Model compression using spatial SVD

This notebook contains a working example of AIMET model compression using the singular value decomposition (SVD) technique.

Spatial SVD aims to reduce computational MACS or memory requirements of the model. After applying spatial SVD, the compressed model must be fine-tuned (meaning trained again for a few epochs) to restore accuracy to a level near the original model’s.

See AIMET spatial SVD for a detailed discussion of spatial SVD and other compression techniques.

Overall flow

The example follows these high-level steps:

  1. Instantiate the example evaluation and training pipeline

  2. Load the model and evaluate it to find the baseline accuracy

  3. Compress the model and fine-tune:

    1. Compress model using spatial SVD and evaluate it to find post-compression accuracy

    2. Fine-tune the model

Note

This notebook does not show state-of-the-art results. For example, it uses a relatively quantization-friendly model (Resnet18). Also, some optimization parameters like number of fine-tuning epochs are chosen to improve execution speed in the notebook.


Dataset

This example does image classification on the ImageNet dataset. If you already have a version of the data set, use that. Otherwise download the data set, for example from https://image-net.org/challenges/LSVRC/2012/index .

Note

The dataloader provided in this example relies on these features of the ImageNet data set:

  • Subfolders train for the training samples and val for the validation samples. See the pytorch dataset description for more details.

  • One subdirectory per class, and one file per image sample.

Note

To speed up the execution of this notebook, you can use a reduced subset of the ImageNet dataset. For example: The entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. However, for the purpose of running this notebook, you can reduce the dataset to, say, two samples per class.

Edit the cell below to specify the directory where the downloaded ImageNet dataset is saved.

[ ]:
DATASET_DIR = '/path/to/dataset/'         # Replace this path with a real directory

1. Instantiate the example training and validation pipeline

Use the following training and validation loop for the image classification task.

Things to note:

  • AIMET does not put limitations on how the training and validation pipeline is written. AIMET modifies the user’s model to create a QuantizationSim model, which is still a PyTorch model. The QuantizationSim model can be used in place of the original model when doing inference or training.

  • AIMET doesn not put limitations on the interface of the evaluate() or train() methods. You should be able to use your existing evaluate and train routines as-is.

[ ]:
import os
import torch
from typing import List
from Examples.common import image_net_config
from Examples.torch.utils.image_net_evaluator import ImageNetEvaluator
from Examples.torch.utils.image_net_trainer import ImageNetTrainer

class ImageNetDataPipeline:

    @staticmethod
    def evaluate(model: torch.nn.Module, iterations: int, use_cuda: bool) -> float:
        """
        Given a torch model, evaluates its Top-1 accuracy on the dataset
        :param model: the model to evaluate
        :param iterations: the number of batches to be used to evaluate the model. A value of 'None' means the model will be
                           evaluated on the entire dataset once.
        :param use_cuda: whether or not the GPU should be used.
        """
        evaluator = ImageNetEvaluator(DATASET_DIR, image_size=image_net_config.dataset['image_size'],
                                      batch_size=image_net_config.evaluation['batch_size'],
                                      num_workers=image_net_config.evaluation['num_workers'])

        return evaluator.evaluate(model, iterations=iterations, use_cuda=use_cuda)

    @staticmethod
    def finetune(model: torch.nn.Module, epochs: int, learning_rate: float, learning_rate_schedule: List, use_cuda: bool):
        """
        Given a torch model, finetunes the model to improve its accuracy
        :param model: the model to finetune
        :param epochs: The number of epochs used during the finetuning step.
        :param learning_rate: The learning rate used during the finetuning step.
        :param learning_rate_schedule: The learning rate schedule used during the finetuning step.
        :param use_cuda: whether or not the GPU should be used.
        """
        trainer = ImageNetTrainer(DATASET_DIR, image_size=image_net_config.dataset['image_size'],
                                  batch_size=image_net_config.train['batch_size'],
                                  num_workers=image_net_config.train['num_workers'])

        trainer.train(model, max_epochs=epochs, learning_rate=learning_rate,
                      learning_rate_schedule=learning_rate_schedule, use_cuda=use_cuda)

2. Load the model and evaluate to get a baseline FP32 accuracy score

2.1 Load a pretrained resnet18 model from torchvision.

You can load any pretrained PyTorch model instead.

[ ]:
from torchvision.models import resnet18

model = resnet18(pretrained=True)

2.2 Decide whether to place the model on a CPU or CUDA device.

This example uses CUDA if it is available. You can change this logic and force a device placement if needed.

