Calibration

Calibration involves determining the appropriate scale and offset parameters for the quantizers added to your model graph. While quantization parameters for weights can be precomputed, activation quantization requires passing small, representative data samples through the model to gather range statistics and identify the appropriate scale and offset parameters.

Workflow

In this example, we will load a pretrained MobileNetV2 model. Similarly, you can use any pretrained model you prefer.

QuantSim creation

Important

aimet_torch 2 is fully backward compatible with all the public APIs of aimet_torch 1.x. If you are using low-level components of QuantizationSimModel, please see Migrate to aimet_torch 2.

import torch
import torch.cuda
from tqdm import tqdm

To perform quantization simulation with aimet_torch, your model definition should adhere to specific guidelines. For example, torch.nn.functional() defined in forward pass should be changed to equivalent torch.nn.Module. For more details on model definition guidelines, please refer: PyTorch model guidelines.

from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).cuda()

To perform quantization simulation with aimet_tensorflow, your model definition must follow specific guidelines. For instance, models defined using subclassing APIs should be converted to functional APIs. For more details on model definition guidelines, please refer: TensorFlow model guidelines.

from tensorflow.keras import applications
model = applications.MobileNetV2()
import math
import os
import numpy as np
import onnx
import torch
from tqdm import tqdm
from torchvision.models import MobileNet_V2_Weights, mobilenet_v2
pt_model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)

# Modify file_path as you wish, we are using temporary directory for now
file_path = os.path.join('/tmp', f'mobilenet_v2.onnx')
torch.onnx.export(pt_model,
                  (dummy_input,),
                  file_path,
                  input_names=['input'],
                  output_names=['output'],
                  dynamic_axes={
                      'input': {0: 'batch_size'},
                      'output': {0: 'batch_size'},
                  },
                  )

# Load exported ONNX model
model = onnx.load_model(file_path)

Note

It’s recommended to apply ONNX simplification before invoking AIMET functionalities.

import onnxsim
try:
    model, _ = onnxsim.simplify(model)
except:
    print('ONNX Simplifier failed. Proceeding with unsimplified model')

Now we use AIMET to create a QuantizationSimModel. This basically means that AIMET will insert fake quantization operations in the model graph and will configure them.

from aimet_common.defs import QuantScheme
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_torch.quantsim import QuantizationSimModel

input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape).cuda()
sim = QuantizationSimModel(model,
                           dummy_input=dummy_input,
                           quant_scheme=QuantScheme.training_range_learning_with_tf_init,
                           default_param_bw=8,
                           default_output_bw=16,
                           config_file=get_path_for_per_channel_config())
from aimet_common.defs import QuantScheme
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_tensorflow.keras.quantsim import QuantizationSimModel

PARAM_BITWIDTH = 8
ACTIVATION_BITWIDTH = 16
sim = QuantizationSimModel(model,
                           quant_scheme=QuantScheme.training_range_learning_with_tf_init,
                           default_param_bw=PARAM_BITWIDTH,
                           default_output_bw=ACTIVATION_BITWIDTH,
                           config_file=get_path_for_per_channel_config())
from aimet_common.defs import QuantScheme
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_onnx.quantsim import QuantizationSimModel

PARAM_BITWIDTH = 8
ACTIVATION_BITWIDTH = 16
sim = QuantizationSimModel(model,
                           quant_scheme=QuantScheme.post_training_tf,
                           default_param_bw=PARAM_BITWIDTH,
                           default_activation_bw=ACTIVATION_BITWIDTH,
                           config_file=get_path_for_per_channel_config())

Calibration callback

Even though AIMET has added ‘quantizer’ operations to the model graph, the QuantizationSimModel object is not ready to be used yet. Before we can use the QuantizationSimModel for inference or training, we need to find appropriate scale/offset quantization parameters for each ‘quantizer’ node.

So we create a routine to pass small, representative data samples through the model. This should be fairly simple - use the existing train or validation data loader to extract some samples and pass them to the model.

In practice, for computing encodings we only need 500-1000 representative data samples.

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

def get_calibration_and_eval_data_loaders(path: str):
    transform = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    dataset = datasets.ImageFolder(path, transform=transform)

    batch_size = 64
    calibration_data_size = batch_size * 16
    eval_data_size = len(dataset) - calibration_data_size

    calibration_dataset, eval_dataset = random_split(
        dataset, [calibration_data_size, eval_data_size]
    )

    calibration_data_loader = DataLoader(calibration_dataset, batch_size=batch_size)
    eval_data_loader = DataLoader(eval_dataset, batch_size=batch_size)
    return calibration_data_loader, eval_data_loader

