AIMET TensorFlow Compression API

Introduction

AIMET supports the following model compression techniques for tensorflow models
  • Spatial SVD

  • Channel Pruning

  • Weight SVD

For the Spatial SVD and Channel Pruning compression techniques, there are two modes in which you can invoke the AIMET API
  • Auto Mode: In Auto mode, AIMET will determine the optimal way to compress each layer of

    the model given an overall target compression ratio. Greedy Compression Ratio Selection Algorithm is used to pick appropriate compression ratios for each layer.

  • Manual Mode: In Manual mode, the user can pass in the desired compression-ratio per layer

    to AIMET. AIMET will apply the specified compression technique for each of the layers to achieve the desired compression-ratio per layer. It is recommended that the user start with Auto mode, and then tweak per-layer compression-ratios using Manual mode if desired.

For Weight SVD, we use Tar-Based Rank selection. Auto and Manual modes are supported for Weight SVD as well.


Top-level API for Compression

class aimet_tensorflow.compress.ModelCompressor

aimet model compressor: Enables model compression using various schemes


static ModelCompressor.compress_model(sess, working_dir, eval_callback, eval_iterations, input_shape, compress_scheme, cost_metric, parameters, trainer=None, visualization_url=None)

Compress a given model using the specified parameters

Parameters
  • sess (Session) – Model, represented by a tf.compat.v1.Session, to compress

  • working_dir (str) – File path to save compressed TensorFlow meta file

  • eval_callback (Callable[[Any, Optional[int], bool], float]) – Evaluation callback. Expected signature is evaluate(model, iterations, use_cuda). Expected to return an accuracy metric.

  • eval_iterations – Iterations to run evaluation for

  • trainer – Training Class: Contains a callable, train_model, which takes model, layer which is being fine tuned and an optional parameter train_flag as a parameter None: If per layer fine tuning is not required while creating the final compressed model

  • input_shape (Union[Tuple, List[Tuple]]) – tuple or list of tuples of input shapes to the model (channels_last format)

  • compress_scheme (CompressionScheme) – Compression scheme. See the enum for allowed values

  • cost_metric (CostMetric) – Cost metric to use for the compression-ratio (either mac or memory)

  • parameters (Union[SpatialSvdParameters, ChannelPruningParameters]) – Compression parameters specific to given compression scheme

  • trainer – Training function None: If per layer fine tuning is not required while creating the final compressed model

  • visualization_url – url the user will need to input where visualizations will appear

Return type

Tuple[Session, CompressionStats]

Returns

A tuple of the compressed model session, and compression statistics


Greedy Selection Parameters

class aimet_common.defs.GreedySelectionParameters(target_comp_ratio, num_comp_ratio_candidates=10, use_monotonic_fit=False, saved_eval_scores_dict=None)

Configuration parameters for the Greedy compression-ratio selection algorithm

Variables
  • target_comp_ratio – Target compression ratio. Expressed as value between 0 and 1. Compression ratio is the ratio of cost of compressed model to cost of the original model.

  • num_comp_ratio_candidates – Number of comp-ratio candidates to analyze per-layer More candidates allows more granular distribution of compression at the cost of increased run-time during analysis. Default value=10. Value should be greater than 1.

  • use_monotonic_fit – If True, eval scores in the eval dictionary are fitted to a monotonically increasing function. This is useful if you see the eval dict scores for some layers are not monotonically increasing. By default, this option is set to False.

  • saved_eval_scores_dict – Path to the eval_scores dictionary pickle file that was saved in a previous run. This is useful to speed-up experiments when trying different target compression-ratios for example. aimet will save eval_scores dictionary pickle file automatically in a ./data directory relative to the current path. num_comp_ratio_candidates parameter will be ignored when this option is used.


Spatial SVD Configuration

class aimet_tensorflow.defs.SpatialSvdParameters(input_op_names, output_op_names, mode, params, multiplicity=1)

Configuration parameters for spatial svd compression

Parameters
  • input_op_names (List[str]) – list of input op names to the model

  • output_op_names (List[str]) – List of output op names of the model

  • mode (Mode) – Either auto mode or manual mode

  • params (Union[ManualModeParams, AutoModeParams]) – Parameters for the mode selected

  • multiplicity – The multiplicity to which ranks/input channels will get rounded. Default: 1

