AIMET PyTorch Compression API

Introduction

AIMET supports the following model compression techniques for PyTorch models
  • Weight SVD

  • Spatial SVD

  • Channel Pruning

For all of these compression techniques there are two modes in which you can invoke the AIMET API
  • Auto Mode: In Auto mode, AIMET will determine the optimal way to compress each layer of

    the model given an overall target compression ratio. Greedy Compression Ratio Selection Algorithm is used to pick appropriate compression ratios for each layer.

  • Manual Mode: In Manual mode, the user can pass in the desired compression-ratio per layer

    to AIMET. AIMET will apply the specified compression technique for each of the layers to achieve the desired compression-ratio per layer. It is recommended that the user start with Auto mode, and then tweak per-layer compression-ratios using Manual mode if desired.


Top-level API for Compression

class aimet_torch.compress.ModelCompressor

AIMET model compressor: Enables model compression using various schemes


static ModelCompressor.compress_model(model, eval_callback, eval_iterations, input_shape, compress_scheme, cost_metric, parameters, trainer=None, visualization_url=None)

Compress a given model using the specified parameters

Parameters
  • model (Module) – Model to compress

  • eval_callback (Callable[[Any, Optional[int], bool], float]) – Evaluation callback. Expected signature is evaluate(model, iterations, use_cuda). Expected to return an accuracy metric.

  • eval_iterations – Iterations to run evaluation for

  • trainer – Training Class: Contains a callable, train_model, which takes model, layer which is being fine tuned and an optional parameter train_flag as a parameter None: If per layer fine tuning is not required while creating the final compressed model

  • input_shape (Tuple) – Shape of the input tensor for model

  • compress_scheme (CompressionScheme) – Compression scheme. See the enum for allowed values

  • cost_metric (CostMetric) – Cost metric to use for the compression-ratio (either mac or memory)

  • parameters (Union[SpatialSvdParameters, WeightSvdParameters, ChannelPruningParameters]) – Compression parameters specific to given compression scheme

  • visualization_url – url the user will need to input where visualizations will appear

Return type

Tuple[Module, CompressionStats]

Returns

A tuple of the compressed model, and compression statistics


Greedy Selection Parameters

class aimet_common.defs.GreedySelectionParameters(target_comp_ratio, num_comp_ratio_candidates=10, use_monotonic_fit=False, saved_eval_scores_dict=None)

Configuration parameters for the Greedy compression-ratio selection algorithm

Variables
  • target_comp_ratio – Target compression ratio. Expressed as value between 0 and 1. Compression ratio is the ratio of cost of compressed model to cost of the original model.

  • num_comp_ratio_candidates – Number of comp-ratio candidates to analyze per-layer More candidates allows more granular distribution of compression at the cost of increased run-time during analysis. Default value=10. Value should be greater than 1.

  • use_monotonic_fit – If True, eval scores in the eval dictionary are fitted to a monotonically increasing function. This is useful if you see the eval dict scores for some layers are not monotonically increasing. By default, this option is set to False.

  • saved_eval_scores_dict – Path to the eval_scores dictionary pickle file that was saved in a previous run. This is useful to speed-up experiments when trying different target compression-ratios for example. aimet will save eval_scores dictionary pickle file automatically in a ./data directory relative to the current path. num_comp_ratio_candidates parameter will be ignored when this option is used.


TAR Selection Parameters

class aimet_torch.defs.TarRankSelectionParameters(num_rank_indices)

Configuration parameters for the TAR compression-ratio selection algorithm

Variables

num_rank_indices – Number of rank indices for ratio selection.


Spatial SVD Configuration

class aimet_torch.defs.SpatialSvdParameters(mode, params, multiplicity=1)

Configuration parameters for spatial svd compression

Parameters
  • mode (Mode) – Either auto mode or manual mode

  • params (Union[ManualModeParams, AutoModeParams]) – Parameters for the mode selected

  • multiplicity – The multiplicity to which ranks/input channels will get rounded. Default: 1

class AutoModeParams(greedy_select_params, modules_to_ignore=None)

Configuration parameters for auto-mode compression

Parameters
  • greedy_select_params (GreedySelectionParameters) – Params for greedy comp-ratio selection algorithm

  • modules_to_ignore (Optional[List[Module]]) – List of modules to ignore (None indicates nothing to ignore)

class ManualModeParams(list_of_module_comp_ratio_pairs)

Configuration parameters for manual-mode spatial svd compression

Parameters

list_of_module_comp_ratio_pairs (List[ModuleCompRatioPair]) – List of (module, comp-ratio) pairs

class Mode

Mode enumeration

auto = 2

Auto mode

manual = 1

Manual mode


Weight SVD Configuration

