AIMET PyTorch Quantization SIM API

AIMET Quantization Sim requires PyTorch model definition to follow certain guidelines. These guidelines are described in detail here. Model Guidelines

AIMET provides Model Preparer API to allow user to prepare PyTorch model for AIMET Quantization features. The API and usage examples are described in detail here. Model Preparer API

AIMET also includes a Model Validator utility to allow user to check their model definition. Please see the API and usage examples for this utility here. Model Validator API

Top-level API

class aimet_torch.quantsim.QuantizationSimModel(model, dummy_input, quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, rounding_mode='nearest', default_output_bw=8, default_param_bw=8, in_place=False, config_file=None, default_data_type=<QuantizationDataType.int: 1>)

Implements mechanism to add quantization simulations ops to a model. This allows for off-target simulation of inference accuracy. Also allows the model to be fine-tuned to counter the effects of quantization.

Constructor

Parameters
  • model (Module) – Model to add simulation ops to

  • dummy_input (Union[Tensor, Tuple]) – Dummy input to the model. Used to parse model graph. If the model has more than one input, pass a tuple. User is expected to place the tensors on the appropriate device.

  • quant_scheme (Union[str, QuantScheme]) – Quantization scheme. Supported options are ‘tf_enhanced’ or ‘tf’ or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced

  • rounding_mode (str) – Rounding mode. Supported options are ‘nearest’ or ‘stochastic’

  • default_output_bw (int) – Default bitwidth (4-31) to use for quantizing all layer inputs and outputs

  • default_param_bw (int) – Default bitwidth (4-31) to use for quantizing all layer parameters

  • in_place (bool) – If True, then the given ‘model’ is modified in-place to add quant-sim nodes. Only suggested use of this option is when the user wants to avoid creating a copy of the model

  • config_file (Optional[str]) – Path to Configuration file for model quantizers

  • default_data_type (QuantizationDataType) – Default data type to use for quantizing all layer inputs, outputs and parameters. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 and default_param_bw=16


The following API can be used to Compute Encodings for Model

QuantizationSimModel.compute_encodings(forward_pass_callback, forward_pass_callback_args)

Computes encodings for all quantization sim nodes in the model. It is also used to find initial encodings for Range Learning

Parameters
  • forward_pass_callback – A callback function that simply runs forward passes on the model. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. This callback internally chooses the number of data samples it wants to use for calculating encodings.

  • forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex. If set to None, forward_pass_callback will be invoked with no parameters.

Returns

None


The following APIs can be used to save and restore the quantized model

quantsim.save_checkpoint(file_path)

This API provides a way for the user to save a checkpoint of the quantized model which can be loaded at a later point to continue fine-tuning e.g. See also load_checkpoint()

Parameters
  • quant_sim_model (QuantizationSimModel) – QuantizationSimModel to save checkpoint for

  • file_path (str) – Path to the file where you want to save the checkpoint

Returns

None


quantsim.load_checkpoint()

Load the quantized model

Parameters

file_path (str) – Path to the file where you want to save the checkpoint

Return type

QuantizationSimModel

Returns

A new instance of the QuantizationSimModel created after loading the checkpoint


The following API can be used to Export the Model to target

QuantizationSimModel.export(path, filename_prefix, dummy_input, onnx_export_args=<aimet_torch.onnx_utils.OnnxExportApiArgs object>)

This method exports out the quant-sim model so it is ready to be run on-target.

Specifically, the following are saved

  1. The sim-model is exported to a regular PyTorch model without any simulation ops

  2. The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)

  3. Optionally, An equivalent model in ONNX format is exported. In addition, nodes in the ONNX model are named the same as the corresponding PyTorch module names. This helps with matching ONNX node to their quant encoding from #2.