class AutoModeParams(greedy_select_params, modules_to_ignore=None)

Configuration parameters for auto-mode compression

Parameters
  • greedy_select_params (GreedySelectionParameters) – Params for greedy comp-ratio selection algorithm

  • modules_to_ignore (Optional[List[Operation]]) – List of modules to ignore (None indicates nothing to ignore)

class ManualModeParams(list_of_module_comp_ratio_pairs)

Configuration parameters for manual-mode spatial svd compression

Parameters

list_of_module_comp_ratio_pairs (List[ModuleCompRatioPair]) – List of (module, comp-ratio) pairs

class Mode

Mode enumeration

auto = 2

Auto mode

manual = 1

Manual mode


Channel Pruning Configuration

class aimet_tensorflow.defs.ChannelPruningParameters(input_op_names, output_op_names, data_set, batch_size, num_reconstruction_samples, allow_custom_downsample_ops, mode, params, multiplicity=1)

Configuration parameters for channel pruning compression

Parameters
  • input_op_names (List[str]) – list of input op names to the model

  • output_op_names (List[str]) – List of output op names of the model

  • data_set (DatasetV1) – data set

  • batch_size (int) – batch size

  • num_reconstruction_samples (int) – number of samples to be used for reconstruction

  • allow_custom_downsample_ops (bool) – If set to True, DownSampleLayer and UpSampleLayer will be added as required

  • mode (Mode) – indicates whether the mode is manual or auto

  • params (Union[ManualModeParams, AutoModeParams]) – ManualModeParams or AutoModeParams, depending on teh value of mode

  • multiplicity – The multiplicity to which ranks/input channels will get rounded. Default: 1

class AutoModeParams(greedy_select_params, modules_to_ignore=None)

Configuration parameters for auto-mode compression

Parameters
  • greedy_select_params (GreedySelectionParameters) – Params for greedy comp-ratio selection algorithm

  • modules_to_ignore (Optional[List[Operation]]) – List of modules to ignore (None indicates nothing to ignore)

class ManualModeParams(list_of_module_comp_ratio_pairs)

Configuration parameters for manual-mode channel pruning compression

Parameters

list_of_module_comp_ratio_pairs (List[ModuleCompRatioPair]) – List of (module, comp-ratio) pairs

class Mode

Mode enumeration

auto = 2

Auto mode: aimet computes optimal comp-ratio per layer

manual = 1

Manual mode: User specifies comp-ratio per layer


Configuration Definitions

class aimet_common.defs.CostMetric

Enumeration of metrics to measure cost of a model/layer

mac = 1

MAC: Cost modeled for compute requirements

memory = 2

Memory: Cost modeled for space requirements


class aimet_common.defs.CompressionScheme

Enumeration of compression schemes supported in aimet

channel_pruning = 3

Channel Pruning

spatial_svd = 2

Spatial SVD

weight_svd = 1

Weight SVD


class aimet_tensorflow.defs.ModuleCompRatioPair(module, comp_ratio)

Pair of tf.Operation and a compression-ratio

Variables
  • module – Module of type tf.Operation

  • comp_ratio – Compression ratio. Compression ratio is the ratio of cost of compressed model to cost of the original model.


Code Examples

Required imports

from decimal import Decimal

import numpy as np
import tensorflow as tf
from tensorflow.python.keras.applications.vgg16 import VGG16

# Compression-related imports
from aimet_common.defs import GreedySelectionParameters
from aimet_common.defs import CostMetric, CompressionScheme
from aimet_tensorflow.defs import SpatialSvdParameters, ChannelPruningParameters, ModuleCompRatioPair
from aimet_tensorflow.compress import ModelCompressor

Evaluation function

def evaluate_model(sess: tf.compat.v1.Session, eval_iterations: int, use_cuda: bool) -> float:
    """
    This is intended to be the user-defined model evaluation function.
    AIMET requires the above signature. So if the user's eval function does not
    match this signature, please create a simple wrapper.

    Note: Honoring the number of iterations is not absolutely necessary.
    However if all evaluations run over an entire epoch of validation data,
    the runtime for AIMET compression will obviously be higher.

    :param sess: Tensorflow session
    :param eval_iterations: Number of iterations to use for evaluation.
            None for entire epoch.
    :param use_cuda: If true, evaluate using gpu acceleration
    :return: single float number (accuracy) representing model's performance
    """

    # Evaluate model should run data through the model and return an accuracy score.
    # If the model does not have nodes to measure accuracy, they will need to be added to the graph.
    return .5