class aimet_torch.defs.WeightSvdParameters(mode, params, multiplicity=1)

Configuration parameters for weight svd compression

Parameters
  • mode (Mode) – Either auto mode or manual mode

  • params (Union[ManualModeParams, AutoModeParams]) – Parameters for the mode selected

  • multiplicity – The multiplicity to which ranks/input channels will get rounded. Default: 1

class AutoModeParams(rank_select_scheme, select_params, modules_to_ignore=None)

Configuration parameters for auto-mode compression

Parameters
  • rank_select_scheme (RankSelectScheme) – supports two options greedy and tar

  • select_params (Union[GreedySelectionParameters, TarRankSelectionParameters]) – Params for greedy/TAR comp-ratio selection algorithm

  • modules_to_ignore (Optional[List[Module]]) – List of modules to ignore (None indicates nothing to ignore)

class ManualModeParams(list_of_module_comp_ratio_pairs)

Configuration parameters for manual-mode weight svd compression

Parameters

list_of_module_comp_ratio_pairs (List[ModuleCompRatioPair]) – List of (module, comp-ratio) pairs

class Mode

Mode enumeration

auto = 2

Auto mode

manual = 1

Manual mode


Channel Pruning Configuration

class aimet_torch.defs.ChannelPruningParameters(data_loader, num_reconstruction_samples, allow_custom_downsample_ops, mode, params, multiplicity=1)

Configuration parameters for channel pruning compression

class AutoModeParams(greedy_select_params, modules_to_ignore=None)

Configuration parameters for auto-mode compression

Parameters
  • greedy_select_params (GreedySelectionParameters) – Params for greedy comp-ratio selection algorithm

  • modules_to_ignore (Optional[List[Module]]) – List of modules to ignore (None indicates nothing to ignore)

class ManualModeParams(list_of_module_comp_ratio_pairs)

Configuration parameters for manual-mode channel pruning compression

Parameters

list_of_module_comp_ratio_pairs (List[ModuleCompRatioPair]) – List of (module, comp-ratio) pairs

class Mode

Mode enumeration

auto = 2

Auto mode: AIMET computes optimal comp-ratio per layer

manual = 1

Manual mode: User specifies comp-ratio per layer


Configuration Definitions

class aimet_common.defs.CostMetric

Enumeration of metrics to measure cost of a model/layer

mac = 1

MAC: Cost modeled for compute requirements

memory = 2

Memory: Cost modeled for space requirements


class aimet_common.defs.CompressionScheme

Enumeration of compression schemes supported in aimet

channel_pruning = 3

Channel Pruning

spatial_svd = 2

Spatial SVD

weight_svd = 1

Weight SVD


class aimet_torch.defs.ModuleCompRatioPair(module, comp_ratio)

Pair of torch.nn.module and a compression-ratio

Variables
  • module – Module of type torch.nn.module

  • comp_ratio – Compression ratio. Compression ratio is the ratio of cost of compressed model to cost of the original model.


Code Examples

Required imports

import os
from decimal import Decimal
import torch


# Compression-related imports
from aimet_common.defs import CostMetric, CompressionScheme, GreedySelectionParameters, RankSelectScheme
from aimet_torch.defs import WeightSvdParameters, SpatialSvdParameters, ChannelPruningParameters, \
    ModuleCompRatioPair
from aimet_torch.compress import ModelCompressor
from aimet_torch.examples import mnist_torch_model

Evaluation function

def evaluate_model(model: torch.nn.Module, eval_iterations: int, use_cuda: bool = False) -> float:
    """
    This is intended to be the user-defined model evaluation function.
    AIMET requires the above signature. So if the user's eval function does not
    match this signature, please create a simple wrapper.

    Note: Honoring the number of iterations is not absolutely necessary.
    However if all evaluations run over an entire epoch of validation data,
    the runtime for AIMET compression will obviously be higher.

    :param model: Model to evaluate
    :param eval_iterations: Number of iterations to use for evaluation.
            None for entire epoch.
    :param use_cuda: If true, evaluate using gpu acceleration
    :return: single float number (accuracy) representing model's performance
    """
    return .5

Compressing using Spatial SVD in auto mode with multiplicity = 8 for rank rounding

def spatial_svd_auto_mode():

    # load trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))

    # Specify the necessary parameters
    greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                              num_comp_ratio_candidates=10)
    auto_params = SpatialSvdParameters.AutoModeParams(greedy_params,
                                                      modules_to_ignore=[model.conv1])

    params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.auto,
                                  params=auto_params, multiplicity=8)

    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.spatial_svd,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)

    compressed_model, stats = results
    print(compressed_model)
    print(stats)     # Stats object can be pretty-printed easily

