Model compression Using Spatial SVD

This notebook shows a working code example of how to use AIMET to perform model compression. The Spatial SVD technique is used in this notebook to achieve model compression.

Here is a brief introduction to the techniques. Please refer to the AIMET user guide for more details.

  1. Spatial SVD: This is a tensor-decomposition technique generally applied to convolutional layers (Conv2D). Applying this technique will decompose a single convolutional layer into two. The weight tensor of the layer to be split is flattended to a 2D matrix and singular value decomposition (SVD) is applied to this matrix. Compression is achieved by discarding the least significant singular values in the diagonal matrix. The decomposed matrices are combined back into two separate convolutional layers.

  2. Channel Pruning: In this technique AIMET will discard least significant (using a magnitude metric) input channels of a given convolutional (Conv2D) layer. The layers of the model feeding into this convolutional layer also have the channels dimension modified to get back to a working graph. This technique also uses a layer-by-layer reconstruction procedure that modifies the weights of the compressed layers to minimize the distance of the compressed layer output to the corresponding layer output of the original model.

Both of the above techniques are structured pruning techniques that aim to reduce computational macs or memory requirements of the model. Subsequent to applying either of these techniques, the compressed model needs to be fine-tuned (meaning trained again for a few epochs) to recover accuracy close to the original model.

This notebook shows working code example of how the technique #1 can be used to compress the model. You can find a separate notebook for #2, and #1 followed by #2 in the same folder.

Overall flow

This notebook covers the following 1. Instantiate the example evaluation and training pipeline 2. Load the model and evaluate it to find the baseline accuracy 3. Compress the model and fine-tune:
3.1 Compress model using Spatial SVD and evaluate it to find post-compression accuracy
3.2 Fine-tune the model

What this notebook is not

  • This notebook is not designed to show state-of-the-art compression results. For example, some optimization parameters such as num_comp_ratio_candidates, num_eval_iterations and epochs are deliberately chosen to have the notebook execute more quickly.


Dataset