Compressing using Spatial SVD in auto mode with multiplicity = 8 for rank rounding

def spatial_svd_auto_mode():

    sess = tf.compat.v1.Session()
    # Construct graph
    with sess.graph.as_default():
        _ = VGG16(weights=None, input_shape=(224, 224, 3))
        init = tf.compat.v1.global_variables_initializer()
    sess.run(init)

    # ignore first Conv2D op
    conv2d = sess.graph.get_operation_by_name('block1_conv1/Conv2D')
    modules_to_ignore = [conv2d]

    greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                              num_comp_ratio_candidates=10,
                                              use_monotonic_fit=True,
                                              saved_eval_scores_dict=None)

    auto_params = SpatialSvdParameters.AutoModeParams(greedy_select_params=greedy_params,
                                                      modules_to_ignore=modules_to_ignore)

    params = SpatialSvdParameters(input_op_names=['input_1'], output_op_names=['predictions/Softmax'],
                                  mode=SpatialSvdParameters.Mode.auto, params=auto_params, multiplicity=8)
    input_shape = (1, 3, 224, 224)

    # Single call to compress the model
    compr_model_sess, stats = ModelCompressor.compress_model(sess=sess,
                                                             working_dir=str('./'),
                                                             eval_callback=evaluate_model,
                                                             eval_iterations=10,
                                                             input_shape=input_shape,
                                                             compress_scheme=CompressionScheme.spatial_svd,
                                                             cost_metric=CostMetric.mac,
                                                             parameters=params,
                                                             trainer=None)

    print(stats)    # Stats object can be pretty-printed easily

Compressing using Spatial SVD in manual mode

def spatial_svd_manual_mode():

    sess = tf.compat.v1.Session()
    # Construct graph
    with sess.graph.as_default():
        _ = VGG16(weights=None, input_shape=(224, 224, 3))
        init = tf.compat.v1.global_variables_initializer()
    sess.run(init)

    # Pick two convs to compress as examples
    conv2d = sess.graph.get_operation_by_name('block1_conv1/Conv2D')
    conv2d_1 = sess.graph.get_operation_by_name('block1_conv2/Conv2D')

    # Specify the necessary parameters
    manual_params = SpatialSvdParameters.ManualModeParams([ModuleCompRatioPair(module=conv2d, comp_ratio=0.5),
                                                           ModuleCompRatioPair(module=conv2d_1, comp_ratio=0.4)])

    params = SpatialSvdParameters(input_op_names=['input_1'], output_op_names=['predictions/Softmax'],
                                  mode=SpatialSvdParameters.Mode.manual, params=manual_params)

    input_shape = (1, 3, 224, 224)

    # Single call to compress the model
    compr_model_sess, stats = ModelCompressor.compress_model(sess=sess,
                                                             working_dir=str('./'),
                                                             eval_callback=evaluate_model,
                                                             eval_iterations=10,
                                                             input_shape=input_shape,
                                                             compress_scheme=CompressionScheme.spatial_svd,
                                                             cost_metric=CostMetric.mac,
                                                             parameters=params,
                                                             trainer=None)

    print(stats)    # Stats object can be pretty-printed easily

Compressing using Channel Pruning in auto mode