[ ]:
use_cuda = False
if torch.cuda.is_available():
    use_cuda = True
    model.to(torch.device('cuda'))

2.3 Compute the floating point 32-bit (FP32) accuracy of this model using the evaluate() routine.

[ ]:
accuracy = ImageNetDataPipeline.evaluate(model, iterations=None, use_cuda=use_cuda)
print(accuracy)

3. Compress the model and fine-tune

3.1. Compress the model using spatial SVD and evaluate it to find post-compression accuracy.

Use AIMET to define compression parameters for spatial SVD.

Some key parameters:

  • target_comp_ratio: The desired compression ratio for Channel Pruning. This example uses 0.9 to compress the model by 10%.

  • num_comp_ratio_candidates: Defines how many compression ratios to try. To calculate how compressible each layer is, AIMET tried difference compression ratios. The value three causes AIMET to try values of 0.33, 0.66, and 1.00. Higher values result in more granular measurements but take longer to complete. In practice a good compromise is 10.

  • modules_to_ignore: A list that contains the references of model layers to ignore during compression. This example adds the first layer for illustration purposes. Other layers can be added if desired.

  • mode: Auto mode performs per-layer compressibility analysis and calculates how much to compress each layer. The alternative is Manual.

  • eval_callback: The model evaluation function. The expected signature of the evaluate function is <function_name>(model, eval_iterations, use_cuda) The function should return an accuracy metric.

  • eval_iterations: The number of batches of data with which to evaluate the model during compression. The example uses one for execution speed. In practice, use a high enough number to give good accuracy results. The eval callback is expected to use the same samples for every invocation of the callback.

  • compress_scheme: The ‘spatial svd’ compression scheme.

  • cost_metric: Set to target either MACs or memory for reduction by the desired compression ratio. The example chooses ‘mac’.

[ ]:
from decimal import Decimal
from aimet_torch.defs import GreedySelectionParameters, SpatialSvdParameters
from aimet_common.defs import CompressionScheme, CostMetric

greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                          num_comp_ratio_candidates=3)
modules_to_ignore = [model.conv1]
auto_params = SpatialSvdParameters.AutoModeParams(greedy_select_params=greedy_params,
                                                  modules_to_ignore=modules_to_ignore)
params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.auto, params=auto_params)

eval_callback = ImageNetDataPipeline.evaluate
eval_iterations = 1
compress_scheme = CompressionScheme.spatial_svd
cost_metric = CostMetric.mac

3.2 Call the AIMET ModelCompressor.compress_model API using the above parameters.

This call returns a compressed model and relevant statistics. The ModelCompressor evaluates the model while compressing using the evaluation function from the data pipeline.

[ ]:
from aimet_torch.compress import ModelCompressor
compressed_model, comp_stats = ModelCompressor.compress_model(model=model,
                                                              eval_callback=eval_callback,
                                                              eval_iterations=eval_iterations,
                                                              input_shape=(1, 3, 224, 224),
                                                              compress_scheme=compress_scheme,
                                                              cost_metric=cost_metric,
                                                              parameters=params)

print(comp_stats)

The QuantizationSim model is now ready to be used for inference or training.

3.3 Pass the model to the same evaluation routine as before to calculate a simulated quantized accuracy score for the compressed model.

[ ]:
accuracy = ImageNetDataPipeline.evaluate(compressed_model, iterations=None, use_cuda=use_cuda)
print(accuracy)

Model accuracy falls sharply after compression. This is expected. Fine-tuning is used to recover accuracy.

3.4. Fine-tune the model.

As with any training job, hyper-parameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so.

This example trains for only one epoch, but you can experiment with the parameters however you like.

[ ]:
ImageNetDataPipeline.finetune(compressed_model, epochs=1, learning_rate=5e-7, learning_rate_schedule=[5, 10],
                              use_cuda=use_cuda)

3.5 After fine-tuning, evaluate the model again to see the improvements in accuracy.

[ ]:
accuracy = ImageNetDataPipeline.evaluate(compressed_model, iterations=None, use_cuda=use_cuda)
print(accuracy)

Of course, there might be little gain in accuracy after only one epoch of training. Experiment with the hyper-parameters to get better results.

Next steps

The next step is to save the model.

Save the fine-tuned model.

[ ]:
os.makedirs('./output/', exist_ok=True)
torch.save(compressed_model, './output/finetuned_model')

For more information

See the AIMET API docs for details about the AIMET APIs and optional parameters.

See the other example notebooks to learn how to use other AIMET compression techniques.