This notebook relies on the ImageNet dataset for the task of image classification. If you already have a version of the dataset readily available, please use that. Else, please download the dataset from appropriate location (e.g. https://image-net.org/challenges/LSVRC/2012/index.php#) and convert them into tfrecords.

Note1: The ImageNet tfrecords dataset typically has the following characteristics and the dataloader provided in this example notebook rely on these - A folder containing tfrecords files starting with ‘train*’ for training files and ‘valid*’ for validation files. Each tfrecord file should have features: ‘image/encoded’ for image data and ‘image/class/label’ for its corresponding class.

Note2: To speed up the execution of this notebook, you may use a reduced subset of the ImageNet dataset. E.g. the entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. But for the purpose of running this notebook, you could perhaps reduce the dataset to say 2 samples per class and then convert it into tfrecords. This exercise is left upto the reader and is not necessary.

Edit the cell below and specify the directory where the downloaded ImageNet dataset is saved.

[ ]:
TFRECORDS_DIR = '/path/to/tfrecords/dir/'        # Please replace this with a real directory

We disable logs at the INFO level and disable eager execution. We set verbosity to the level as displayed (ERORR), so TensorFlow will display all messages that have the label ERROR (or more critical).

[ ]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  ##TODO

import tensorflow.compat.v1 as tf        ## TODO Abhijit
tf.disable_eager_execution()
tf.logging.set_verbosity(tf.logging.ERROR)

1. Example evaluation and training pipeline

The following is an example training and validation loop for this image classification task.

  • Does AIMET have any limitations on how the training, validation pipeline is written? Not really. We will see later that AIMET will modify the user’s model to compress it and the resultant model is still a PyTorch model. This compressed model can be used in place of the original model when doing inference or training.

  • Does AIMET put any limitation on the interface of the evaluate() or train() methods? Not really, but evaluate() method should return a single number representing the accuracy of the model. Ideally, You should be able to use your existing evaluate and train routines as-is.

[ ]:
from typing import List
from Examples.common import image_net_config
from Examples.tensorflow.utils.image_net_evaluator import ImageNetDataLoader
from Examples.tensorflow.utils.image_net_evaluator import ImageNetEvaluator
from Examples.tensorflow.utils.image_net_trainer import ImageNetTrainer

class ImageNetDataPipeline:
    """
    Provides APIs for model evaluation and finetuning using ImageNet Dataset.
    """

    @staticmethod
    def get_val_dataloader():
        """
        Instantiates a validation dataloader for ImageNet dataset and returns it
        """
        data_loader = ImageNetDataLoader(TFRECORDS_DIR,
                                         image_size=image_net_config.dataset['image_size'],
                                         batch_size=image_net_config.evaluation['batch_size'],
                                         format_bgr=True)

        return data_loader

    @staticmethod
    def evaluate(sess: tf.Session, iterations: int = None, use_cuda: bool = False) -> float:
        """
        Given a TF session, evaluates its Top-1 accuracy on the validation dataset
        :param sess: The sess graph to be evaluated.
        :return: The accuracy for the sample with the maximum accuracy.
        """
        evaluator = ImageNetEvaluator(TFRECORDS_DIR, training_inputs=['keras_learning_phase:0'],
                                      data_inputs=['input_1:0'], validation_inputs=['labels:0'],
                                      image_size=image_net_config.dataset['image_size'],
                                      batch_size=image_net_config.evaluation['batch_size'],
                                      format_bgr=True)

        return evaluator.evaluate(sess, iterations)


    @staticmethod
    def finetune(sess: tf.Session, update_ops_name: List[str], epochs: int, learning_rate: float, decay_steps: int):
        """
        Given a TF session, finetunes it to improve its accuracy
        :param sess: The sess graph to fine-tune.
        :param update_ops_name: list of name of update ops (mostly BatchNorms' moving averages).
                                tf.GraphKeys.UPDATE_OPS collections is always used
                                in addition to this list
        :param epochs: The number of epochs used during the finetuning step.
        :param learning_rate: The learning rate used during the finetuning step.
        :param decay_steps: A number used to adjust(decay) the learning rate after every decay_steps epochs in training.
        """
        trainer = ImageNetTrainer(TFRECORDS_DIR, training_inputs=['keras_learning_phase:0'],
                                  data_inputs=['input_1:0'], validation_inputs=['labels:0'],
                                  image_size=image_net_config.dataset['image_size'],
                                  batch_size=image_net_config.train['batch_size'],
                                  num_epochs=epochs, format_bgr=True)

        trainer.train(sess, update_ops_name=update_ops_name, learning_rate=learning_rate, decay_steps=decay_steps)


2. Load the model and evaluate it to find the baseline accuracy

For this example notebook, we are going to load a pretrained ResNet50 model from keras and covert it to a tensorflow session. Similarly, you can load any pretrained tensorflow model instead.

Calling clear_session() releases the global state: this helps avoid clutter from old models and layers, especially when memory is limited.

By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. Since batchnorm ops are folded, these need to be ignored during training.

[ ]:
from tensorflow.compat.v1.keras.applications.resnet import ResNet50

tf.keras.backend.clear_session()

model = ResNet50(weights='imagenet', input_shape=(224, 224, 3))
update_ops_name = [op.name for op in model.updates] # Used for finetuning

The following utility method in AIMET sets BN layers in the model to eval mode. This allows AIMET to more easily read the BN parameters from the graph. Eventually we will fold BN layers into adjacent conv layers.

[ ]:
from aimet_tensorflow.utils.graph import update_keras_bn_ops_trainable_flag

model = update_keras_bn_ops_trainable_flag(model, load_save_path="./", trainable=False)

AIMET features currently support tensorflow sessions. add_image_net_computational_nodes_in_graph adds an output layer, softmax and loss functions to the Resnet50 model graph.

[ ]:
from Examples.tensorflow.utils.add_computational_nodes_in_graph import add_image_net_computational_nodes_in_graph

sess = tf.keras.backend.get_session()

# Creates the computation graph of ResNet within the tensorflow session.
add_image_net_computational_nodes_in_graph(sess, model.output.name, image_net_config.dataset['images_classes'])

Since all tensorflow input and output tensors have names, we identify the tensors needed by AIMET APIs here.

[ ]:
input_op_names = [model.input.name.split(":")[0]]
output_op_names = [model.output.name.split(":")[0]]
starting_op_names = input_op_names.copy()
starting_op_names.append('labels')

We are checking if TensorFlow is using CPU or CUDA device. This example code will use CUDA if available in your current execution environment.

[ ]:
use_cuda = tf.test.is_gpu_available(cuda_only=True)

Let’s determine the FP32 (floating point 32-bit) accuracy of this model using the evaluate() routine

[ ]:
accuracy = ImageNetDataPipeline.evaluate(sess=sess)
print(accuracy)

3. Compress the model and fine-tune

3.1. Compress model using Channel Pruning and evaluate it to find post-compression accuracy

Now we use AIMET to define compression parameters for Channel Pruning, few of which are explained here

  • target_comp_ratio: The desired compession ratio using Channel Pruning. This value denotes the desired compression % of the original model. To compress the model to 20% of its original size, use 0.2. This would compress the model by 80%. The pre-specified value that is given is 50%. The desired compression ratio for Channel Pruning. We are using 0.9 to compress the model by 10%.

  • num_comp_ratio_candidates: As part of determining how compressible each layer is, AIMET performs various measurements. This number denotes the different compression ratios tried by the AIMET for each layer. We are using 3 here which translates to 0.33, 0.66 and 1.00 compression ratios at each layer. Optimal value is 10. The higher the number of candidates the more granular the measurements for each layer, but also the higher the time taken to complete these measurements.

  • modules_to_ignore: This list can contain the references of model-layers that should be ignored during compression. We have added the first layer to be ignored to preserve the way the input interacts with the model; other layers can be added too if desired.

  • mode: We are chossing Auto mode which means AIMET performs per-layer compressibility analysis and determines how much to compress each layer. The alternate choice is Manual.

  • data_loader: Channel Pruning uses unlabelled data samples for the layer-by-layer reconstruction procedure explained at the start. This provided data loader is used to retrieve those samples. You can just pass your existing data loader - say for the validation or training dataset.

  • num_reconstruction_samples: During the last stage of Channel Pruning, the Compression API tries to map the outputs of the pruned model with that of the original model through linear regression, and uses this attempt to change the weights in the pruned layer. The regression is done with this many random samples. The number of samples used in the layer-by-layer reconstruction procedure. We are using 10 here which is a ridiculously low number but enables this notebook to execute quickly. A typical setting here would ~1000 samples.

  • allow_custom_downsample_ops: If this flag is enabled, AIMET Channel Pruning will insert downsample ops into the model graph if needed. Enabling this can enable more convolutional layers to be considered for pruning, but it may increase memory bandwidth overhead for the additional downsample layers. So there is a trade-off to be considered. We suggest disabling this by default.

  • eval_callback: The model evaluation function. The expected signature of the evaluate function should be <function_name>(model, eval_iterations, use_cuda) and it is expected to return an accuracy metric.

  • eval_iterations: The number of batches of data to use for evaluating the model while the model is compressing. We are using 1 to speed up the notebook execution. But please choose a high enough number of samples so that we can trust the accuracy of the model given those samples. It is expected that the eval callback would use the same samples for every invocation of the callback.

  • compress_scheme: We choose the ‘channel pruning’ compression scheme.

  • cost_metric: Determines whether we want to target either to reduce MACs or memory by the desired compression ratio. We are chossing ‘mac’ here.

The next cell creates the actual parameters for Spatial SVD. There are two methods for which you can choose parameters - Auto and Manual. For Auto, the only option is a greedy selection scheme, where the optimal compression ratio is selected for each layer among a set list of candidates to reach the target ratio (which was specified in the previous cell). For Manual, you have to specify the compression ratios for each layer; a general rule of thumb, if one is to use Manual, is to start with the ratios found by Auto Mode and use it as a starting point.

[ ]:
from decimal import Decimal
from aimet_common.defs import CompressionScheme, CostMetric, GreedySelectionParameters
from aimet_tensorflow.defs import SpatialSvdParameters

greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.5),
                                          num_comp_ratio_candidates=2)

