AIMET PyTorch Compression API¶
Introduction¶
- AIMET supports the following model compression techniques for PyTorch models
Weight SVD
Spatial SVD
Channel Pruning
- For all of these compression techniques there are two modes in which you can invoke the AIMET API
- Auto Mode: In Auto mode, AIMET will determine the optimal way to compress each layer of
the model given an overall target compression ratio. Greedy Compression Ratio Selection Algorithm is used to pick appropriate compression ratios for each layer.
- Manual Mode: In Manual mode, the user can pass in the desired compression-ratio per layer
to AIMET. AIMET will apply the specified compression technique for each of the layers to achieve the desired compression-ratio per layer. It is recommended that the user start with Auto mode, and then tweak per-layer compression-ratios using Manual mode if desired.
Greedy Selection Parameters¶
-
class
aimet_common.defs.
GreedySelectionParameters
(target_comp_ratio, num_comp_ratio_candidates=10, use_monotonic_fit=False, saved_eval_scores_dict=None)¶ Configuration parameters for the Greedy compression-ratio selection algorithm
- Variables
target_comp_ratio – Target compression ratio. Expressed as value between 0 and 1. Compression ratio is the ratio of cost of compressed model to cost of the original model.
num_comp_ratio_candidates – Number of comp-ratio candidates to analyze per-layer More candidates allows more granular distribution of compression at the cost of increased run-time during analysis. Default value=10. Value should be greater than 1.
use_monotonic_fit – If True, eval scores in the eval dictionary are fitted to a monotonically increasing function. This is useful if you see the eval dict scores for some layers are not monotonically increasing. By default, this option is set to False.
saved_eval_scores_dict – Path to the eval_scores dictionary pickle file that was saved in a previous run. This is useful to speed-up experiments when trying different target compression-ratios for example. aimet will save eval_scores dictionary pickle file automatically in a ./data directory relative to the current path. num_comp_ratio_candidates parameter will be ignored when this option is used.
TAR Selection Parameters¶
Spatial SVD Configuration¶
Weight SVD Configuration¶
Channel Pruning Configuration¶
Configuration Definitions¶
-
class
aimet_common.defs.
CostMetric
Enumeration of metrics to measure cost of a model/layer
-
mac
= 1 MAC: Cost modeled for compute requirements
-
memory
= 2 Memory: Cost modeled for space requirements
-
-
class
aimet_common.defs.
CompressionScheme
Enumeration of compression schemes supported in aimet
-
channel_pruning
= 3 Channel Pruning
-
spatial_svd
= 2 Spatial SVD
-
weight_svd
= 1 Weight SVD
-
Code Examples¶
Required imports
import os
from decimal import Decimal
import torch
# Compression-related imports
from aimet_common.defs import CostMetric, CompressionScheme, GreedySelectionParameters, RankSelectScheme
from aimet_torch.defs import WeightSvdParameters, SpatialSvdParameters, ChannelPruningParameters, \
ModuleCompRatioPair
from aimet_torch.compress import ModelCompressor
from aimet_torch.examples import mnist_torch_model
Evaluation function
def evaluate_model(model: torch.nn.Module, eval_iterations: int, use_cuda: bool = False) -> float:
"""
This is intended to be the user-defined model evaluation function.
AIMET requires the above signature. So if the user's eval function does not
match this signature, please create a simple wrapper.
Note: Honoring the number of iterations is not absolutely necessary.
However if all evaluations run over an entire epoch of validation data,
the runtime for AIMET compression will obviously be higher.
:param model: Model to evaluate
:param eval_iterations: Number of iterations to use for evaluation.
None for entire epoch.
:param use_cuda: If true, evaluate using gpu acceleration
:return: single float number (accuracy) representing model's performance
"""
return .5
Compressing using Spatial SVD in auto mode with multiplicity = 8 for rank rounding
def spatial_svd_auto_mode():
# load trained MNIST model
model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
# Specify the necessary parameters
greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
num_comp_ratio_candidates=10)
auto_params = SpatialSvdParameters.AutoModeParams(greedy_params,
modules_to_ignore=[model.conv1])
params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.auto,
params=auto_params, multiplicity=8)
# Single call to compress the model
results = ModelCompressor.compress_model(model,
eval_callback=evaluate_model,
eval_iterations=1000,
input_shape=(1, 1, 28, 28),
compress_scheme=CompressionScheme.spatial_svd,
cost_metric=CostMetric.mac,
parameters=params)
compressed_model, stats = results
print(compressed_model)
print(stats) # Stats object can be pretty-printed easily
Compressing using Spatial SVD in manual mode
def spatial_svd_manual_mode():
# Load a trained MNIST model
model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
# Specify the necessary parameters
manual_params = SpatialSvdParameters.ManualModeParams([ModuleCompRatioPair(model.conv1, 0.5),
ModuleCompRatioPair(model.conv2, 0.4)])
params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.manual,
params=manual_params)
# Single call to compress the model
results = ModelCompressor.compress_model(model,
eval_callback=evaluate_model,
eval_iterations=1000,
input_shape=(1, 1, 28, 28),
compress_scheme=CompressionScheme.spatial_svd,
cost_metric=CostMetric.mac,
parameters=params)
compressed_model, stats = results
print(compressed_model)
print(stats) # Stats object can be pretty-printed easily
Compressing using Weight SVD in auto mode
def weight_svd_auto_mode():
# Load trained MNIST model
model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
# Specify the necessary parameters
greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
num_comp_ratio_candidates=10)
rank_select = RankSelectScheme.greedy
auto_params = WeightSvdParameters.AutoModeParams(rank_select_scheme=rank_select,
select_params=greedy_params,
modules_to_ignore=[model.conv1])
params = WeightSvdParameters(mode=WeightSvdParameters.Mode.auto,
params=auto_params)
# Single call to compress the model
results = ModelCompressor.compress_model(model,
eval_callback=evaluate_model,
eval_iterations=1000,
input_shape=(1, 1, 28, 28),
compress_scheme=CompressionScheme.weight_svd,
cost_metric=CostMetric.mac,
parameters=params)
compressed_model, stats = results
print(compressed_model)
print(stats) # Stats object can be pretty-printed easily
Compressing using Weight SVD in manual mode with multiplicity = 8 for rank rounding
def weight_svd_manual_mode():
# Load a trained MNIST model
model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
# Specify the necessary parameters
manual_params = WeightSvdParameters.ManualModeParams([ModuleCompRatioPair(model.conv1, 0.5),
ModuleCompRatioPair(model.conv2, 0.4)])
params = WeightSvdParameters(mode=WeightSvdParameters.Mode.manual,
params=manual_params, multiplicity=8)
# Single call to compress the model
results = ModelCompressor.compress_model(model,
eval_callback=evaluate_model,
eval_iterations=1000,
input_shape=(1, 1, 28, 28),
compress_scheme=CompressionScheme.weight_svd,
cost_metric=CostMetric.mac,
parameters=params)
compressed_model, stats = results
print(compressed_model)
print(stats) # Stats object can be pretty-printed easily
Compressing using Channel Pruning in auto mode
def channel_pruning_auto_mode():
# Load trained MNIST model
model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
# Specify the necessary parameters
greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
num_comp_ratio_candidates=10)
auto_params = ChannelPruningParameters.AutoModeParams(greedy_params,
modules_to_ignore=[model.conv1])
data_loader = mnist_torch_model.DataLoaderMnist(cuda=True, seed=1, shuffle=True)
params = ChannelPruningParameters(data_loader=data_loader.train_loader,
num_reconstruction_samples=500,
allow_custom_downsample_ops=True,
mode=ChannelPruningParameters.Mode.auto,
params=auto_params)
# Single call to compress the model
results = ModelCompressor.compress_model(model,
eval_callback=evaluate_model,
eval_iterations=1000,
input_shape=(1, 1, 28, 28),
compress_scheme=CompressionScheme.channel_pruning,
cost_metric=CostMetric.mac,
parameters=params)
compressed_model, stats = results
print(compressed_model)
print(stats) # Stats object can be pretty-printed easily
Compressing using Channel Pruning in manual mode
def channel_pruning_manual_mode():
# Load a trained MNIST model
model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
# Specify the necessary parameters
manual_params = ChannelPruningParameters.ManualModeParams([ModuleCompRatioPair(model.conv2, 0.4)])
data_loader = mnist_torch_model.DataLoaderMnist(cuda=True, seed=1, shuffle=True)
params = ChannelPruningParameters(data_loader=data_loader.train_loader,
num_reconstruction_samples=500,
allow_custom_downsample_ops=True,
mode=ChannelPruningParameters.Mode.manual,
params=manual_params)
# Single call to compress the model
results = ModelCompressor.compress_model(model,
eval_callback=evaluate_model,
eval_iterations=1000,
input_shape=(1, 1, 28, 28),
compress_scheme=CompressionScheme.channel_pruning,
cost_metric=CostMetric.mac,
parameters=params)
compressed_model, stats = results
print(compressed_model)
print(stats) # Stats object can be pretty-printed easily
Example Training Object
class Trainer:
""" Example trainer class """
def __init__(self):
self._layer_db = []
def train_model(self, model, layer, train_flag=True):
"""
Trains a model
:param model: Model to be trained
:param layer: layer which has to be fine tuned
:param train_flag: Default: True. If ture the model gets trained
:return:
"""
if train_flag:
mnist_torch_model.train(model, epochs=1, use_cuda=True, batch_size=50, batch_callback=None)
self._layer_db.append(layer)
Compressing using Spatial SVD in auto mode with layer-wise fine tuning
def spatial_svd_auto_mode_with_layerwise_finetuning():
# load trained MNIST model
model = torch.load(os.path.join('../', 'data', 'mnist_trained_on_GPU.pth'))
# Specify the necessary parameters
greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.8),
num_comp_ratio_candidates=10)
auto_params = SpatialSvdParameters.AutoModeParams(greedy_params,
modules_to_ignore=[model.conv1])
params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.auto,
params=auto_params)
# Single call to compress the model
results = ModelCompressor.compress_model(model,
eval_callback=evaluate_model,
eval_iterations=1000,
input_shape=(1, 1, 28, 28),
compress_scheme=CompressionScheme.spatial_svd,
cost_metric=CostMetric.mac,
parameters=params, trainer=Trainer())
compressed_model, stats = results
print(compressed_model)
print(stats) # Stats object can be pretty-printed easily