PATH_TO_IMAGENET = '<your_imagenet_validation_data_path>'
calibration_data_loader, eval_data_loader = get_calibration_and_eval_data_loaders(PATH_TO_IMAGENET)
from typing import Any, Optional

def pass_calibration_data(model: torch.nn.Module, forward_pass_args: Optional[Any]=None):
    """
    The User of the QuantizationSimModel API is expected to write this callback based on their dataset.
    """
    data_loader = forward_pass_args

    # batch_size (64) * num_batches (16) should be 1024
    num_batches = 16

    model.eval()
    with torch.no_grad():
        for batch, (input_data, _) in enumerate(data_loader):
            inputs_batch = input_data.to("cuda")  # labels are ignored
            model(inputs_batch)
            if batch >= num_batches:
                break
from tensorflow.keras.applications import mobilenet_v2
from tensorflow.keras import losses, metrics, optimizers, preprocessing

BATCH_SIZE = 32
imagenet_dataset = preprocessing.image_dataset_from_directory(
    directory='<your_imagenet_validation_data_path>',
    label_mode='categorical',
    image_size=(224, 224),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

imagenet_dataset = imagenet_dataset.map(
    lambda x, y: (mobilenet_v2.preprocess_input(x), y)
)

NUM_CALIBRATION_SAMPLES = 1024
calibration_dataset = imagenet_dataset.take(NUM_CALIBRATION_SAMPLES // BATCH_SIZE)
eval_dataset = imagenet_dataset.skip(NUM_CALIBRATION_SAMPLES // BATCH_SIZE)
def pass_calibration_data(model, _):
    """
    The User of the QuantizationSimModel API is expected to write this callback based on their dataset.
    """
    for inputs, _ in calibration_dataset:
        _ = model(inputs)
from datasets import load_dataset
from aimet_onnx.defs import DataLoader
from torchvision import transforms

dataset = load_dataset('ILSVRC/imagenet-1k', split='validation').shuffle()

class CustomDataLoader(DataLoader):
    def __init__(
        self,
        data: np.ndarray,
        batch_size: int,
        iterations: int,
        unlabeled: bool = True,
    ):
        super().__init__(data, batch_size, iterations)
        self._current_iteration = 0
        self._unlabeled = unlabeled

    def __iter__(self):
        self._current_iteration = 0
        return self

    def __next__(self):
        if self._current_iteration < self.iterations:
            start = self._current_iteration * self.batch_size
            end = start + self.batch_size
            self._current_iteration += 1

            batch_data = self._data[start:end]
            if self._unlabeled:
                return np.stack(batch_data['image'])
            else:
                return np.stack(batch_data['image']), np.stack(batch_data['label'])
        else:
            raise StopIteration

preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

def transforms(examples):
    examples['image'] = [
        preprocess(image.convert('RGB')) for image in examples['image']
    ]
    return examples

dataset.set_transform(transforms)

BATCH_SIZE = 32
NUM_CALIBRATION_SAMPLES = 1024
NUM_EVAL_SAMPLES = 50000
calibration_data_loader = CustomDataLoader(dataset, BATCH_SIZE, math.ceil(NUM_CALIBRATION_SAMPLES / BATCH_SIZE))
eval_data_loader = CustomDataLoader(dataset, BATCH_SIZE, math.ceil(NUM_EVAL_SAMPLES / BATCH_SIZE), unlabeled=False)
import onnxruntime as ort

def pass_calibration_data(session: ort.InferenceSession, _):
    """
    The User of the QuantizationSimModel API is expected to write this callback based on their dataset.
    """
    input_name = session.get_inputs()[0].name
    for inputs in tqdm(calibration_data_loader):
        session.run(None, {input_name: inputs})

Compute encodings

Now we call QuantizationSimModel.compute_encodings() to use the above callback to pass small, representative data through the quantized model. By doing so, the quantizers in the quantized model will observe the inputs and initialize their quantization encodings according to the observed input statistics. Encodings here refer to scale/offset quantization parameters.

sim.compute_encodings(pass_calibration_data, forward_pass_callback_args=calibration_data_loader)
sim.compute_encodings(pass_calibration_data, forward_pass_callback_args=None)
sim.compute_encodings(pass_calibration_data, forward_pass_callback_args=None)

Evaluation

Next, we evaluate the QuantizationSimModel to get quantized accuracy.

# Determine simulated quantized accuracy
sim.model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in tqdm(eval_data_loader):
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        outputs = sim.model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Accuracy: {correct / total:.4f}')
# Determine simulated quantized accuracy
sim.model.compile(
    optimizer=optimizers.SGD(1e-6),
    loss=[losses.CategoricalCrossentropy()],
    metrics=[metrics.CategoricalAccuracy()],
)

_, accuracy = sim.model.evaluate(eval_dataset)
print(f'Quantized accuracy (W{PARAM_BITWIDTH}A{ACTIVATION_BITWIDTH}): {accuracy:.4f}')
Quantized accuracy (W8A16): 0.7013
correct_predictions = 0
total_samples = 0
for inputs, labels in tqdm(eval_data_loader):
    input_name = sim.session.get_inputs()[0].name
    pred_probs, *_ = sim.session.run(None, {input_name: inputs})
    pred_labels = np.argmax(pred_probs, axis=1)
    correct_predictions += np.sum(pred_labels == labels)
    total_samples += labels.shape[0]

accuracy = correct_predictions / total_samples
print(f'Quantized accuracy (W{PARAM_BITWIDTH}A{ACTIVATION_BITWIDTH}): {accuracy:.4f}')
Quantized accuracy (W8A16): 0.7173

Export

Lastly, export a version of the model with quantization operations removed and an encodings JSON file with quantization scale and offset parameters for the model’s activation and weight tensors.

# Export the model for on-target inference.
# Export the model which saves pytorch model without any simulation nodes and saves encodings file for both
# activations and parameters in JSON format at provided path.
sim.export(path='/tmp', filename_prefix='quantized_mobilenet_v2', dummy_input=dummy_input.cpu())
# Export the model for on-target inference.
# Export the model which saves TensorFlow model without any simulation nodes and saves encodings file for both
# activations and parameters in JSON format at provided path.
sim.export(path='/tmp', filename_prefix='quantized_mobilenet_v2')
# Export the model for on-target inference.
# Export the model which saves ONNX model without any simulation nodes and saves encodings file for both
# activations and parameters in JSON format at provided path.
sim.export(path='/tmp', filename_prefix='quantized_mobilenet_v2')

API

Top level APIs

class aimet_torch.quantsim.QuantizationSimModel(model, dummy_input, quant_scheme=None, rounding_mode=None, default_output_bw=8, default_param_bw=8, in_place=False, config_file=None, default_data_type=QuantizationDataType.int)[source]

Class that simulates the quantized model execution on a target hardware backend.

QuantizationSimModel simulates quantization of a given model by converting all PyTorch modules into quantized modules with input/output/parameter quantizers as necessary.

Example

>>> model = torchvision.models.resnet18()
>>> dummy_input = torch.randn(1, 3, 224, 224)
>>> sim = QuantizationSimModel(model, dummy_input)
>>> print(model)
ResNet(
  (conv1): Conv2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
  )
  ...
)
>>> print(sim.model)
ResNet(
  (conv1): QuantizedConv2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    (param_quantizers): ModuleDict(
      (weight): QuantizeDequantize(shape=(), qmin=-128, qmax=127, symmetric=True)
    )
    (input_quantizers): ModuleList(
      (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
    )
    (output_quantizers): ModuleList(
      (0): None
    )
  )
  ...
)

Warning

rounding_mode parameter is deprecated. Passing rounding_mode will throw runtime error in >=1.35.

Warning

The default value of quant_scheme will change from QuantScheme.post_training_tf_enhanced to QuantScheme.training_range_learning_with_tf_init in the future versions, and will be deprecated in the longer term.

Parameters:
  • model (torch.nn.Module) – Model to simulate the quantized execution of

  • dummy_input (Tensor | Sequence[Tensor]) – Dummy input to be used to capture the computational graph of the model. All input tensors are expected to be already placed on the appropriate devices to run forward pass of the model.

  • quant_scheme (QuantScheme, optional) – Quantization scheme that indicates how to observe and calibrate the quantization encodings (Default: QuantScheme.post_training_tf_enhanced)

  • rounding_mode – Deprecated

  • default_output_bw (int, optional) – Default bitwidth (4-31) to use for quantizing all layer inputs and outputs unless otherwise specified in the config file. (Default: 8)

  • default_param_bw (int, optional) – Default bitwidth (4-31) to use for quantizing all layer parameters unless otherwise specified in the config file. (Default: 8)

  • in_place (bool, optional) – If True, then the given model is modified in-place into a quantized model. (Default: False)

  • config_file (str, optional) – Path to the quantization simulation config file (Default: None)

  • default_data_type (QuantizationDataType, optional) – Default data type to use for quantizing all inputs, outputs and parameters unless otherwise specified in the config file. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 or 32 and default_param_bw=16 or 32. (Default: QuantizationDataType.int)

compute_encodings(forward_pass_callback, forward_pass_callback_args=<class 'aimet_torch.v2.quantsim.quantsim._NOT_SPECIFIED'>)[source]

Computes encodings for all quantizers in the model.

This API will invoke forward_pass_callback, a function written by the user that runs forward pass(es) of the quantized model with a small, representative subset of the training dataset. By doing so, the quantizers in the quantized model will observe the inputs and initialize their quantization encodings according to the observed input statistics.

This function is overloaded with the following signatures:

compute_encodings(forward_pass_callback)[source]
Parameters:

forward_pass_callback (Callable[[torch.nn.Module], Any]) – A function that takes a quantized model and runs forward passes with a small, representative subset of training dataset

compute_encodings(forward_pass_callback, forward_pass_callback_args)[source]
Parameters:
  • forward_pass_callback (Callable[[torch.nn.Module, T], Any]) – A function that takes a quantized model and runs forward passes with a small, representative subset of training dataset

  • forward_pass_callback_args (T) – The second argument to forward_pass_callback.

Example

>>> sim = QuantizationSimModel(...)
>>> _ = sim.model(input) # Can't run forward until quantizer encodings are initialized
RuntimeError: Failed to run QuantizeDequantize since quantization parameters are not initialized.
Please initialize the quantization parameters using `compute_encodings()`.
>>> def run_forward_pass(quantized_model: torch.nn.Module):
...     for input in train_dataloader:
...         with torch.no_grad():
...             _ = quantized_model(input)
...
>>> sim.compute_encodings(run_forward_pass)
>>> _ = sim.model(input) # Now runs successfully!
export(path, filename_prefix, dummy_input, *args, **kwargs)[source]

This method exports out the quant-sim model so it is ready to be run on-target.

Specifically, the following are saved:

  1. The sim-model is exported to a regular PyTorch model without any simulation ops

  2. The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)

  3. Optionally, An equivalent model in ONNX format is exported. In addition, nodes in the ONNX model are named the same as the corresponding PyTorch module names. This helps with matching ONNX node to their quant encoding from #2.

Parameters:
  • path (str) – path where to store model pth and encodings

  • filename_prefix (str) – Prefix to use for filenames of the model pth and encodings files

  • dummy_input (Union[Tensor, Tuple]) – Dummy input to the model. Used to parse model graph. It is required for the dummy_input to be placed on CPU.

  • onnx_export_args – Optional export argument with onnx specific overrides provided as a dictionary or OnnxExportApiArgs object. If not provided, defaults to “opset_version” = None, “input_names” = None, “output_names” = None, and for torch version < 1.10.0, “enable_onnx_checker” = False.

  • propagate_encodings – If True, encoding entries for intermediate ops (when one PyTorch ops results in multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of ops. Defaults to False.

  • export_to_torchscript – If True, export to torchscript. Export to onnx otherwise. Defaults to False.

  • use_embedded_encodings – If True, another onnx model embedded with fakequant nodes will be exported

  • export_model – If True, then ONNX model is exported. When False, only encodings are exported. User should disable (False) this flag only if the corresponding ONNX model already exists in the path specified

  • filename_prefix_encodings – File name prefix to be used when saving encodings. If None, then user defaults to filename_prefix value

load_encodings(encodings, strict=True, partial=True, requires_grad=None, allow_overwrite=True)
Parameters:
  • encodings (Union[Mapping, str, PathLike]) – Encoding dictionary or path to the encoding dictionary json file.

  • strict (bool) – If True, an error will be thrown if the model doesn’t have a quantizer corresponding to the specified encodings.

  • partial (bool) – If True, the encoding will be interpreted as a partial encoding, and the dangling quantizers with no corresponding encoding will be kept untouched. Otherwise, the dangling quantizers will be removed from the model.

  • requires_grad (bool) – Whether or not the quantization parameters loaded from the encodings require gradient computation during training. If None, requires_grad flag of the quantization parameters will be kept unchanged.

  • allow_overwrite (bool) – Whether or not the quantization parameters loaded from the encodings can be overwriiten by compute_encodings or another load_encodings. If None, whether the quantizer is overwrieable will be kept unchanged.

Quant Scheme Enum

class aimet_common.defs.QuantScheme(value)[source]

Enumeration of Quant schemes

post_training_percentile = 6

For a Tensor, adjusted minimum and maximum values are selected based on the percentile value passed. The Quantization encodings are calculated using the adjusted minimum and maximum value.

post_training_tf = 1

For a Tensor, the absolute minimum and maximum value of the Tensor are used to compute the Quantization encodings.

post_training_tf_enhanced = 2

For a Tensor, searches and selects the optimal minimum and maximum value that minimizes the Quantization Noise. The Quantization encodings are calculated using the selected minimum and maximum value.

training_range_learning_with_tf_enhanced_init = 4

For a Tensor, the encoding values are initialized with the post_training_tf_enhanced scheme. Then, the encodings are learned during training.

training_range_learning_with_tf_init = 3

For a Tensor, the encoding values are initialized with the post_training_tf scheme. Then, the encodings are learned during training.

Top level APIs

class aimet_tensorflow.keras.quantsim.QuantizationSimModel(model, quant_scheme='tf_enhanced', rounding_mode='nearest', default_output_bw=8, default_param_bw=8, in_place=False, config_file=None, default_data_type=QuantizationDataType.int)[source]

Implements mechanism to add quantization simulations ops to a model. This allows for off-target simulation of inference accuracy. Also allows the model to be fine-tuned to counter the effects of quantization.

Parameters:
  • model – Model to quantize

  • quant_scheme (Union[QuantScheme, str]) – Quantization Scheme, currently supported schemes are post_training_tf and post_training_tf_enhanced, defaults to post_training_tf_enhanced

  • rounding_mode (str) – The round scheme to used. One of: ‘nearest’ or ‘stochastic’, defaults to ‘nearest’.

  • default_output_bw (int) – bitwidth to use for activation tensors, defaults to 8

  • default_param_bw (int) – bitwidth to use for parameter tensors, defaults to 8

  • in_place (bool) – If True, then the given ‘model’ is modified in-place to add quant-sim nodes. Only suggested use of this option is when the user wants to avoid creating a copy of the model

  • config_file (Optional[str]) – Path to a config file to use to specify rules for placing quant ops in the model

  • default_data_type (QuantizationDataType) – Default data type to use for quantizing all layer parameters. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 and default_param_bw=16

compute_encodings(forward_pass_callback, forward_pass_callback_args)[source]

Computes encodings for all quantization sim nodes in the model.

Parameters:
  • forward_pass_callback – A callback function that is expected to run forward passes on a model. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples.

  • forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex.

export(path, filename_prefix, custom_objects=None, convert_to_pb=True)[source]

This method exports out the quant-sim model so it is ready to be run on-target. Specifically, the following are saved

  1. The sim-model is exported to a regular Keras model without any simulation ops

  2. The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)

Parameters:
  • path – path where to store model pth and encodings

  • filename_prefix – Prefix to use for filenames of the model pth and encodings files

  • custom_objects – If there are custom objects to load, Keras needs a dict of them to map them

load_encodings_to_sim(encoding_file_path)[source]

Loads the saved encodings to quant sim model

Parameters:

encoding_file_path (str) – path from where to load encodings file

Returns:

Quant Scheme Enum

class aimet_common.defs.QuantScheme(value)[source]

Enumeration of Quant schemes

post_training_percentile = 6

For a Tensor, adjusted minimum and maximum values are selected based on the percentile value passed. The Quantization encodings are calculated using the adjusted minimum and maximum value.

post_training_tf = 1

For a Tensor, the absolute minimum and maximum value of the Tensor are used to compute the Quantization encodings.

post_training_tf_enhanced = 2

For a Tensor, searches and selects the optimal minimum and maximum value that minimizes the Quantization Noise. The Quantization encodings are calculated using the selected minimum and maximum value.

training_range_learning_with_tf_enhanced_init = 4

For a Tensor, the encoding values are initialized with the post_training_tf_enhanced scheme. Then, the encodings are learned during training.

training_range_learning_with_tf_init = 3

For a Tensor, the encoding values are initialized with the post_training_tf scheme. Then, the encodings are learned during training.

Top level APIs

class aimet_onnx.quantsim.QuantizationSimModel(model, dummy_input=None, quant_scheme=QuantScheme.post_training_tf_enhanced, rounding_mode='nearest', default_param_bw=8, default_activation_bw=8, use_symmetric_encodings=False, use_cuda=True, device=0, config_file=None, default_data_type=QuantizationDataType.int, user_onnx_libs=None, path=None)[source]

Creates a QuantizationSimModel model by adding quantization simulations ops to a given model

Constructor

Parameters:
  • model (ModelProto) – ONNX model

  • dummy_input (Optional[Dict[str, ndarray]]) – Dummy input to the model. If None, will attempt to auto-generate a dummy input

  • quant_scheme (QuantScheme) – Quantization scheme (e.g. QuantScheme.post_training_tf)

  • rounding_mode (str) – Rounding mode (e.g. nearest)

  • default_param_bw (int) – Quantization bitwidth for parameter

  • default_activation_bw (int) – Quantization bitwidth for activation

  • use_symmetric_encodings (bool) – True if symmetric encoding is used. False otherwise.

  • use_cuda (bool) – True if using CUDA to run quantization op. False otherwise.

  • config_file (Optional[str]) – Path to Configuration file for model quantizers

  • default_data_type (QuantizationDataType) – Default data type to use for quantizing all layer inputs, outputs and parameters. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 and default_param_bw=16

  • user_onnx_libs (Optional[List[str]]) – List of paths to all compiled ONNX custom ops libraries

  • path (Optional[str]) – Directory to save the artifacts.

compute_encodings(forward_pass_callback, forward_pass_callback_args)[source]

Compute and return the encodings of each tensor quantizer

Parameters:
  • forward_pass_callback – A callback function that simply runs forward passes on the model. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. This callback internally chooses the number of data samples it wants to use for calculating encodings.

  • forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex. If set to None, forward_pass_callback will be invoked with no parameters.

export(path, filename_prefix)[source]

Compute encodings and export to files

Parameters:
  • path (str) – dir to save encoding files

  • filename_prefix (str) – filename to save encoding files

Note

  • It is recommended to use onnx-simplifier before creating quantsim model.

  • Since ONNX Runtime will be used for optimized inference only, ONNX framework will support Post Training Quantization schemes i.e. TF or TF-enhanced to compute the encodings.

aimet_onnx.quantsim.load_encodings_to_sim(quant_sim_model, onnx_encoding_path, strict=True)[source]

Loads the saved encodings to quant sim model. The encoding filename to load should end in .encodings, generated as part of quantsim export.

Parameters:
  • quant_sim_model (QuantizationSimModel) – Quantized model to load encodings for. Note: The model configuration should be the same as when encodings were exported.

  • onnx_encoding_path (str) – Path of the encodings file to load.

  • strict – If set to True and encoding settings between encodings to load do not line up with Quantsim initialized settings, an assertion will be thrown. If set to False, quantizer settings will update to align with encodings to load.

Return type:

List[EncodingMismatchInfo]

Returns:

List of EncodingMismatchInfo objects containing quantizer names and mismatched settings

Quant Scheme Enum

class aimet_common.defs.QuantScheme(value)[source]

Enumeration of Quant schemes

post_training_percentile = 6

For a Tensor, adjusted minimum and maximum values are selected based on the percentile value passed. The Quantization encodings are calculated using the adjusted minimum and maximum value.

post_training_tf = 1

For a Tensor, the absolute minimum and maximum value of the Tensor are used to compute the Quantization encodings.

post_training_tf_enhanced = 2

For a Tensor, searches and selects the optimal minimum and maximum value that minimizes the Quantization Noise. The Quantization encodings are calculated using the selected minimum and maximum value.

training_range_learning_with_tf_enhanced_init = 4

For a Tensor, the encoding values are initialized with the post_training_tf_enhanced scheme. Then, the encodings are learned during training.

training_range_learning_with_tf_init = 3

For a Tensor, the encoding values are initialized with the post_training_tf scheme. Then, the encodings are learned during training.