modules_to_ignore = [sess.graph.get_operation_by_name('conv1_conv/Conv2D')]
auto_params = SpatialSvdParameters.AutoModeParams(greedy_select_params=greedy_params,
                                                  modules_to_ignore=modules_to_ignore)

params = SpatialSvdParameters(input_op_names=starting_op_names,
                              output_op_names=output_op_names,
                              mode=SpatialSvdParameters.Mode.auto,
                              params=auto_params)


eval_callback = ImageNetDataPipeline.evaluate
eval_iterations = 1
compress_scheme =  CompressionScheme.spatial_svd
cost_metric = CostMetric.mac

We call the AIMET ModelCompressor.compress_model API using the above parameters. This call returns a compressed model as well as relevant statistics.
Note: the ModelCompressor evaluates the model while compressing using the same evaluate function that is in our data pipeline.
[ ]:
from aimet_tensorflow.compress import ModelCompressor

os.makedirs('./output/', exist_ok=True)
#TODO: makedirs should be at top??

compressed_sess, comp_stats = ModelCompressor.compress_model(sess=sess,
                                                             working_dir="output",
                                                             eval_callback=ImageNetDataPipeline.evaluate,
                                                             eval_iterations=eval_iterations,
                                                             input_shape=(1, 3, 224, 224),
                                                             compress_scheme=compress_scheme,
                                                             cost_metric=cost_metric,
                                                             parameters=params)

