Quantization-aware training¶
Quantization-aware training (QAT) finds better-optimized solutions than post-training quantization (PTQ) by fine-tuning the model parameters in the presence of quantization noise. This higher accuracy comes at the usual cost of neural network training, including longer training times and the need for labeled data and hyperparameter search.
QAT modes¶
There are two versions of QAT: without range learning and with range learning.
- Without range learning:
In QAT without range Learning, encoding values for activation quantizers are found once during calibration and are not updated again.
- With range learning:
In QAT with range Learning, encoding values for activation quantizers are set during calibration and can be updated during training, resulting in better scale and offset quantization parameters.
In both versions, parameter quantizer encoding values continue to be updated with the parameters themselves during training.
QAT recommendations¶
Here are some guidelines that can improve performance and speed convergence with QAT:
- Initialization
It often helps to first apply PTQ techniques before applying QAT, especially if there is large drop in INT8 performance from the FP32 baseline.
- Hyper-parameters
Number of epochs: 15-20 epochs are usually sufficient for convergence
Learning rate: Comparable (or one order higher) to FP32 model’s final learning rate at convergence. Results in AIMET are with learning of the order 1e-6.
Learning rate schedule: Divide learning rate by 10 every 5-10 epochs
Workflow¶
Setup¶
Setup the model, data loader, and training loops for training.
import torch
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
from aimet_torch.batch_norm_fold import fold_all_batch_norms
# General setup that can be changed as needed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = torchvision.models.mobilenet_v2(pretrained=True).eval().to(device)
batch_size = 64
PATH_TO_IMAGENET = ...
data = torchvision.datasets.ImageNet(PATH_TO_IMAGENET, split="train")
data_loader = DataLoader(data, batch_size=batch_size)
dummy_input = torch.randn(1, 3, 224, 224).to(device)
fold_all_batch_norms(model, dummy_input.shape)
# Callback function to pass calibration data through the model
def pass_calibration_data(model: torch.nn.Module, batches):
"""
The User of the QuantizationSimModel API is expected to write this callback based on their dataset.
"""
with torch.no_grad():
for batch, (images, _) in enumerate(data_loader):
images = images.to(device)
model(images)
if batch >= batches:
break
# Basic ImageNet evaluation function
def evaluate(model, data_loader):
model.eval()
correct = 0
with torch.no_grad():
for data, labels in tqdm(data_loader):
data, labels = data.to(device), labels.to(device)
logits = model(data)
correct += (logits.argmax(1) == labels).type(torch.float).sum().item()
accuracy = correct / len(data_loader.dataset)
return accuracy
from aimet_common.defs import QuantScheme
from aimet_tensorflow.keras.quantsim import QuantizationSimModel
from tensorflow.keras import applications, losses, metrics, optimizers, preprocessing
from tensorflow.keras.applications import mobilenet_v2
model = applications.MobileNetV2()
# Set up dataset
BATCH_SIZE = 32
imagenet_dataset = preprocessing.image_dataset_from_directory(
directory='<your_imagenet_validation_data_path>',
labels='inferred',
label_mode='categorical',
image_size=(224, 224),
batch_size=BATCH_SIZE,
shuffle=True,
)
imagenet_dataset = imagenet_dataset.map(
lambda x, y: (mobilenet_v2.preprocess_input(x), y)
)
NUM_CALIBRATION_SAMPLES = 2048
calibration_dataset = imagenet_dataset.take(NUM_CALIBRATION_SAMPLES // BATCH_SIZE)
eval_dataset = imagenet_dataset.skip(NUM_CALIBRATION_SAMPLES // BATCH_SIZE)
def pass_calibration_data(model, _):
"""
The User of the QuantizationSimModel API is expected to write this callback based on their dataset.
"""
for inputs, _ in calibration_dataset:
model(inputs)
Compute the initial quantization parameters¶
from aimet_common.defs import QuantScheme
from aimet_torch.quantsim import QuantizationSimModel
sim = QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.training_range_learning_with_tf_init)
calibration_batches = 10
sim.compute_encodings(pass_calibration_data, calibration_batches)
accuracy = evaluate(sim.model, data_loader)
print(f"Quantized accuracy (W8A8): {accuracy}")
Quantized accuracy (W8A8): 0.68016
PARAM_BITWIDTH = 8
ACTIVATION_BITWIDTH = 8
QUANT_SCHEME = QuantScheme.training_range_learning_with_tf_init
sim = QuantizationSimModel(
model,
quant_scheme=QUANT_SCHEME,
default_param_bw=PARAM_BITWIDTH,
default_output_bw=ACTIVATION_BITWIDTH,
)
sim.compute_encodings(pass_calibration_data, None)
sim.model.compile(
optimizer=optimizers.SGD(learning_rate=1e-5),
loss=[losses.CategoricalCrossentropy()],
metrics=[metrics.CategoricalAccuracy()],
)
_, accuracy = sim.model.evaluate(eval_dataset)
print(f'Quantized accuracy (W8A8): {accuracy:.4f}')
Quantized accuracy (W8A8): 0.6583
Fine-tune quantized model¶
# Training loop can be replaced with any custom training loop
def train(model, data_loader, optimizer, loss_fn):
model.train()
for data, labels in tqdm(data_loader):
data, labels = data.to(device), labels.to(device)
logits = model(data)
loss = loss_fn(logits, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(sim.model.parameters(), lr=1e-5)
epochs = 2
for epoch in range(epochs):
train(sim.model, data_loader, optimizer, loss_fn)
sim.model.fit(calibration_dataset, epochs=10)
Evaluation¶
Next, we evaluate the QuantizationSimModel
to get quantized accuracy.
accuracy = evaluate(sim.model, data_loader)
print(f"Model accuracy after QAT: {accuracy}")
Model accuracy after QAT: 0.70838
_, accuracy = sim.model.evaluate(eval_dataset) print(f'Model accuracy after QAT: {accuracy:.4f}')
Model accuracy after QAT: 0.6910
Export¶
After fine-tuning the model’s quantized accuracy with QAT, export a version of the model with quantization operations removed and an encodings JSON file with quantization scale and offset parameters for the model’s activation and weight tensors.
sim.export(path="./", filename_prefix="quantized_mobilenetv2", dummy_input=dummy_input.cpu())
sim.export(path='/tmp', filename_prefix='quantized_mobilenet_v2')
Multi-GPU support¶
For using multi-GPU with QAT,
Create a
QuantizationSimModel
for your pre-trained PyTorch model (Not in DataParallel mode)Perform
QuantizationSimModel.compute_encodings()
(NOTE: Do not use a forward function that moves the model to multi-gpu and back)Move
QuantizationSimModel
to DataParallel.
# "sim" here refers to QuantizationSimModel object.
sim.model = torch.nn.DataParallel(sim.model)
Perform eval and/or training.
Export for on-target inference.
API¶
Top level APIs
- class aimet_torch.quantsim.QuantizationSimModel(model, dummy_input, quant_scheme=None, rounding_mode=None, default_output_bw=8, default_param_bw=8, in_place=False, config_file=None, default_data_type=QuantizationDataType.int)[source]
Class that simulates the quantized model execution on a target hardware backend.
QuantizationSimModel simulates quantization of a given model by converting all PyTorch modules into quantized modules with input/output/parameter quantizers as necessary.
Example
>>> model = torchvision.models.resnet18() >>> dummy_input = torch.randn(1, 3, 224, 224) >>> sim = QuantizationSimModel(model, dummy_input) >>> print(model) ResNet( (conv1): Conv2d( 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False ) ... ) >>> print(sim.model) ResNet( (conv1): QuantizedConv2d( 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False (param_quantizers): ModuleDict( (weight): QuantizeDequantize(shape=(), qmin=-128, qmax=127, symmetric=True) ) (input_quantizers): ModuleList( (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False) ) (output_quantizers): ModuleList( (0): None ) ) ... )
Warning
rounding_mode parameter is deprecated. Passing rounding_mode will throw runtime error in >=1.35.
Warning
The default value of quant_scheme will change from QuantScheme.post_training_tf_enhanced to QuantScheme.training_range_learning_with_tf_init in the future versions, and will be deprecated in the longer term.
- Parameters:
model (torch.nn.Module) – Model to simulate the quantized execution of
dummy_input (Tensor | Sequence[Tensor]) – Dummy input to be used to capture the computational graph of the model. All input tensors are expected to be already placed on the appropriate devices to run forward pass of the model.
quant_scheme (QuantScheme, optional) – Quantization scheme that indicates how to observe and calibrate the quantization encodings (Default: QuantScheme.post_training_tf_enhanced)
rounding_mode – Deprecated
default_output_bw (int, optional) – Default bitwidth (4-31) to use for quantizing all layer inputs and outputs unless otherwise specified in the config file. (Default: 8)
default_param_bw (int, optional) – Default bitwidth (4-31) to use for quantizing all layer parameters unless otherwise specified in the config file. (Default: 8)
in_place (bool, optional) – If True, then the given model is modified in-place into a quantized model. (Default: False)
config_file (str, optional) – Path to the quantization simulation config file (Default: None)
default_data_type (QuantizationDataType, optional) – Default data type to use for quantizing all inputs, outputs and parameters unless otherwise specified in the config file. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 or 32 and default_param_bw=16 or 32. (Default: QuantizationDataType.int)
- compute_encodings(forward_pass_callback, forward_pass_callback_args=<class 'aimet_torch.v2.quantsim.quantsim._NOT_SPECIFIED'>)[source]
Computes encodings for all quantizers in the model.
This API will invoke forward_pass_callback, a function written by the user that runs forward pass(es) of the quantized model with a small, representative subset of the training dataset. By doing so, the quantizers in the quantized model will observe the inputs and initialize their quantization encodings according to the observed input statistics.
This function is overloaded with the following signatures:
- compute_encodings(forward_pass_callback)[source]
- Parameters:
forward_pass_callback (Callable[[torch.nn.Module], Any]) – A function that takes a quantized model and runs forward passes with a small, representative subset of training dataset
- compute_encodings(forward_pass_callback, forward_pass_callback_args)[source]
- Parameters:
forward_pass_callback (Callable[[torch.nn.Module, T], Any]) – A function that takes a quantized model and runs forward passes with a small, representative subset of training dataset
forward_pass_callback_args (T) – The second argument to forward_pass_callback.
Example
>>> sim = QuantizationSimModel(...) >>> _ = sim.model(input) # Can't run forward until quantizer encodings are initialized RuntimeError: Failed to run QuantizeDequantize since quantization parameters are not initialized. Please initialize the quantization parameters using `compute_encodings()`. >>> def run_forward_pass(quantized_model: torch.nn.Module): ... for input in train_dataloader: ... with torch.no_grad(): ... _ = quantized_model(input) ... >>> sim.compute_encodings(run_forward_pass) >>> _ = sim.model(input) # Now runs successfully!
- export(path, filename_prefix, dummy_input, *args, **kwargs)[source]
This method exports out the quant-sim model so it is ready to be run on-target.
Specifically, the following are saved:
The sim-model is exported to a regular PyTorch model without any simulation ops
The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)
Optionally, An equivalent model in ONNX format is exported. In addition, nodes in the ONNX model are named the same as the corresponding PyTorch module names. This helps with matching ONNX node to their quant encoding from #2.
- Parameters:
path (
str
) – path where to store model pth and encodingsfilename_prefix (
str
) – Prefix to use for filenames of the model pth and encodings filesdummy_input (
Union
[Tensor
,Tuple
]) – Dummy input to the model. Used to parse model graph. It is required for the dummy_input to be placed on CPU.onnx_export_args – Optional export argument with onnx specific overrides provided as a dictionary or OnnxExportApiArgs object. If not provided, defaults to “opset_version” = None, “input_names” = None, “output_names” = None, and for torch version < 1.10.0, “enable_onnx_checker” = False.
propagate_encodings – If True, encoding entries for intermediate ops (when one PyTorch ops results in multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of ops. Defaults to False.
export_to_torchscript – If True, export to torchscript. Export to onnx otherwise. Defaults to False.
use_embedded_encodings – If True, another onnx model embedded with fakequant nodes will be exported
export_model – If True, then ONNX model is exported. When False, only encodings are exported. User should disable (False) this flag only if the corresponding ONNX model already exists in the path specified
filename_prefix_encodings – File name prefix to be used when saving encodings. If None, then user defaults to filename_prefix value
- load_encodings(encodings, strict=True, partial=True, requires_grad=None, allow_overwrite=True)
- Parameters:
encodings (
Union
[Mapping
,str
,PathLike
]) – Encoding dictionary or path to the encoding dictionary json file.strict (bool) – If True, an error will be thrown if the model doesn’t have a quantizer corresponding to the specified encodings.
partial (bool) – If True, the encoding will be interpreted as a partial encoding, and the dangling quantizers with no corresponding encoding will be kept untouched. Otherwise, the dangling quantizers will be removed from the model.
requires_grad (bool) – Whether or not the quantization parameters loaded from the encodings require gradient computation during training. If None,
requires_grad
flag of the quantization parameters will be kept unchanged.allow_overwrite (bool) – Whether or not the quantization parameters loaded from the encodings can be overwriiten by compute_encodings or another load_encodings. If None, whether the quantizer is overwrieable will be kept unchanged.
Quant Scheme Enum
- class aimet_common.defs.QuantScheme(value)[source]
Enumeration of Quant schemes
- post_training_percentile = 6
For a Tensor, adjusted minimum and maximum values are selected based on the percentile value passed. The Quantization encodings are calculated using the adjusted minimum and maximum value.
- post_training_tf = 1
For a Tensor, the absolute minimum and maximum value of the Tensor are used to compute the Quantization encodings.
- post_training_tf_enhanced = 2
For a Tensor, searches and selects the optimal minimum and maximum value that minimizes the Quantization Noise. The Quantization encodings are calculated using the selected minimum and maximum value.
- training_range_learning_with_tf_enhanced_init = 4
For a Tensor, the encoding values are initialized with the post_training_tf_enhanced scheme. Then, the encodings are learned during training.
- training_range_learning_with_tf_init = 3
For a Tensor, the encoding values are initialized with the post_training_tf scheme. Then, the encodings are learned during training.
Top level APIs
- class aimet_tensorflow.keras.quantsim.QuantizationSimModel(model, quant_scheme='tf_enhanced', rounding_mode='nearest', default_output_bw=8, default_param_bw=8, in_place=False, config_file=None, default_data_type=QuantizationDataType.int)[source]
Implements mechanism to add quantization simulations ops to a model. This allows for off-target simulation of inference accuracy. Also allows the model to be fine-tuned to counter the effects of quantization.
- Parameters:
model – Model to quantize
quant_scheme (
Union
[QuantScheme
,str
]) – Quantization Scheme, currently supported schemes are post_training_tf and post_training_tf_enhanced, defaults to post_training_tf_enhancedrounding_mode (
str
) – The round scheme to used. One of: ‘nearest’ or ‘stochastic’, defaults to ‘nearest’.default_output_bw (
int
) – bitwidth to use for activation tensors, defaults to 8default_param_bw (
int
) – bitwidth to use for parameter tensors, defaults to 8in_place (
bool
) – If True, then the given ‘model’ is modified in-place to add quant-sim nodes. Only suggested use of this option is when the user wants to avoid creating a copy of the modelconfig_file (
Optional
[str
]) – Path to a config file to use to specify rules for placing quant ops in the modeldefault_data_type (
QuantizationDataType
) – Default data type to use for quantizing all layer parameters. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 and default_param_bw=16
- compute_encodings(forward_pass_callback, forward_pass_callback_args)[source]
Computes encodings for all quantization sim nodes in the model.
- Parameters:
forward_pass_callback – A callback function that is expected to run forward passes on a model. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples.
forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex.
- export(path, filename_prefix, custom_objects=None, convert_to_pb=True)[source]
This method exports out the quant-sim model so it is ready to be run on-target. Specifically, the following are saved
The sim-model is exported to a regular Keras model without any simulation ops
The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)
- Parameters:
path – path where to store model pth and encodings
filename_prefix – Prefix to use for filenames of the model pth and encodings files
custom_objects – If there are custom objects to load, Keras needs a dict of them to map them
- load_encodings_to_sim(encoding_file_path)[source]
Loads the saved encodings to quant sim model
- Parameters:
encoding_file_path (
str
) – path from where to load encodings file- Returns:
Quant Scheme Enum
- class aimet_common.defs.QuantScheme(value)[source]
Enumeration of Quant schemes
- post_training_percentile = 6
For a Tensor, adjusted minimum and maximum values are selected based on the percentile value passed. The Quantization encodings are calculated using the adjusted minimum and maximum value.
- post_training_tf = 1
For a Tensor, the absolute minimum and maximum value of the Tensor are used to compute the Quantization encodings.
- post_training_tf_enhanced = 2
For a Tensor, searches and selects the optimal minimum and maximum value that minimizes the Quantization Noise. The Quantization encodings are calculated using the selected minimum and maximum value.
- training_range_learning_with_tf_enhanced_init = 4
For a Tensor, the encoding values are initialized with the post_training_tf_enhanced scheme. Then, the encodings are learned during training.
- training_range_learning_with_tf_init = 3
For a Tensor, the encoding values are initialized with the post_training_tf scheme. Then, the encodings are learned during training.