AIMET ONNX Quantization SIM API

Top-level API


Note about Quantization Schemes : Since ONNX Runtime will be used for optimized inference only, ONNX framework will support Post Training Quantization schemes i.e. TF or TF-enhanced to compute the encodings.

The following API can be used to Compute Encodings for Model


The following API can be used to Export the Model to target


Code Examples

Required imports

from aimet_onnx.quantsim import QuantizationSimModel
from aimet_common.defs import QuantScheme
import numpy as np

User should write this function to pass calibration data

def pass_calibration_data(session):
    """
    The User of the QuantizationSimModel API is expected to write this function based on their data set.
    This is not a working function and is provided only as a guideline.

    :param session: Model's session
    :return:
    """

    # User action required
    # The following line of code is an example of how to use the ImageNet data's validation data loader.
    # Replace the following line with your own dataset's validation data loader.
    data_loader = None  # Your Dataset's data loader

    # User action required
    # For computing the activation encodings, around 1000 unlabelled data samples are required.
    # Edit the following 2 lines based on your dataloader's batch size.
    # batch_size * max_batch_counter should be 1024
    batch_size = 64
    max_batch_counter = 16

    input_tensor = None  # input tensor in session

    current_batch_counter = 0
    for input_data, _ in data_loader:
        session.run(None, input_data)

        current_batch_counter += 1
        if current_batch_counter == max_batch_counter:
            break

Quantize the model and finetune (QAT)

def quantize_model():
    onnx_model = Model()
    input_shape = (1, 3, 224, 224)
    dummy_data = np.random.randn(*input_shape).astype(np.float32)
    dummy_input = {'input' : dummy_data}
    sim = QuantizationSimModel(onnx_model, dummy_input, quant_scheme=QuantScheme.post_training_tf,
                               rounding_mode='nearest', default_param_bw=8, default_activation_bw=8,
                               use_symmetric_encodings=False, use_cuda=False)

    sim.compute_encodings(pass_calibration_data, None)

    # Evaluate the quant sim
    forward_pass_function(sim.session)