AIMET ONNX Quantization SIM API

Top-level API

class aimet_onnx.quantsim.QuantizationSimModel(model, dummy_input=None, quant_scheme=QuantScheme.post_training_tf_enhanced, rounding_mode='nearest', default_param_bw=8, default_activation_bw=8, use_symmetric_encodings=False, use_cuda=True, device=0, config_file=None, default_data_type=QuantizationDataType.int, simplify_model=True, user_onnx_libs=None, path=None)[source]

Creates a QuantizationSimModel model by adding quantization simulations ops to a given model

Constructor

Parameters:
  • model (ModelProto) – ONNX model or path to model

  • dummy_input (Optional[Dict[str, ndarray]]) – Dummy input to the model. If None, will attempt to auto-generate a dummy input

  • quant_scheme (QuantScheme) – Quantization scheme (e.g. QuantScheme.post_training_tf)

  • rounding_mode (str) – Rounding mode (e.g. nearest)

  • default_param_bw (int) – Quantization bitwidth for parameter

  • default_activation_bw (int) – Quantization bitwidth for activation

  • use_symmetric_encodings (bool) – True if symmetric encoding is used. False otherwise.

  • use_cuda (bool) – True if using CUDA to run quantization op. False otherwise.

  • config_file (Optional[str]) – Path to Configuration file for model quantizers

  • default_data_type (QuantizationDataType) – Default data type to use for quantizing all layer inputs, outputs and parameters. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 and default_param_bw=16

  • simplify_model (bool) – Default True, uses onnx simplifier to simplify model

  • user_onnx_libs (Optional[List[str]]) – List of paths to all compiled ONNX custom ops libraries

  • path (Optional[str]) – Directory to save the artifacts.


Note about Quantization Schemes : Since ONNX Runtime will be used for optimized inference only, ONNX framework will support Post Training Quantization schemes i.e. TF or TF-enhanced to compute the encodings.

The following API can be used to Compute Encodings for Model

QuantizationSimModel.compute_encodings(forward_pass_callback, forward_pass_callback_args)[source]

Compute and return the encodings of each tensor quantizer

Parameters:
  • forward_pass_callback – A callback function that simply runs forward passes on the model. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. This callback internally chooses the number of data samples it wants to use for calculating encodings.

  • forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex. If set to None, forward_pass_callback will be invoked with no parameters.


The following API can be used to Export the Model to target

QuantizationSimModel.export(path, filename_prefix)[source]

Compute encodings and export to files

Parameters:
  • path (str) – dir to save encoding files

  • filename_prefix (str) – filename to save encoding files


Code Examples

Required imports

from aimet_onnx.quantsim import QuantizationSimModel
from aimet_common.defs import QuantScheme
import numpy as np

User should write this function to pass calibration data

def pass_calibration_data(session):
    """
    The User of the QuantizationSimModel API is expected to write this function based on their data set.
    This is not a working function and is provided only as a guideline.

    :param session: Model's session
    :return:
    """

    # User action required
    # The following line of code is an example of how to use the ImageNet data's validation data loader.
    # Replace the following line with your own dataset's validation data loader.
    data_loader = None  # Your Dataset's data loader

    # User action required
    # For computing the activation encodings, around 1000 unlabelled data samples are required.
    # Edit the following 2 lines based on your dataloader's batch size.
    # batch_size * max_batch_counter should be 1024
    batch_size = 64
    max_batch_counter = 16

    input_tensor = None  # input tensor in session

    current_batch_counter = 0
    for input_data, _ in data_loader:
        session.run(None, input_data)

        current_batch_counter += 1
        if current_batch_counter == max_batch_counter:
            break

Quantize the model and finetune (QAT)

def quantize_model():
    onnx_model = Model()
    input_shape = (1, 3, 224, 224)
    dummy_data = np.random.randn(*input_shape).astype(np.float32)
    dummy_input = {'input' : dummy_data}
    sim = QuantizationSimModel(onnx_model, dummy_input, quant_scheme=QuantScheme.post_training_tf,
                               rounding_mode='nearest', default_param_bw=8, default_activation_bw=8,
                               use_symmetric_encodings=False, use_cuda=False)

    sim.compute_encodings(pass_calibration_data, None)

    # Evaluate the quant sim
    forward_pass_function(sim.session)