def channel_pruning_auto_mode():

    sess = tf.compat.v1.Session()
    # Construct graph
    with sess.graph.as_default():
        _ = VGG16(weights=None, input_shape=(224, 224, 3))
        init = tf.compat.v1.global_variables_initializer()
    sess.run(init)

    # ignore first Conv2D op
    conv2d = sess.graph.get_operation_by_name('block1_conv1/Conv2D')
    modules_to_ignore = [conv2d]

    greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
                                              num_comp_ratio_candidates=2,
                                              use_monotonic_fit=True,
                                              saved_eval_scores_dict=None)

    auto_params = ChannelPruningParameters.AutoModeParams(greedy_select_params=greedy_params,
                                                          modules_to_ignore=modules_to_ignore)

    # Create random dataset
    batch_size = 1
    input_data = np.random.rand(100, 224, 224, 3)
    dataset = tf.data.Dataset.from_tensor_slices(input_data)
    dataset = dataset.batch(batch_size=batch_size)

    params = ChannelPruningParameters(input_op_names=['input_1'],
                                      output_op_names=['predictions/Softmax'],
                                      data_set=dataset,
                                      batch_size=32,
                                      num_reconstruction_samples=50,
                                      allow_custom_downsample_ops=False,
                                      mode=ChannelPruningParameters.Mode.auto,
                                      params=auto_params,
                                      multiplicity=8)

    # Single call to compress the model
    results = ModelCompressor.compress_model(sess,
                                             working_dir=None,
                                             eval_callback=evaluate_model,
                                             eval_iterations=10,
                                             input_shape=(32, 224, 224, 3),
                                             compress_scheme=CompressionScheme.channel_pruning,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)

    compressed_model, stats = results
    print(compressed_model)
    print(stats)     # Stats object can be pretty-printed easily

Compressing using Channel Pruning in manual mode

def channel_pruning_manual_mode():

    sess = tf.compat.v1.Session()

    # Construct graph
    with sess.graph.as_default():
        _ = VGG16(weights=None, input_shape=(224, 224, 3))
        init = tf.compat.v1.global_variables_initializer()
    sess.run(init)

    # Create random dataset
    batch_size = 1
    input_data = np.random.rand(100, 224, 224, 3)
    dataset = tf.data.Dataset.from_tensor_slices(input_data)
    dataset = dataset.batch(batch_size=batch_size)

    #  Pick two convs to compress as examples
    block1_conv2_op = sess.graph.get_operation_by_name('block1_conv2/Conv2D')
    block2_conv2_op = sess.graph.get_operation_by_name('block2_conv2/Conv2D')

    list_of_module_comp_ratio_pairs = [ModuleCompRatioPair(block1_conv2_op, 0.5),
                                       ModuleCompRatioPair(block2_conv2_op, 0.5)]

    manual_params = ChannelPruningParameters.ManualModeParams(list_of_module_comp_ratio_pairs=
                                                              list_of_module_comp_ratio_pairs)

    params = ChannelPruningParameters(input_op_names=['input_1'],
                                      output_op_names=['predictions/Softmax'],
                                      data_set=dataset,
                                      batch_size=32,
                                      num_reconstruction_samples=50,
                                      allow_custom_downsample_ops=False,
                                      mode=ChannelPruningParameters.Mode.
                                      manual,
                                      params=manual_params,
                                      multiplicity=8)

    # Single call to compress the model
    results = ModelCompressor.compress_model(sess,
                                             working_dir=None,
                                             eval_callback=evaluate_model,
                                             eval_iterations=10,
                                             input_shape=(32, 224, 224, 3),
                                             compress_scheme=CompressionScheme.channel_pruning,
                                             cost_metric=CostMetric.mac,
                                             parameters=params)

    compressed_model, stats = results
    print(compressed_model)
    print(stats)     # Stats object can be pretty-printed easily

Weight SVD Top-level API

class aimet_tensorflow.svd.Svd(graph, checkpoint, metric, output_file='./svd_graph', svd_type='svd', num_layers=0, layers=None, layer_ranks=None, num_ranks=20, gpu=True, debug=False, no_evaluation=False, layer_selection_threshold=0.6)

A class for performing singular value decomposition on a tensorflow model.

The Svd class enables model compression through singular value decomposition (SVD). It can analyze convolution and fully connected layers and perform some analysis to find the optimal ranks for balancing compression and the accuracy of the network.

Constructor for the Svd class

Constructs the Svd class from a set of options passed in at construction. The class takes a number of named arguments which are detailed below.

