Model compression using Spatial SVD

This notebook shows a working code example of how to use AIMET to perform model compression. The Spatial SVD technique is used in this notebook to achieve model compression.

Here is a brief introduction to the techniques. Please refer to the AIMET user guide for more details.

  1. Spatial SVD: This is a tensor-decomposition technique generally applied to convolutional layers (Conv2D). Applying this technique will decompose a single convolutional layer into two. The weight tensor of the layer to be split is flattended to a 2D matrix and singular value decomposition (SVD) is applied to this matrix. Compression is achieved by discarding the least significant singular values in the diagonal matrix. The decomposed matrices are combined back into two separate convolutional layers.

  2. Channel Pruning: In this technique AIMET will discard least significant (using a magnitude metric) input channels of a given convolutional (Conv2D) layer. The layers of the model feeding into this convolutional layer also have the channels dimension modified to get back to a working graph. This technique also uses a layer-by-layer reconstruction procedure that modifies the weights of the compressed layers to minimize the distance of the compressed layer output to the corresponding layer output of the original model.

Both of the above techniques are structured pruning techniques that aim to reduce computational macs or memory requirements of the model. Subsequent to applying either of these techniques, the compressed model needs to be fine-tuned (meaning trained again for a few epochs) to recover accuracy close to the original model.

This notebook shows working code example of how the technique #1 can be used to compress the model. You can find a separate notebook for #2, and #1 followed by #2 in the same folder.

Overall flow

This notebook covers the following 1. Instantiate the example evaluation and training pipeline 2. Load the model and evaluate it to find the baseline accuracy 3. Compress the model and fine-tune:
3.1 Compress model using Spatial SVD and evaluate it to find post-compression accuracy
3.2 Fine-tune the model

What this notebook is not

  • This notebook is not designed to show state-of-the-art compression results. For example, some optimization parameters such as num_comp_ratio_candidates, num_eval_iterations and epochs are deliberately chosen to have the notebook execute more quickly.


Dataset

This notebook relies on the ImageNet dataset for the task of image classification. If you already have a version of the dataset readily available, please use that. Else, please download the dataset from appropriate location (e.g. https://image-net.org/challenges/LSVRC/2012/index.php#).

Note1: The ImageNet dataset typically has the following characteristics and the dataloader provided in this example notebook rely on these - Subfolders ‘train’ for the training samples and ‘val’ for the validation samples. Please see the pytorch dataset description for more details. - A subdirectory per class, and a file per each image sample

Note2: To speed up the execution of this notebook, you may use a reduced subset of the ImageNet dataset. E.g. the entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. But for the purpose of running this notebook, you could perhaps reduce the dataset to say 2 samples per class. This exercise is left upto the reader and is not necessary.

Edit the cell below and specify the directory where the downloaded ImageNet dataset is saved.

[ ]:
DATASET_DIR = '/path/to/dataset/'         # Please replace this with a real directory

1. Example evaluation and training pipeline

The following is an example training and validation loop for this image classification task.

  • Does AIMET have any limitations on how the training, validation pipeline is written? Not really. We will see later that AIMET will modify the user’s model to compress it and the resultant model is still a PyTorch model. This compressed model can be used in place of the original model when doing inference or training.

  • Does AIMET put any limitation on the interface of the evaluate() or train() methods? Not really, but evaluate() method should return a single number representing the accuracy of the model. Ideally, You should be able to use your existing evaluate and train routines as-is.

[ ]:
import os
import torch
from typing import List
from Examples.common import image_net_config
from Examples.torch.utils.image_net_evaluator import ImageNetEvaluator
from Examples.torch.utils.image_net_trainer import ImageNetTrainer

class ImageNetDataPipeline:

    @staticmethod
    def evaluate(model: torch.nn.Module, iterations: int, use_cuda: bool) -> float:
        """
        Given a torch model, evaluates its Top-1 accuracy on the dataset
        :param model: the model to evaluate
        :param iterations: the number of batches to be used to evaluate the model. A value of 'None' means the model will be
                           evaluated on the entire dataset once.
        :param use_cuda: whether or not the GPU should be used.
        """
        evaluator = ImageNetEvaluator(DATASET_DIR, image_size=image_net_config.dataset['image_size'],
                                      batch_size=image_net_config.evaluation['batch_size'],
                                      num_workers=image_net_config.evaluation['num_workers'])

        return evaluator.evaluate(model, iterations=iterations, use_cuda=use_cuda)

    @staticmethod
    def finetune(model: torch.nn.Module, epochs: int, learning_rate: float, learning_rate_schedule: List, use_cuda: bool):
        """
        Given a torch model, finetunes the model to improve its accuracy
        :param model: the model to finetune
        :param epochs: The number of epochs used during the finetuning step.
        :param learning_rate: The learning rate used during the finetuning step.
        :param learning_rate_schedule: The learning rate schedule used during the finetuning step.
        :param use_cuda: whether or not the GPU should be used.
        """
        trainer = ImageNetTrainer(DATASET_DIR, image_size=image_net_config.dataset['image_size'],
                                  batch_size=image_net_config.train['batch_size'],
                                  num_workers=image_net_config.train['num_workers'])

        trainer.train(model, max_epochs=epochs, learning_rate=learning_rate,
                      learning_rate_schedule=learning_rate_schedule, use_cuda=use_cuda)

2. Load the model and evaluate it to find the baseline accuracy

For this example notebook, we are going to load a pretrained resnet18 model from torchvision. Similarly, you can load any pretrained PyTorch model instead.

[ ]:
from torchvision.models import resnet18

model = resnet18(pretrained=True)

We should decide whether to place the model on a CPU or CUDA device. This example code will use CUDA if available in your current execution environment. You can change this logic and force a device placement if needed.

[ ]:
use_cuda = False
if torch.cuda.is_available():
    use_cuda = True
    model.to(torch.device('cuda'))

Let’s determine the FP32 (floating point 32-bit) accuracy of this model using the evaluate() routine

[ ]:
accuracy = ImageNetDataPipeline.evaluate(model, iterations=None, use_cuda=use_cuda)
print(accuracy)

3. Compress the model and fine-tune

3.1. Compress model using Spatial SVD and evaluate it to find post-compression accuracy

Now we use AIMET to define compression parameters for Spatial SVD, few of which are explained here

  • target_comp_ratio: The desired compression ratio for Spatial SVD. We are using 0.8 to compress the model by 20%.

  • num_comp_ratio_candidates: As part of determining how compressible each layer is, AIMET performs various measurements. This number denotes the different compression ratios tried by the AIMET for each layer. We are using 3 here which translates to 0.33, 0.66 and 1.00 compression ratios at each layer. Optimal value is 10. The higher the number of candidates the more granular the measurements for each layer, but also the higher the time taken to complete these measurements.

  • modules_to_ignore: This list can contain the references of model-layers that should be ignored during compression. We have added the first layer to be ignored to preserve the way the input interacts with the model; other layers can be added too if desired.

  • mode: We are chossing Auto mode which means AIMET performs per-layer compressibility analysis and determines how much to compress each layer. The alternate choice is Manual.

  • eval_callback: The model evaluation function. The expected signature of the evaluate function should be <function_name>(model, eval_iterations, use_cuda) and it is expected to return an accuracy metric.

  • eval_iterations: The number of batches of data to use for evaluating the model while the model is compressing. We are using 1 to speed up the notebook execution. But please choose a high enough number of samples so that we can trust the accuracy of the model given those samples. It is expected that the eval callback would use the same samples for every invocation of the callback.

  • compress_scheme: We choose the ‘spatial svd’ compression scheme.

  • cost_metric: Determines whether we want to target either to reduce MACs or memory by the desired compression ratio. We are chossing ‘mac’ here.

[ ]:
from decimal import Decimal
from aimet_torch.defs import GreedySelectionParameters, SpatialSvdParameters
from aimet_common.defs import CompressionScheme, CostMetric

greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                          num_comp_ratio_candidates=3)
modules_to_ignore = [model.conv1]
auto_params = SpatialSvdParameters.AutoModeParams(greedy_select_params=greedy_params,
                                                  modules_to_ignore=modules_to_ignore)