Compressing using Spatial SVD in manual mode

def spatial_svd_manual_mode():

    # Load a trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))

    # Specify the necessary parameters
    manual_params = SpatialSvdParameters.ManualModeParams([ModuleCompRatioPair(model.conv1, 0.5),
                                                           ModuleCompRatioPair(model.conv2, 0.4)])
    params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.manual,
                                  params=manual_params)

    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.spatial_svd,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)

    compressed_model, stats = results
    print(compressed_model)
    print(stats)    # Stats object can be pretty-printed easily

Compressing using Weight SVD in auto mode

def weight_svd_auto_mode():

    # Load trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))

    # Specify the necessary parameters
    greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                              num_comp_ratio_candidates=10)
    rank_select = RankSelectScheme.greedy
    auto_params = WeightSvdParameters.AutoModeParams(rank_select_scheme=rank_select,
                                                     select_params=greedy_params,
                                                     modules_to_ignore=[model.conv1])

    params = WeightSvdParameters(mode=WeightSvdParameters.Mode.auto,
                                 params=auto_params)

    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.weight_svd,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)

    compressed_model, stats = results
    print(compressed_model)
    print(stats)     # Stats object can be pretty-printed easily

Compressing using Weight SVD in manual mode with multiplicity = 8 for rank rounding

def weight_svd_manual_mode():

    # Load a trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))

    # Specify the necessary parameters
    manual_params = WeightSvdParameters.ManualModeParams([ModuleCompRatioPair(model.conv1, 0.5),
                                                          ModuleCompRatioPair(model.conv2, 0.4)])
    params = WeightSvdParameters(mode=WeightSvdParameters.Mode.manual,
                                 params=manual_params, multiplicity=8)

    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.weight_svd,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)

    compressed_model, stats = results
    print(compressed_model)
    print(stats)    # Stats object can be pretty-printed easily

Compressing using Channel Pruning in auto mode

def channel_pruning_auto_mode():

    # Load trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))

    # Specify the necessary parameters
    greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                              num_comp_ratio_candidates=10)
    auto_params = ChannelPruningParameters.AutoModeParams(greedy_params,
                                                          modules_to_ignore=[model.conv1])

    data_loader = mnist_torch_model.DataLoaderMnist(cuda=True, seed=1, shuffle=True)
    params = ChannelPruningParameters(data_loader=data_loader.train_loader,
                                      num_reconstruction_samples=500,
                                      allow_custom_downsample_ops=True,
                                      mode=ChannelPruningParameters.Mode.auto,
                                      params=auto_params)

    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.channel_pruning,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)

    compressed_model, stats = results
    print(compressed_model)
    print(stats)     # Stats object can be pretty-printed easily

Compressing using Channel Pruning in manual mode

def channel_pruning_manual_mode():

    # Load a trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))

    # Specify the necessary parameters
    manual_params = ChannelPruningParameters.ManualModeParams([ModuleCompRatioPair(model.conv2, 0.4)])

    data_loader = mnist_torch_model.DataLoaderMnist(cuda=True, seed=1, shuffle=True)
    params = ChannelPruningParameters(data_loader=data_loader.train_loader,
                                      num_reconstruction_samples=500,
                                      allow_custom_downsample_ops=True,
                                      mode=ChannelPruningParameters.Mode.manual,
                                      params=manual_params)

    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.channel_pruning,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)

    compressed_model, stats = results
    print(compressed_model)
    print(stats)    # Stats object can be pretty-printed easily

Example Training Object

class Trainer:
    """ Example trainer class """

    def __init__(self):
        self._layer_db = []

    def train_model(self, model, layer, train_flag=True):
        """
        Trains a model
        :param model: Model to be trained
        :param layer: layer which has to be fine tuned
        :param train_flag: Default: True. If ture the model gets trained
        :return:
        """
        if train_flag:
            mnist_torch_model.train(model, epochs=1, use_cuda=True, batch_size=50, batch_callback=None)
        self._layer_db.append(layer)

Compressing using Spatial SVD in auto mode with layer-wise fine tuning

def spatial_svd_auto_mode_with_layerwise_finetuning():

    # load trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))

    # Specify the necessary parameters
    greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                              num_comp_ratio_candidates=10)
    auto_params = SpatialSvdParameters.AutoModeParams(greedy_params,
                                                      modules_to_ignore=[model.conv1])

    params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.auto,
                                  params=auto_params)

    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.spatial_svd,
                                             cost_metric=CostMetric.mac,
                                             parameters=params, trainer=Trainer())

    compressed_model, stats = results
    print(compressed_model)
    print(stats)     # Stats object can be pretty-printed easily