Parameters
  • graph – The file path to the meta graph.

  • checkpoint – The file path to the tensorflow checkpoint file.

  • metric – The metric to use for determining the optimal compression. Either ‘mac’ for optimizing compression to minimize multiplies and accumulates or ‘memory’ which optimizes for overall memory footprint. Defaults to ‘memory’

  • output_file – The file path for saving the compressed tensorflow graph. aimet will save to the directory specified, using output_file as a filename prefix

  • svd_type – Indicates which algorithm should be used, either ‘svd’ or ‘ssvd’. Defaults to ‘svd’.

  • num_layers – The number of layers to compress. Defaults to ‘0’ which uses a heuristic to determine the optimal number of layers to compress.

  • layers – A list of op names to compress. All other layers will be ignored. Overrides num_layers and sets it to the length of this list.

  • layer_ranks – required only if no_evaluation is set to True. A list of tuples to compress layers specified in layers argument.

  • num_ranks – The number of ranks (compression_points) to evaluate for compression. Defaults to 20. Value should be greater than 2.

  • gpu – Indicates if the algorithm should run on GPU or CPU. Defaults to GPU. To use CPU set to false

  • debug – If true debug messages will be printed. Defaults to False.

  • no_evaluation – If true, ranks will be set manually from user. Defaults to False.

  • layer_selection_threshold – Threshold (0-1) to use to select the top layers in the network

Raises

ValueError: An error occurred processing one of the input parameters.


Svd.compress_net(generator, eval_names=None, run_graph=<function evaluate_graph>, eval_func=<function default_eval_func>, error_margin=2, iterations=100)

Compresses the network using SVD

Runs rank selection on the network, and compresses it using the method and parameters passed during construction of the Svd object.

Parameters
  • generator – The generator which should be used for generating data for quantization

  • eval_names – The list of names to use for calculating model performance

  • run_graph – The function to use for running data through the graph and evaluating the network’s performance. This function must return only a single number representing the avg performance of the model over the dataset batches. See the ‘graph_eval’ module’s ‘evaluate_graph’ function for the prototype

  • eval_func – The function to use for evaluating the network performance. This function should always return a single number that can be used for comparing different graph’s performance. (The default is accuracy)

  • error_margin – The acceptable degradation in network accuracy from the original. 1 for 1% drop, etc. Defaults to 2%.

  • iterations – The number of iterations (data batches) to run through the network for analysis

Returns

An object containing compression statistics

Raises
  • ValueError: An invalid parameter was passed

  • RuntimeError: An error occurred analyzing or compressing the network. The associated error and other information will be returned with the error.


Code Examples for Weight SVD

Required imports

import os
from aimet_tensorflow import svd as s
from aimet_tensorflow.common import tfrecord_generator as tf_gen
from aimet_tensorflow.common.tfrecord_generator import MnistParser

Compressing using Weight SVD in auto mode

def weight_svd_auto_mode(self):

    # Allocate the generator you wish to use to provide the network with data
    generator = tf_gen.TfRecordGenerator(tfrecords=[os.path.join('data', 'mnist', 'validation.tfrecords')],
                                         parser=MnistParser())

    # Allocate the SVD instance and compress the network
    svd = s.Svd(graph=os.path.join('models', 'mnist_save.meta'), checkpoint=os.path.join('models', 'mnist_save'),
                output_file=os.path.join('svd', 'svd_graph'), layers=[], num_ranks=20,
                layer_selection_threshold=0.95, metric=s.CostMetric.memory)

    stats = svd.compress_net(generator=generator, iterations=10)

    stats.pretty_print() # Print the stats for Weight SVD compression

Compressing using Weight SVD in manual mode

def weight_svd_manual_mode(self):

    # Allocate the generator you wish to use to provide the network with data
    generator = tf_gen.TfRecordGenerator(tfrecords=[os.path.join('data', 'mnist', 'validation.tfrecords')],
                                         parser=MnistParser())

    # Only Compress Conv2d_1 and MatMul_1 with ranks 31 and 9 respectively
    # no_evaluation should be True in Manual mode

    layers = ['Conv2D_1', 'MatMul_1']
    layer_ranks = [('Conv2D_1', 31), ('MatMul_1', 9)]

    svd = s.Svd(graph=os.path.join('models', 'mnist_save.meta'), checkpoint=os.path.join('models', 'mnist_save'),
                output_file=os.path.join('svd', 'svd_graph'), layers=layers, layer_ranks=layer_ranks, num_ranks=20,
                no_evaluation=True, metric=s.CostMetric.memory)

    stats = svd.compress_net(generator=generator, iterations=10)

    stats.pretty_print() # Print the stats for Weight SVD compression