Parameters
  • path (str) – path where to store model pth and encodings

  • filename_prefix (str) – Prefix to use for filenames of the model pth and encodings files

  • dummy_input (Union[Tensor, Tuple]) – Dummy input to the model. Used to parse model graph. It is required for the dummy_input to be placed on CPU.

  • onnx_export_args (Optional[OnnxExportApiArgs]) – optional export argument with onnx specific overrides if not provide export via

torchscript graph :return: None


Encoding format is described in the Quantization Encoding Specification


Enum Definition

Quant Scheme Enum

class aimet_common.defs.QuantScheme

Enumeration of Quant schemes

post_training_tf = 1

Tf scheme

post_training_tf_enhanced = 2

Tf- enhanced scheme


Code Examples

Required imports


import logging
import torch
import torch.cuda
from torch.utils.data import DataLoader
from torchvision import models
from aimet_common.utils import AimetLogger
from aimet_common.defs import QuantScheme
from aimet_torch.utils import create_fake_data_loader
from aimet_torch.model_preparer import prepare_model
from aimet_torch.quantsim import QuantizationSimModel

Train function

def train(model: torch.nn.Module, data_loader: DataLoader) -> torch.Tensor:
    """
    This is intended to be the user-defined model train function.
    :param model: torch model
    :param data_loader: torch data loader
    :return: total loss
    """
    total_loss = 0
    model.train()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    for (data, labels) in data_loader:
        optimizer.zero_grad()
        data = data.cuda()
        labels = labels.cuda()
        predicted = model(data)
        loss = criterion(predicted, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss

    return total_loss

Evaluation function

def evaluate(model: torch.nn.Module, forward_pass_callback_args):
    """
     This is intended to be the user-defined model evaluation function. AIMET requires the above signature. So if the
     user's eval function does not match this signature, please create a simple wrapper.
     Use representative dataset that covers diversity in training data to compute optimal encodings.

    :param model: Model to evaluate
    :param forward_pass_callback_args: These argument(s) are passed to the forward_pass_callback as-is. Up to
            the user to determine the type of this parameter. E.g. could be simply an integer representing the number
            of data samples to use. Or could be a tuple of parameters or an object representing something more complex.
            If set to None, forward_pass_callback will be invoked with no parameters.
    """
    dummy_input = torch.randn(1, 3, 224, 224).to(torch.device('cuda'))
    model.eval()
    with torch.no_grad():
        model(dummy_input)

Quantize and fine-tune a quantized model (QAT)

def quantsim_example():

    AimetLogger.set_level_for_all_areas(logging.INFO)
    model = models.resnet18().eval()
    model.cuda()
    input_shape = (1, 3, 224, 224)
    dummy_input = torch.randn(input_shape).cuda()

    # Prepare model for Quantization SIM. This will automate some changes required in model definition for example
    # create independent modules for torch.nn.functional and reused modules
    prepared_model = prepare_model(model)

    # Instantiate Quantization SIM. This will insert simulation nodes in the model
    quant_sim = QuantizationSimModel(prepared_model, dummy_input=dummy_input,
                                     quant_scheme=QuantScheme.post_training_tf_enhanced,
                                     default_param_bw=8, default_output_bw=8,
                                     config_file='../../TrainingExtensions/common/src/python/aimet_common/quantsim_config/'
                                                 'default_config.json')

    # Compute encodings (min, max, delta, offset) for activations and parameters. Use representative dataset
    # roughly ~1000 examples
    quant_sim.compute_encodings(evaluate, forward_pass_callback_args=None)

    # QAT - Quantization Aware Training - Fine-tune the model fore few epochs to retain accuracy using train loop
    data_loader = create_fake_data_loader(dataset_size=32, batch_size=16, image_size=input_shape[1:])
    _ = train(quant_sim.model, data_loader)

    # Export the model which saves pytorch model without any simulation nodes and saves encodings file for both
    # activations and parameters in JSON format
    quant_sim.export(path='./', filename_prefix='quantized_resnet18', dummy_input=dummy_input.cpu())