AIMET PyTorch Quantization SIM API

Top-level API

class aimet_torch.quantsim.QuantizationSimModel(model, input_shapes, quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, rounding_mode='nearest', default_output_bw=8, default_param_bw=8, in_place=False, config_file=None)

Implements mechanism to add quantization simulations ops to a model. This allows for off-target simulation of inference accuracy. Also allows the model to be fine-tuned to counter the effects of quantization.


  • model (Module) – Model to add simulation ops to

  • input_shapes (Union[Tuple, List[Tuple]]) – List of input shapes to the model

  • quant_scheme (Union[str, QuantScheme]) – Quantization scheme. Supported options are ‘tf_enhanced’ or ‘tf’ or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced

  • rounding_mode (str) – Rounding mode. Supported options are ‘nearest’ or ‘stochastic’

  • default_output_bw (int) – Default bitwidth (4-31) to use for quantizing layer inputs and outputs

  • default_param_bw (int) – Default bitwidth (4-31) to use for quantizing layer parameters

  • in_place (bool) – If True, then the given ‘model’ is modified in-place to add quant-sim nodes. Only suggested use of this option is when the user wants to avoid creating a copy of the model

  • config_file (Optional[str]) – Configuration file for model quantizers

The following API can be used to Compute Encodings for Model

QuantizationSimModel.compute_encodings(forward_pass_callback, forward_pass_callback_args)

Computes encodings for all quantization sim nodes in the model. It is also used to find initial encodings for Range Learning

  • forward_pass_callback – A callback function that simply runs forward passes on the model. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. This callback internally chooses the number of data samples it wants to use for calculating encodings.

  • forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex. If set to None, forward_pass_callback will be invoked with no parameters.



The following APIs can be used to save and restore the quantized model


This API provides a way for the user to save a checkpoint of the quantized model which can be loaded at a later point to continue fine-tuning e.g. See also load_checkpoint()

  • quant_sim_model (QuantizationSimModel) – QuantizationSimModel to save checkpoint for

  • file_path (str) – Path to the file where you want to save the checkpoint




Load the quantized model


file_path (str) – Path to the file where you want to save the checkpoint

Return type



A new instance of the QUantizationSimModel created after loading the checkpoint

The following API can be used to Export the Model to target

QuantizationSimModel.export(path, filename_prefix, input_shape, set_onnx_layer_names=True)

This method exports out the quant-sim model so it is ready to be run on-target.

Specifically, the following are saved

  1. The sim-model is exported to a regular PyTorch model without any simulation ops

  2. The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)

  3. An equivalent model in ONNX format is exported. In addition, nodes in the ONNX model are named the same as the corresponding PyTorch module names. This helps with matching ONNX node to their quant encoding from #2.

  • path (str) – path where to store model pth and encodings

  • filename_prefix (str) – Prefix to use for filenames of the model pth and encodings files

  • input_shape (Union[Tuple, List[Tuple]]) – shape of the model input as a tuple. If the model takes more than one input, specify this as a list of shapes.

  • set_onnx_layer_names (bool) – If ONNX layer names should be set while exporting the model. Default is True



Enum Definition

Quant Scheme Enum

class aimet_common.defs.QuantScheme

Enumeration of Quant schemes

post_training_tf = 1

Tf scheme

post_training_tf_enhanced = 2

Tf- enhanced scheme

Code Examples

Required imports

import torch
from aimet_torch.examples import mnist_torch_model
# Quantization related import
from aimet_torch.quantsim import QuantizationSimModel

Evaluation function

def evaluate_model(model: torch.nn.Module, eval_iterations: int, use_cuda: bool = False) -> float:
    This is intended to be the user-defined model evaluation function.
    AIMET requires the above signature. So if the user's eval function does not
    match this signature, please create a simple wrapper.

    Note: Honoring the number of iterations is not absolutely necessary.
    However if all evaluations run over an entire epoch of validation data,
    the runtime for AIMET compression will obviously be higher.

    :param model: Model to evaluate
    :param eval_iterations: Number of iterations to use for evaluation.
            None for entire epoch.
    :param use_cuda: If true, evaluate using gpu acceleration
    :return: single float number (accuracy) representing model's performance
    return .5

Quantize and fine-tune a trained model

def quantize_model(trainer_function):

    model = mnist_torch_model.Net().to(torch.device('cuda'))
    input_shape = (1, 1, 28, 28)

    sim = QuantizationSimModel(model, default_output_bw=8, default_param_bw=8, input_shapes=input_shape,

    # Quantize the untrained MNIST model
    sim.compute_encodings(forward_pass_callback=evaluate_model, forward_pass_callback_args=5)

    # Fine-tune the model's parameter using training
    trainer_function(model=sim.model, epochs=1, num_batches=100, use_cuda=True)

    # Export the model
    sim.export(path='./', filename_prefix='quantized_mnist', input_shape=input_shape)