Spatial SVDΒΆ

ContextΒΆ

Spatial singular value decomposition (spatial SVD) is a tensor decomposition technique which decomposes one large layer (in terms of Multiply-accumulate(MAC) or memory) into two smaller layers.

Consider a convolution (Conv) layer with kernel (π‘š, 𝑛, β„Ž, 𝑀), where:

  • π‘š is the input channels

  • 𝑛 the output channels

  • β„Ž is the height of the kernel

  • 𝑀 is the width of the kernel

Spatial SVD decomposes the kernel into two kernels, one of size (π‘š, π‘˜, β„Ž, 1) and one of size (π‘˜, 𝑛, 1, 𝑀), where π‘˜ is called the rank. The smaller the value of π‘˜, the larger the degree of compression.

The following figure illustrates how spatial SVD decomposes both the output channel dimension and the size of the Conv kernel itself.

../../_images/spatial_svd.png

WorkflowΒΆ

Code exampleΒΆ

SetupΒΆ

import os
from decimal import Decimal
import torch


# Compression-related imports
from aimet_common.defs import CostMetric, CompressionScheme, GreedySelectionParameters, RankSelectScheme
from aimet_torch.defs import WeightSvdParameters, SpatialSvdParameters, ChannelPruningParameters, \
    ModuleCompRatioPair
from aimet_torch.compress import ModelCompressor
def evaluate_model(model: torch.nn.Module, eval_iterations: int, use_cuda: bool = False) -> float:
    """
    This is intended to be the user-defined model evaluation function.
    AIMET requires the above signature. So if the user's eval function does not
    match this signature, please create a simple wrapper.

    Note: Honoring the number of iterations is not absolutely necessary.
    However if all evaluations run over an entire epoch of validation data,
    the runtime for AIMET compression will obviously be higher.

    :param model: Model to evaluate
    :param eval_iterations: Number of iterations to use for evaluation.
            None for entire epoch.
    :param use_cuda: If true, evaluate using gpu acceleration
    :return: single float number (accuracy) representing model's performance
    """
    return .5

Compressing using Spatial SVDΒΆ

Compressing using Spatial SVD in auto mode with multiplicity = 8 for rank rounding

def spatial_svd_auto_mode():

    # load trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))

    # Specify the necessary parameters
    greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                              num_comp_ratio_candidates=10)
    auto_params = SpatialSvdParameters.AutoModeParams(greedy_params,
                                                      modules_to_ignore=[model.conv1])

    params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.auto,
                                  params=auto_params, multiplicity=8)

    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.spatial_svd,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)

    compressed_model, stats = results
    print(compressed_model)
    print(stats)     # Stats object can be pretty-printed easily

Compressing using Spatial SVD in manual mode

def spatial_svd_manual_mode():

    # Load a trained MNIST model
    model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))

    # Specify the necessary parameters
    manual_params = SpatialSvdParameters.ManualModeParams([ModuleCompRatioPair(model.conv1, 0.5),
                                                           ModuleCompRatioPair(model.conv2, 0.4)])
    params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.manual,
                                  params=manual_params)

    # Single call to compress the model
    results = ModelCompressor.compress_model(model,
                                             eval_callback=evaluate_model,
                                             eval_iterations=1000,
                                             input_shape=(1, 1, 28, 28),
                                             compress_scheme=CompressionScheme.spatial_svd,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)

    compressed_model, stats = results
    print(compressed_model)
    print(stats)    # Stats object can be pretty-printed easily

APIΒΆ

Top-level API for Compression

class aimet_torch.compress.ModelCompressor[source]

AIMET model compressor: Enables model compression using various schemes

static ModelCompressor.compress_model(model, eval_callback, eval_iterations, input_shape, compress_scheme, cost_metric, parameters, trainer=None, visualization_url=None)[source]

Compress a given model using the specified parameters

Parameters:
  • model (Module) – Model to compress

  • eval_callback (Callable[[Any, Optional[int], bool], float]) – Evaluation callback. Expected signature is evaluate(model, iterations, use_cuda). Expected to return an accuracy metric.

  • eval_iterations – Iterations to run evaluation for

  • trainer – Training Class: Contains a callable, train_model, which takes model, layer which is being fine tuned and an optional parameter train_flag as a parameter None: If per layer fine tuning is not required while creating the final compressed model

  • input_shape (Tuple) – Shape of the input tensor for model

  • compress_scheme (CompressionScheme) – Compression scheme. See the enum for allowed values

  • cost_metric (CostMetric) – Cost metric to use for the compression-ratio (either mac or memory)

  • parameters (Union[SpatialSvdParameters, WeightSvdParameters, ChannelPruningParameters]) – Compression parameters specific to given compression scheme

  • visualization_url – url the user will need to input where visualizations will appear

Return type:

Tuple[Module, CompressionStats]

Returns:

A tuple of the compressed model, and compression statistics

Greedy Selection Parameters

class aimet_common.defs.GreedySelectionParameters(target_comp_ratio, num_comp_ratio_candidates=10, use_monotonic_fit=False, saved_eval_scores_dict=None)[source]

Configuration parameters for the Greedy compression-ratio selection algorithm

Variables:
  • target_comp_ratio – Target compression ratio. Expressed as value between 0 and 1. Compression ratio is the ratio of cost of compressed model to cost of the original model.

  • num_comp_ratio_candidates – Number of comp-ratio candidates to analyze per-layer More candidates allows more granular distribution of compression at the cost of increased run-time during analysis. Default value=10. Value should be greater than 1.

  • use_monotonic_fit – If True, eval scores in the eval dictionary are fitted to a monotonically increasing function. This is useful if you see the eval dict scores for some layers are not monotonically increasing. By default, this option is set to False.

  • saved_eval_scores_dict – Path to the eval_scores dictionary pickle file that was saved in a previous run. This is useful to speed-up experiments when trying different target compression-ratios for example. aimet will save eval_scores dictionary pickle file automatically in a ./data directory relative to the current path. num_comp_ratio_candidates parameter will be ignored when this option is used.

Configuration Definitions

class aimet_common.defs.CostMetric(value)[source]

Enumeration of metrics to measure cost of a model/layer

mac = 1

Cost modeled for compute requirements

Type:

MAC

memory = 2

Cost modeled for space requirements

Type:

Memory

class aimet_common.defs.CompressionScheme(value)[source]

Enumeration of compression schemes supported in aimet

channel_pruning = 3

Channel Pruning

spatial_svd = 2

Spatial SVD

weight_svd = 1

Weight SVD

Spatial SVD Configuration

class aimet_torch.defs.SpatialSvdParameters(mode, params, multiplicity=1)[source]

Configuration parameters for spatial svd compression

Parameters:
  • mode (Mode) – Either auto mode or manual mode

  • params (Union[ManualModeParams, AutoModeParams]) – Parameters for the mode selected

  • multiplicity – The multiplicity to which ranks/input channels will get rounded. Default: 1

class AutoModeParams(greedy_select_params, modules_to_ignore=None)[source]

Configuration parameters for auto-mode compression

Parameters:
  • greedy_select_params (GreedySelectionParameters) – Params for greedy comp-ratio selection algorithm

  • modules_to_ignore (Optional[List[Module]]) – List of modules to ignore (None indicates nothing to ignore)

class ManualModeParams(list_of_module_comp_ratio_pairs)[source]

Configuration parameters for manual-mode spatial svd compression

Parameters:

list_of_module_comp_ratio_pairs (List[ModuleCompRatioPair]) – List of (module, comp-ratio) pairs

class Mode(value)[source]

Mode enumeration

auto = 2

Auto mode

manual = 1

Manual mode