Quantization-aware training

Quantization-aware training (QAT) finds better optimized solutions than post-training quantization (PTQ) by fine-tuning the model parameters in the presence of quantization noise. This higher accuracy comes with the usual costs of neural network training, including longer training times and the need for labeled data and hyperparameter search.

Variants of QAT

There are two main variants of QAT: without range learning and with range learning.

Without range learning

In this approach, activation quantization parameters (such as scale and offset) are not learnable and remain fixed throughout training.

With range learning

In this approach, activation quantization parameters are treated as learnable parameters. While this dynamic adjustment can further reduce quantization noise, this advantage comes at the cost of significantly increased memory usage.

In both variants, quantization parameters for model parameters (such as weight) are treated as learnable.

Typical recommendations

  • Initialization: Apply PTQ techniques (such as Sequential MSE) before starting QAT.

    This is especially more important if there is a large drop in INT8 performance compared to the FP baseline. QAT is a fine-tuning technique that relies on a reasonably well-performing quantized model as a starting point. Without a solid baseline, its benefit tends to be limited.

  • Learning rate: Use a small learning rate.

    Start with a small learning rate, and reduce it by a factor of 10 every few epochs. The main goal of QAT is fine-tuning. Since quantization parameters are often sensitive to even minor updates, a small learning rate is typically recommended for stable convergence.

  • Target layers for QAT: Whenever possible, apply QAT selectively to layers that are sensitive to quantization.

    Applying QAT to all layers is not only memory-intensive but can also negatively impact convergence. Quantization parameters that were already near-optimal may drift away from the optimum during QAT. For instance, INT16 quantization typically does not require QAT due to its high precision. In constrast, lower-bit quantization formats such as INT8 or INT4 are more likely to benefit from QAT, as they are more susceptible to quantization noise

Workflow

Step 1: Setup

Set up the model, data loader, and callback functions.

import itertools
import torch
import torchvision
from tqdm import tqdm
from aimet_torch.batch_norm_fold import fold_all_batch_norms

# General setup that can be changed as needed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = torchvision.models.mobilenet_v2(pretrained=True).eval().to(device)

PATH_TO_IMAGENET = ...
data = torchvision.datasets.ImageNet(PATH_TO_IMAGENET, split="train")
data_loader = torch.utils.data.DataLoader(data, batch_size=64)

dummy_input = torch.randn(1, 3, 224, 224, device=device)
fold_all_batch_norms(model, dummy_input.shape)

@torch.no_grad()
def pass_calibration_data(model: torch.nn.Module):
    # Pass N batches of calibration data through the model
    for images, _ in itertools.islice(data_loader, 10):
        _ = model(images.to(device))

@torch.no_grad()
def evaluate(model, data_loader):
    # Basic ImageNet evaluation function
    correct = 0
    for data, labels in tqdm(data_loader):
        data, labels = data.to(device), labels.to(device)
        logits = model(data)
        correct += (logits.argmax(1) == labels).sum().item()
    return correct / len(data_loader.dataset)

from aimet_common.defs import QuantScheme
from aimet_tensorflow.keras.quantsim import QuantizationSimModel
from tensorflow.keras import applications, losses, metrics, optimizers, preprocessing
from tensorflow.keras.applications import mobilenet_v2

model = applications.MobileNetV2()