params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.auto, params=auto_params)

eval_callback = ImageNetDataPipeline.evaluate
eval_iterations = 1
compress_scheme = CompressionScheme.spatial_svd
cost_metric = CostMetric.mac

We call the AIMET ModelCompressor.compress_model API using the above parameters. This call returns a compressed model as well as relevant statistics.
Note: the ModelCompressor evaluates the model while compressing using the same evaluate function that is in our data pipeline.
[ ]:
from aimet_torch.compress import ModelCompressor
compressed_model, comp_stats = ModelCompressor.compress_model(model=model,
                                                              eval_callback=eval_callback,
                                                              eval_iterations=eval_iterations,
                                                              input_shape=(1, 3, 224, 224),
                                                              compress_scheme=compress_scheme,
                                                              cost_metric=cost_metric,
                                                              parameters=params)

print(comp_stats)

Now the compressed model is ready to be used for inference or training. First we can pass this model to the same evaluation routine we used before to calculated compressed model accuracy.

[ ]:
accuracy = ImageNetDataPipeline.evaluate(compressed_model, iterations=None, use_cuda=use_cuda)
print(accuracy)

As you can see the model accuracy fell sharply after compression. This is expected. We will use model fine-tuning to recover this accuracy back.

3.2. Fine-tune the model

After the model is compressed using Spatial SVD, we can simply train the model for a few more epochs (typically 15-20). As with any training job, hyper-parameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so.

For the purpose of this example notebook, we are going to train only for 1 epoch. But feel free to change these parameters as you see fit.

[ ]:
ImageNetDataPipeline.finetune(compressed_model, epochs=1, learning_rate=5e-7, learning_rate_schedule=[5, 10],
                              use_cuda=use_cuda)

After we are done with finetuing the compressed model, we can check the floating point accuracy against the same validation dataset at the end to observe any improvements in accuracy.

[ ]:
accuracy = ImageNetDataPipeline.evaluate(compressed_model, iterations=None, use_cuda=use_cuda)
print(accuracy)

Depending on your settings you should have observed a slight gain in accuracy after one epoch of training. Ofcourse, this was just an example. Please try this against the model of your choice and play with the hyper-parameters to get the best results.

So we have an improved model after compression using spatial SVD. Optionally, this model now can be saved like a regular PyTorch model.

[ ]:
os.makedirs('./output/', exist_ok=True)
torch.save(compressed_model, './output/finetuned_model')

Summary

Hope this notebook was useful for you to understand how to use AIMET for performing compression with Spatial SVD. As indicated above, some parameters have been chosen in a way to run the example faster.

Few additional resources - Refer to the AIMET API docs to know more details of the APIs and optional parameters - Refer to the other example notebooks to understand how to use AIMET compression and quantization techniques