AIMET PyTorch Compression API¶
Introduction¶
- AIMET supports the following model compression techniques for PyTorch models
- Weight SVD 
- Spatial SVD 
- Channel Pruning 
 
- For all of these compression techniques there are two modes in which you can invoke the AIMET API
- Auto Mode: In Auto mode, AIMET will determine the optimal way to compress each layer of
- the model given an overall target compression ratio. Greedy Compression Ratio Selection Algorithm is used to pick appropriate compression ratios for each layer. 
 
- Manual Mode: In Manual mode, the user can pass in the desired compression-ratio per layer
- to AIMET. AIMET will apply the specified compression technique for each of the layers to achieve the desired compression-ratio per layer. It is recommended that the user start with Auto mode, and then tweak per-layer compression-ratios using Manual mode if desired. 
 
 
Greedy Selection Parameters¶
- 
class aimet_common.defs.GreedySelectionParameters(target_comp_ratio, num_comp_ratio_candidates=10, use_monotonic_fit=False, saved_eval_scores_dict=None)¶
- Configuration parameters for the Greedy compression-ratio selection algorithm - Variables
- target_comp_ratio – Target compression ratio. Expressed as value between 0 and 1. Compression ratio is the ratio of cost of compressed model to cost of the original model. 
- num_comp_ratio_candidates – Number of comp-ratio candidates to analyze per-layer More candidates allows more granular distribution of compression at the cost of increased run-time during analysis. Default value=10. Value should be greater than 1. 
- use_monotonic_fit – If True, eval scores in the eval dictionary are fitted to a monotonically increasing function. This is useful if you see the eval dict scores for some layers are not monotonically increasing. By default, this option is set to False. 
- saved_eval_scores_dict – Path to the eval_scores dictionary pickle file that was saved in a previous run. This is useful to speed-up experiments when trying different target compression-ratios for example. aimet will save eval_scores dictionary pickle file automatically in a ./data directory relative to the current path. num_comp_ratio_candidates parameter will be ignored when this option is used. 
 
 
TAR Selection Parameters¶
- 
class aimet_torch.defs.TarRankSelectionParameters(num_rank_indices)¶
- Configuration parameters for the TAR compression-ratio selection algorithm - Variables
- num_rank_indices – Number of rank indices for ratio selection. 
 
Spatial SVD Configuration¶
- 
class aimet_torch.defs.SpatialSvdParameters(mode, params, multiplicity=1)¶
- Configuration parameters for spatial svd compression - Parameters
- mode ( - Mode) – Either auto mode or manual mode
- params ( - Union[- ManualModeParams,- AutoModeParams]) – Parameters for the mode selected
- multiplicity – The multiplicity to which ranks/input channels will get rounded. Default: 1 
 
 - 
class AutoModeParams(greedy_select_params, modules_to_ignore=None)¶
- Configuration parameters for auto-mode compression - Parameters
- greedy_select_params ( - GreedySelectionParameters) – Params for greedy comp-ratio selection algorithm
- modules_to_ignore ( - Optional[- List[- Module]]) – List of modules to ignore (None indicates nothing to ignore)
 
 
 - 
class ManualModeParams(list_of_module_comp_ratio_pairs)¶
- Configuration parameters for manual-mode spatial svd compression - Parameters
- list_of_module_comp_ratio_pairs ( - List[- ModuleCompRatioPair]) – List of (module, comp-ratio) pairs
 
 
Weight SVD Configuration¶
- 
class aimet_torch.defs.WeightSvdParameters(mode, params, multiplicity=1)¶
- Configuration parameters for weight svd compression - Parameters
- mode ( - Mode) – Either auto mode or manual mode
- params ( - Union[- ManualModeParams,- AutoModeParams]) – Parameters for the mode selected
- multiplicity – The multiplicity to which ranks/input channels will get rounded. Default: 1 
 
 - 
class AutoModeParams(rank_select_scheme, select_params, modules_to_ignore=None)¶
- Configuration parameters for auto-mode compression - Parameters
- rank_select_scheme ( - RankSelectScheme) – supports two options greedy and tar
- select_params ( - Union[- GreedySelectionParameters,- TarRankSelectionParameters]) – Params for greedy/TAR comp-ratio selection algorithm
- modules_to_ignore ( - Optional[- List[- Module]]) – List of modules to ignore (None indicates nothing to ignore)
 
 
 - 