# Set up dataset
BATCH_SIZE = 32
imagenet_dataset = preprocessing.image_dataset_from_directory(
    directory='<your_imagenet_validation_data_path>',
    labels='inferred',
    label_mode='categorical',
    image_size=(224, 224),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

imagenet_dataset = imagenet_dataset.map(
    lambda x, y: (mobilenet_v2.preprocess_input(x), y)
)

NUM_CALIBRATION_SAMPLES = 2048
calibration_dataset = imagenet_dataset.take(NUM_CALIBRATION_SAMPLES // BATCH_SIZE)
eval_dataset = imagenet_dataset.skip(NUM_CALIBRATION_SAMPLES // BATCH_SIZE)

def pass_calibration_data(model, _):
    """
    The User of the QuantizationSimModel API is expected to write this callback based on their dataset.
    """
    for inputs, _ in calibration_dataset:
        model(inputs)

Not supported.

Step 2: Compute initial quantization parameters

Compute initial quantization parameters and evaluate accuracy.

from aimet_torch.quantsim import QuantizationSimModel
sim = QuantizationSimModel(model, dummy_input)
sim.compute_encodings(pass_calibration_data)

accuracy = evaluate(sim.model.eval(), data_loader)
print(f"Quantized accuracy (W8A8): {accuracy}")
Quantized accuracy (W8A8): 0.68016
PARAM_BITWIDTH = 8
ACTIVATION_BITWIDTH = 8
QUANT_SCHEME = QuantScheme.training_range_learning_with_tf_init
sim = QuantizationSimModel(
    model,
    quant_scheme=QUANT_SCHEME,
    default_param_bw=PARAM_BITWIDTH,
    default_output_bw=ACTIVATION_BITWIDTH,
)
sim.compute_encodings(pass_calibration_data, None)
sim.model.compile(
    optimizer=optimizers.SGD(learning_rate=1e-5),
    loss=[losses.CategoricalCrossentropy()],
    metrics=[metrics.CategoricalAccuracy()],
)
_, accuracy = sim.model.evaluate(eval_dataset)
print(f'Quantized accuracy (W8A8): {accuracy:.4f}')
Quantized accuracy (W8A8): 0.6583

Not supported.

Step 3: Run quantization-aware training

Train the model to fine-tune quantization parameters.

# Training loop can be replaced with any custom training loop
def train(model, data_loader, optimizer, loss_fn, num_epochs):
    for _ in range(num_epochs):
        for data, labels in tqdm(data_loader):
            data, labels = data.to(device), labels.to(device)
            logits = model(data)
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(sim.model.parameters(), lr=1e-5)
train(sim.model.train(), data_loader, optimizer, loss_fn, num_epochs=2)
accuracy = evaluate(sim.model.eval(), data_loader)
print(f"Model accuracy after QAT: {accuracy}")
Model accuracy after QAT: 0.70838
sim.model.fit(calibration_dataset, epochs=10)
_, accuracy = sim.model.evaluate(eval_dataset)
print(f'Model accuracy after QAT: {accuracy:.4f}')
Model accuracy after QAT: 0.6910

Not supported.

API

Top level APIs

class aimet_torch.QuantizationSimModel(model, dummy_input, quant_scheme=None, rounding_mode=None, default_output_bw=8, default_param_bw=8, in_place=False, config_file=None, default_data_type=QuantizationDataType.int)[source]

Class that simulates the quantized model execution on a target hardware backend.

QuantizationSimModel simulates quantization of a given model by converting all PyTorch modules into quantized modules with input/output/parameter quantizers as necessary.

Example

>>> model = torchvision.models.resnet18()
>>> dummy_input = torch.randn(1, 3, 224, 224)
>>> sim = QuantizationSimModel(model, dummy_input)
>>> print(model)
ResNet(
  (conv1): Conv2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
  )
  ...
)
>>> print(sim.model)
ResNet(
  (conv1): QuantizedConv2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    (param_quantizers): ModuleDict(
      (weight): QuantizeDequantize(shape=(), qmin=-128, qmax=127, symmetric=True)
    )
    (input_quantizers): ModuleList(
      (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
    )
    (output_quantizers): ModuleList(
      (0): None
    )
  )
  ...
)

Warning

rounding_mode parameter is deprecated. Passing rounding_mode will throw runtime error in >=1.35.

Warning

The default value of quant_scheme has changed from QuantScheme.post_training_tf_enhanced to QuantScheme.training_range_learning_with_tf_init since 2.0.0, and will be deprecated in the longer term.

Parameters:
  • model (torch.nn.Module) – Model to simulate the quantized execution of

  • dummy_input (Tensor | Sequence[Tensor]) – Dummy input to be used to capture the computational graph of the model. All input tensors are expected to be already placed on the appropriate devices to run forward pass of the model.

  • quant_scheme (QuantScheme, optional) – Quantization scheme that indicates how to observe and calibrate the quantization encodings (Default: QuantScheme.post_training_tf_enhanced)

  • rounding_mode – Deprecated

  • default_output_bw (int, optional) – Default bitwidth (4-31) to use for quantizing all layer inputs and outputs unless otherwise specified in the config file. (Default: 8)

  • default_param_bw (int, optional) – Default bitwidth (4-31) to use for quantizing all layer parameters unless otherwise specified in the config file. (Default: 8)

  • in_place (bool, optional) – If True, then the given model is modified in-place into a quantized model. (Default: False)

  • config_file (str, optional) – File path or alias of the configuration file. Alias can be one of { default, htp_v66, htp_v68, htp_v69, htp_v73, htp_v75, htp_v79, htp_v81 } (Default: “default”)

  • default_data_type (QuantizationDataType, optional) – Default data type to use for quantizing all inputs, outputs and parameters unless otherwise specified in the config file. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 or 32 and default_param_bw=16 or 32. (Default: QuantizationDataType.int)

compute_encodings(forward_pass_callback, forward_pass_callback_args=<class 'aimet_torch.v2.quantsim.quantsim._NOT_SPECIFIED'>)[source]

Computes encodings for all quantizers in the model.

This API will invoke forward_pass_callback, a function written by the user that runs forward pass(es) of the quantized model with a small, representative subset of the training dataset. By doing so, the quantizers in the quantized model will observe the inputs and initialize their quantization encodings according to the observed input statistics.

This function is overloaded with the following signatures:

compute_encodings(forward_pass_callback)[source]
Parameters:

forward_pass_callback (Callable[[torch.nn.Module], Any]) – A function that takes a quantized model and runs forward passes with a small, representative subset of training dataset

compute_encodings(forward_pass_callback, forward_pass_callback_args)[source]
Parameters:
  • forward_pass_callback (Callable[[torch.nn.Module, T], Any]) – A function that takes a quantized model and runs forward passes with a small, representative subset of training dataset

  • forward_pass_callback_args (T) – The second argument to forward_pass_callback.

Example

>>> sim = QuantizationSimModel(...)
>>> _ = sim.model(input) # Can't run forward until quantizer encodings are initialized
RuntimeError: Failed to run QuantizeDequantize since quantization parameters are not initialized.
Please initialize the quantization parameters using `compute_encodings()`.
>>> def run_forward_pass(quantized_model: torch.nn.Module):
...     for input in train_dataloader:
...         with torch.no_grad():
...             _ = quantized_model(input)
...
>>> sim.compute_encodings(run_forward_pass)
>>> _ = sim.model(input) # Now runs successfully!
export(path, filename_prefix, dummy_input, *args, **kwargs)[source]

This method exports out the quant-sim model so it is ready to be run on-target.

Specifically, the following are saved:

  1. The sim-model is exported to a regular PyTorch model without any simulation ops

  2. The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)

  3. Optionally, An equivalent model in ONNX format is exported. In addition, nodes in the ONNX model are named the same as the corresponding PyTorch module names. This helps with matching ONNX node to their quant encoding from #2.

Parameters:
  • path (str) – path where to store model pth and encodings

  • filename_prefix (str) – Prefix to use for filenames of the model pth and encodings files

  • dummy_input (Union[Tensor, Tuple]) – Dummy input to the model. Used to parse model graph. It is required for the dummy_input to be placed on CPU.

  • onnx_export_args – Optional export argument with onnx specific overrides provided as a dictionary or OnnxExportApiArgs object. If not provided, defaults to “opset_version” = None, “input_names” = None, “output_names” = None, and for torch version < 1.10.0, “enable_onnx_checker” = False.

  • propagate_encodings – If True, encoding entries for intermediate ops (when one PyTorch ops results in multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of ops. Defaults to False.

  • export_to_torchscript – If True, export to torchscript. Export to onnx otherwise. Defaults to False.

  • use_embedded_encodings – If True, another onnx model embedded with fakequant nodes will be exported

  • export_model – If True, then ONNX model is exported. When False, only encodings are exported. User should disable (False) this flag only if the corresponding ONNX model already exists in the path specified

  • filename_prefix_encodings – File name prefix to be used when saving encodings. If None, then user defaults to filename_prefix value

Top level APIs

class aimet_tensorflow.keras.quantsim.QuantizationSimModel(model, quant_scheme='tf_enhanced', rounding_mode='nearest', default_output_bw=8, default_param_bw=8, in_place=False, config_file=None, default_data_type=QuantizationDataType.int)[source]

Implements mechanism to add quantization simulations ops to a model. This allows for off-target simulation of inference accuracy. Also allows the model to be fine-tuned to counter the effects of quantization.

Parameters:
  • model – Model to quantize

  • quant_scheme (Union[QuantScheme, str]) – Quantization Scheme, currently supported schemes are post_training_tf and post_training_tf_enhanced, defaults to post_training_tf_enhanced

  • rounding_mode (str) – The round scheme to used. One of: ‘nearest’ or ‘stochastic’, defaults to ‘nearest’.

  • default_output_bw (int) – bitwidth to use for activation tensors, defaults to 8

  • default_param_bw (int) – bitwidth to use for parameter tensors, defaults to 8

  • in_place (bool) – If True, then the given ‘model’ is modified in-place to add quant-sim nodes. Only suggested use of this option is when the user wants to avoid creating a copy of the model

  • config_file (Optional[str]) – Path to a config file to use to specify rules for placing quant ops in the model

  • default_data_type (QuantizationDataType) – Default data type to use for quantizing all layer parameters. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 and default_param_bw=16

compute_encodings(forward_pass_callback, forward_pass_callback_args)[source]

Computes encodings for all quantization sim nodes in the model.

Parameters:
  • forward_pass_callback – A callback function that is expected to run forward passes on a model. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples.

  • forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex.

export(path, filename_prefix, custom_objects=None, convert_to_pb=True)[source]

This method exports out the quant-sim model so it is ready to be run on-target. Specifically, the following are saved

  1. The sim-model is exported to a regular Keras model without any simulation ops

  2. The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)

Parameters:
  • path – path where to store model pth and encodings

  • filename_prefix – Prefix to use for filenames of the model pth and encodings files

  • custom_objects – If there are custom objects to load, Keras needs a dict of them to map them

load_encodings_to_sim(encoding_file_path)[source]

Loads the saved encodings to quant sim model

Parameters:

encoding_file_path (str) – path from where to load encodings file

Returns:

Quant Scheme Enum

class aimet_common.defs.QuantScheme(value)[source]

Quantization schemes

classmethod from_str(alias)[source]

Returns QuantScheme object from string alias

Return type:

QuantScheme

Not supported.