print(comp_stats)

Now the compressed model is ready to be used for inference or training. First we can pass this model to the same evaluation routine we used before to calculated compressed model accuracy.

[ ]:
comp_accuracy = ImageNetDataPipeline.evaluate(compressed_sess)
print(comp_accuracy)

As you can see the model accuracy fell sharply after compression. This is expected. We will use model fine-tuning to recover this accuracy back.

3.2. Fine-tune the model

After the model is compressed using Spatial SVD, we can simply train the model for a few more epochs (typically 15-20). As with any training job, hyper-parameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so.

For the purpose of this example notebook, we are going to train only for 1 epoch. But feel free to change these parameters as you see fit.

Add this: Since Channel Pruning replaces few BNs by different BNs with ‘reduced_’ added in their original name, update_ops_name list should be updated accordingly

[ ]:
compr_graph_all_ops_name = [op.name for op in compressed_sess.graph.get_operations()]
update_ops_name_after_CP = []
for op_name in update_ops_name:
    if 'reduced_'+op_name in compr_graph_all_ops_name:
        update_ops_name_after_CP.append('reduced_'+ op_name)
    else:
        update_ops_name_after_CP.append(op_name)

ImageNetDataPipeline.finetune(compressed_sess, update_ops_name_after_CP, epochs=1, learning_rate=1e-3, decay_steps=5)

After we are done with finetuing the compressed model, we can check the floating point accuracy against the same validation dataset at the end to observe any improvements in accuracy.

[ ]:
accuracy = ImageNetDataPipeline.evaluate(compressed_sess)
print(accuracy)

Depending on your settings you should have observed a slight gain in accuracy after one epoch of training. Ofcourse, this was just an example. Please try this against the model of your choice and play with the hyper-parameters to get the best results.

So we have an improved model after compression using spatial SVD. Optionally, this model now can be saved like a regular tensorflow model.

[ ]:
from aimet_tensorflow.utils.graph_saver import save_model_to_meta

save_model_to_meta(compressed_sess, meta_path='./output/finetuned_model')

Summary

Hope this notebook was useful for you to understand how to use AIMET for performing compression with Spatial SVD. As indicated above, some parameters have been chosen in a way to run the example faster.

Few additional resources - Refer to the AIMET API docs to know more details of the APIs and optional parameters - Refer to the other example notebooks to understand how to use AIMET compression and quantization techniques