AIMET PyTorch Quantization SIM API¶
User Guide Link¶
To learn more about Quantization Simulation, please see Quantization Sim
Examples Notebook Link¶
For an end-to-end notebook showing how to use PyTorch Quantization-Aware Training, please see here.
Guidelines¶
AIMET Quantization Sim requires PyTorch model definition to follow certain guidelines. These guidelines are described in detail here. Model Guidelines
AIMET provides Model Preparer API to allow user to prepare PyTorch model for AIMET Quantization features. The API and usage examples are described in detail here. Model Preparer API
AIMET also includes a Model Validator utility to allow user to check their model definition. Please see the API and usage examples for this utility here. Model Validator API
Top-level API¶
-
class
aimet_torch.quantsim.
QuantizationSimModel
(model, dummy_input, quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, rounding_mode='nearest', default_output_bw=8, default_param_bw=8, in_place=False, config_file=None, default_data_type=<QuantizationDataType.int: 1>)[source]¶ Implements mechanism to add quantization simulations ops to a model. This allows for off-target simulation of inference accuracy. Also allows the model to be fine-tuned to counter the effects of quantization.
Constructor for QuantizationSimModel.
- Parameters
model (
Module
) – Model to add simulation ops todummy_input (
Union
[Tensor
,Tuple
]) – Dummy input to the model. Used to parse model graph. If the model has more than one input, pass a tuple. User is expected to place the tensors on the appropriate device.quant_scheme (
Union
[str
,QuantScheme
]) – Quantization scheme. The Quantization scheme is used to compute the Quantization encodings. There are multiple schemes available. Please refer the QuantScheme enum definition.rounding_mode (
str
) – Rounding mode. Supported options are ‘nearest’ or ‘stochastic’default_output_bw (
int
) – Default bitwidth (4-31) to use for quantizing all layer inputs and outputsdefault_param_bw (
int
) – Default bitwidth (4-31) to use for quantizing all layer parametersin_place (
bool
) – If True, then the given ‘model’ is modified in-place to add quant-sim nodes. Only suggested use of this option is when the user wants to avoid creating a copy of the modelconfig_file (
Optional
[str
]) – Path to Configuration file for model quantizersdefault_data_type (
QuantizationDataType
) – Default data type to use for quantizing all layer inputs, outputs and parameters. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 and default_param_bw=16
The following API can be used to Compute Encodings for Model
-
QuantizationSimModel.
compute_encodings
(forward_pass_callback, forward_pass_callback_args)[source]¶ Computes encodings for all quantization sim nodes in the model. It is also used to find initial encodings for Range Learning
- Parameters
forward_pass_callback – A callback function that simply runs forward passes on the model. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. This callback internally chooses the number of data samples it wants to use for calculating encodings.
forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex. If set to None, forward_pass_callback will be invoked with no parameters.
- Returns
None
The following APIs can be used to save and restore the quantized model
-
quantsim.
save_checkpoint
(file_path)¶ This API provides a way for the user to save a checkpoint of the quantized model which can be loaded at a later point to continue fine-tuning e.g. See also load_checkpoint()
- Parameters
quant_sim_model (
QuantizationSimModel
) – QuantizationSimModel to save checkpoint forfile_path (
str
) – Path to the file where you want to save the checkpoint
- Returns
None
-
quantsim.
load_checkpoint
()¶ Load the quantized model
- Parameters
file_path (
str
) – Path to the file where you want to save the checkpoint- Return type
- Returns
A new instance of the QuantizationSimModel created after loading the checkpoint
The following API can be used to Export the Model to target
-
QuantizationSimModel.
export
(path, filename_prefix, dummy_input, onnx_export_args=None, propagate_encodings=False, export_to_torchscript=False)[source]¶ This method exports out the quant-sim model so it is ready to be run on-target.
Specifically, the following are saved:
The sim-model is exported to a regular PyTorch model without any simulation ops
The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)
Optionally, An equivalent model in ONNX format is exported. In addition, nodes in the ONNX model are named the same as the corresponding PyTorch module names. This helps with matching ONNX node to their quant encoding from #2.
- Parameters
path (
str
) – path where to store model pth and encodingsfilename_prefix (
str
) – Prefix to use for filenames of the model pth and encodings filesdummy_input (
Union
[Tensor
,Tuple
]) – Dummy input to the model. Used to parse model graph. It is required for the dummy_input to be placed on CPU.onnx_export_args (
Union
[OnnxExportApiArgs
,Dict
[~KT, ~VT],None
]) – Optional export argument with onnx specific overrides provided as a dictionary or OnnxExportApiArgs object. If not provided, defaults to “opset_version” = None, “input_names” = None, “output_names” = None, and for torch version < 1.10.0, “enable_onnx_checker” = False.propagate_encodings (
bool
) – If True, encoding entries for intermediate ops (when one PyTorch ops results in multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of ops. Defaults to False.export_to_torchscript (
bool
) – If True, export to torchscript. Export to onnx otherwise. Defaults to False.
Encoding format is described in the Quantization Encoding Specification
Enum Definition¶
Quant Scheme Enum
-
class
aimet_common.defs.
QuantScheme
[source]¶ Enumeration of Quant schemes
-
post_training_percentile
= 6¶ For a Tensor, adjusted minimum and maximum values are selected based on the percentile value passed. The Quantization encodings are calculated using the adjusted minimum and maximum value.
-
post_training_tf
= 1¶ For a Tensor, the absolute minimum and maximum value of the Tensor are used to compute the Quantization encodings.
-
post_training_tf_enhanced
= 2¶ For a Tensor, searches and selects the optimal minimum and maximum value that minimizes the Quantization Noise. The Quantization encodings are calculated using the selected minimum and maximum value.
-
training_range_learning_with_tf_enhanced_init
= 4¶ For a Tensor, the encoding values are initialized with the post_training_tf_enhanced scheme. Then, the encodings are learned during training.
-
training_range_learning_with_tf_init
= 3¶ For a Tensor, the encoding values are initialized with the post_training_tf scheme. Then, the encodings are learned during training.
-
Code Example - Quantization Aware Training (QAT)¶
This example shows how to use AIMET to perform QAT (Quantization-aware training). QAT is an AIMET feature adding quantization simulation ops (also called fake quantization ops sometimes) to a trained ML model and using a standard training pipeline to fine-tune or train the model for a few epochs. The resulting model should show improved accuracy on quantized ML accelerators.
Simply referred to as QAT - quantization parameters like per-tensor scale/offsets for activations are computed once. During fine-tuning, the model weights are updated to minimize the effects of quantization in the forward pass, keeping the quantization parameters constant.
Required imports
import torch
import torch.cuda
Load the PyTorch Model
For this example, we are going to load a pretrained ResNet18 model from torchvision. Similarly, you can load any pretrained PyTorch model instead.
from torchvision.models import resnet18
model = resnet18(pretrained=True)
model = model.cuda()
Prepare the model for Quantization simulation
AIMET quantization simulation requires the user’s model definition to follow certain guidelines. For example, functionals defined in forward pass should be changed to equivalent torch.nn.Module. AIMET user guide lists all these guidelines. The following ModelPreparer API uses new graph transformation feature available in PyTorch 1.9+ version and automates model definition changes required to comply with the above guidelines.
For more details, please refer: Model Preparer API:
from aimet_torch.model_preparer import prepare_model
prepared_model = prepare_model(model)
Create the Quantization Simulation Model
Now we use AIMET to create a QuantizationSimModel. This basically means that AIMET will insert fake quantization ops in the model graph and will configure them. A few of the parameters are explained here
from aimet_common.defs import QuantScheme
from aimet_torch.quantsim import QuantizationSimModel
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape).cuda()
quant_sim = QuantizationSimModel(prepared_model, dummy_input=dummy_input,
quant_scheme=QuantScheme.post_training_tf_enhanced,
default_param_bw=8, default_output_bw=8,
config_file='../../TrainingExtensions/common/src/python/aimet_common/quantsim_config/'
'default_config.json')
An example User created function that is called back from compute_encodings()
Even though AIMET has added ‘quantizer’ nodes to the model graph, the model is not ready to be used yet. Before we can use the sim model for inference or training, we need to find appropriate scale/offset quantization parameters for each ‘quantizer’ node. For activation quantization nodes, we need to pass unlabeled data samples through the model to collect range statistics which will then let AIMET calculate appropriate scale/offset quantization parameters. This process is sometimes referred to as calibration. AIMET simply refers to it as ‘computing encodings’.
So we create a routine to pass unlabeled data samples through the model. This should be fairly simple - use the existing train or validation data loader to extract some samples and pass them to the model. We don’t need to compute any loss metric etc. So we can just ignore the model output for this purpose. A few pointers regarding the data samples
In practice, we need a very small percentage of the overall data samples for computing encodings. For example, the training dataset for ImageNet has 1M samples. For computing encodings we only need 500 or 1000 samples.
It may be beneficial if the samples used for computing encoding are well distributed. It’s not necessary that all classes need to be covered etc. since we are only looking at the range of values at every layer activation. However, we definitely want to avoid an extreme scenario like all ‘dark’ or ‘light’ samples are used - e.g. only using pictures captured at night might not give ideal results.
def pass_calibration_data(sim_model, forward_pass_args=None):
"""
The User of the QuantizationSimModel API is expected to write this function based on their data set.
This is not a working function and is provided only as a guideline.
:param sim_model:
:param args: other arguments for the forwards
:return:
"""
# User action required
# The following line of code is an example of how to use the ImageNet data's validation data loader.
# Replace the following line with your own dataset's validation data loader.
data_loader = ImageNetDataPipeline.get_val_dataloader()
# User action required
# For computing the activation encodings, around 1000 unlabelled data samples are required.
# Edit the following 2 lines based on your batch size.
# batch_size * max_batch_counter should be 1024
batch_size = 64
max_batch_counter = 16
sim_model.eval()
current_batch_counter = 0
with torch.no_grad():
for input_data, target_data in data_loader:
inputs_batch = input_data # labels are ignored
sim_model(inputs_batch)
current_batch_counter += 1
if current_batch_counter == max_batch_counter:
break
Compute the Quantization Encodings
Now we call AIMET to use the above routine to pass data through the model and then subsequently compute the quantization encodings. Encodings here refer to scale/offset quantization parameters.
quant_sim.compute_encodings(pass_calibration_data, forward_pass_callback_args=None)
Finetune the Quatization Simulation Model
To perform quantization aware training (QAT), we simply train the model for a few more epochs (typically 15-20). As with any training job, hyper-parameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so.
For the purpose of this example, we are going to train only for 1 epoch. But feel free to change these parameters as you see fit.
# User action required
# The following line of code illustrates that the model is getting finetuned.
# Replace the following finetune() unction with your pipeline's finetune() function.
ImageNetDataPipeline.finetune(quant_sim.model, epochs=1, learning_rate=5e-7, learning_rate_schedule=[5, 10],
use_cuda=use_cuda)
# Determine simulated accuracy
accuracy = ImageNetDataPipeline.evaluate(quant_sim.model, use_cuda)
print(accuracy)
Export the model
So we have an improved model after QAT. Now the next step would be to actually take this model to target. For this purpose, we need to export the model with the updated weights without the fake quant ops. We also to export the encodings (scale/offset quantization parameters) that were updated during training since we employed QAT. AIMET QuantizationSimModel provides an export API for this purpose.
# Export the model which saves pytorch model without any simulation nodes and saves encodings file for both
# activations and parameters in JSON format
quant_sim.export(path='./', filename_prefix='quantized_resnet18', dummy_input=dummy_input.cpu())