class ManualModeParams(list_of_module_comp_ratio_pairs)¶
- Configuration parameters for manual-mode weight svd compression - Parameters
- list_of_module_comp_ratio_pairs ( - List[- ModuleCompRatioPair]) – List of (module, comp-ratio) pairs
 
 
Channel Pruning Configuration¶
- 
class aimet_torch.defs.ChannelPruningParameters(data_loader, num_reconstruction_samples, allow_custom_downsample_ops, mode, params, multiplicity=1)¶
- Configuration parameters for channel pruning compression - 
class AutoModeParams(greedy_select_params, modules_to_ignore=None)¶
- Configuration parameters for auto-mode compression - Parameters
- greedy_select_params ( - GreedySelectionParameters) – Params for greedy comp-ratio selection algorithm
- modules_to_ignore ( - Optional[- List[- Module]]) – List of modules to ignore (None indicates nothing to ignore)
 
 
 - 
class ManualModeParams(list_of_module_comp_ratio_pairs)¶
- Configuration parameters for manual-mode channel pruning compression - Parameters
- list_of_module_comp_ratio_pairs ( - List[- ModuleCompRatioPair]) – List of (module, comp-ratio) pairs
 
 
- 
class 
Configuration Definitions¶
- 
class aimet_common.defs.CostMetric
- Enumeration of metrics to measure cost of a model/layer - 
mac= 1
- MAC: Cost modeled for compute requirements 
 - 
memory= 2
- Memory: Cost modeled for space requirements 
 
- 
- 
class aimet_common.defs.CompressionScheme
- Enumeration of compression schemes supported in aimet - 
channel_pruning= 3
- Channel Pruning 
 - 
spatial_svd= 2
- Spatial SVD 
 - 
weight_svd= 1
- Weight SVD 
 
- 
- 
class aimet_torch.defs.ModuleCompRatioPair(module, comp_ratio)¶
- Pair of torch.nn.module and a compression-ratio - Variables
- module – Module of type torch.nn.module 
- comp_ratio – Compression ratio. Compression ratio is the ratio of cost of compressed model to cost of the original model. 
 
 
Code Examples¶
Required imports
import os
from decimal import Decimal
import torch
# Compression-related imports
from aimet_common.defs import CostMetric, CompressionScheme, GreedySelectionParameters, RankSelectScheme
from aimet_torch.defs import WeightSvdParameters, SpatialSvdParameters, ChannelPruningParameters, \
    ModuleCompRatioPair
from aimet_torch.compress import ModelCompressor
from aimet_torch.examples import mnist_torch_model
Evaluation function
def evaluate_model(model: torch.nn.Module, eval_iterations: int, use_cuda: bool = False) -> float:
    """
    This is intended to be the user-defined model evaluation function.
    AIMET requires the above signature. So if the user's eval function does not
    match this signature, please create a simple wrapper.
    Note: Honoring the number of iterations is not absolutely necessary.
    However if all evaluations run over an entire epoch of validation data,
    the runtime for AIMET compression will obviously be higher.
    :param model: Model to evaluate
    :param eval_iterations: Number of iterations to use for evaluation.
            None for entire epoch.
    :param use_cuda: If true, evaluate using gpu acceleration
    :return: single float number (accuracy) representing model's performance
    """
    return .5
Compressing using Spatial SVD in auto mode with multiplicity = 8 for rank rounding
def spatial_svd_auto_mode():
    # load trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
    # Specify the necessary parameters
    greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                              num_comp_ratio_candidates=10)
    auto_params = SpatialSvdParameters.AutoModeParams(greedy_params,
                                                      modules_to_ignore=[model.conv1])
    params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.auto,
                                  params=auto_params, multiplicity=8)
    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.spatial_svd,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)
    compressed_model, stats = results
    print(compressed_model)
    print(stats)     # Stats object can be pretty-printed easily
Compressing using Spatial SVD in manual mode
def spatial_svd_manual_mode():
    # Load a trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
    # Specify the necessary parameters
    manual_params = SpatialSvdParameters.ManualModeParams([ModuleCompRatioPair(model.conv1, 0.5),
                                                           ModuleCompRatioPair(model.conv2, 0.4)])
    params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.manual,
                                  params=manual_params)
    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.spatial_svd,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)
    compressed_model, stats = results
    print(compressed_model)
    print(stats)    # Stats object can be pretty-printed easily
Compressing using Weight SVD in auto mode
def weight_svd_auto_mode():
    # Load trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
    # Specify the necessary parameters
    greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                              num_comp_ratio_candidates=10)
    rank_select = RankSelectScheme.greedy
    auto_params = WeightSvdParameters.AutoModeParams(rank_select_scheme=rank_select,
                                                     select_params=greedy_params,
                                                     modules_to_ignore=[model.conv1])
    params = WeightSvdParameters(mode=WeightSvdParameters.Mode.auto,
                                 params=auto_params)
    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.weight_svd,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)
    compressed_model, stats = results
    print(compressed_model)
    print(stats)     # Stats object can be pretty-printed easily
Compressing using Weight SVD in manual mode with multiplicity = 8 for rank rounding
def weight_svd_manual_mode():
    # Load a trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
    # Specify the necessary parameters
    manual_params = WeightSvdParameters.ManualModeParams([ModuleCompRatioPair(model.conv1, 0.5),
                                                          ModuleCompRatioPair(model.conv2, 0.4)])
    params = WeightSvdParameters(mode=WeightSvdParameters.Mode.manual,
                                 params=manual_params, multiplicity=8)
    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.weight_svd,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)
    compressed_model, stats = results
    print(compressed_model)
    print(stats)    # Stats object can be pretty-printed easily
Compressing using Channel Pruning in auto mode
def channel_pruning_auto_mode():
    # Load trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
    # Specify the necessary parameters
    greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                              num_comp_ratio_candidates=10)
    auto_params = ChannelPruningParameters.AutoModeParams(greedy_params,
                                                          modules_to_ignore=[model.conv1])
    data_loader = mnist_torch_model.DataLoaderMnist(cuda=True, seed=1, shuffle=True)
    params = ChannelPruningParameters(data_loader=data_loader.train_loader,
                                      num_reconstruction_samples=500,
                                      allow_custom_downsample_ops=True,
                                      mode=ChannelPruningParameters.Mode.auto,
                                      params=auto_params)
    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.channel_pruning,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)
    compressed_model, stats = results
    print(compressed_model)
    print(stats)     # Stats object can be pretty-printed easily
Compressing using Channel Pruning in manual mode
def channel_pruning_manual_mode():
    # Load a trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
    # Specify the necessary parameters
    manual_params = ChannelPruningParameters.ManualModeParams([ModuleCompRatioPair(model.conv2, 0.4)])
    data_loader = mnist_torch_model.DataLoaderMnist(cuda=True, seed=1, shuffle=True)
    params = ChannelPruningParameters(data_loader=data_loader.train_loader,
                                      num_reconstruction_samples=500,
                                      allow_custom_downsample_ops=True,
                                      mode=ChannelPruningParameters.Mode.manual,
                                      params=manual_params)
    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.channel_pruning,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)
    compressed_model, stats = results
    print(compressed_model)
    print(stats)    # Stats object can be pretty-printed easily
Example Training Object
class Trainer:
    """ Example trainer class """
    def __init__(self):
        self._layer_db = []
    def train_model(self, model, layer, train_flag=True):
        """
        Trains a model
        :param model: Model to be trained
        :param layer: layer which has to be fine tuned
        :param train_flag: Default: True. If ture the model gets trained
        :return:
        """
        if train_flag:
            mnist_torch_model.train(model, epochs=1, use_cuda=True, batch_size=50, batch_callback=None)
        self._layer_db.append(layer)
Compressing using Spatial SVD in auto mode with layer-wise fine tuning
def spatial_svd_auto_mode_with_layerwise_finetuning():
    # load trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
    # Specify the necessary parameters
    greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                              num_comp_ratio_candidates=10)
    auto_params = SpatialSvdParameters.AutoModeParams(greedy_params,
                                                      modules_to_ignore=[model.conv1])
    params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.auto,
                                  params=auto_params)
    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.spatial_svd,
                                             cost_metric=CostMetric.mac,
                                             parameters=params, trainer=Trainer())
    compressed_model, stats = results
    print(compressed_model)
    print(stats)     # Stats object can be pretty-printed easily
