AIMET PyTorch Quantization SIM API¶
AIMET Quantization Sim requires the model definitions to use certain constructs and avoid others. These constraints are described in detail here.
AIMET also includes a Model Validator tool to allow the users to check their model definition and find constructs that might need to be replaced. Please see the API and usage examples for this tool also on the same page.
Top-level API¶
-
class
aimet_torch.quantsim.
QuantizationSimModel
(model, dummy_input, quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, rounding_mode='nearest', default_output_bw=8, default_param_bw=8, in_place=False, config_file=None)¶ Implements mechanism to add quantization simulations ops to a model. This allows for off-target simulation of inference accuracy. Also allows the model to be fine-tuned to counter the effects of quantization.
Constructor
- Parameters
model (
Module
) – Model to add simulation ops todummy_input (
Union
[Tensor
,Tuple
]) – Dummy input to the model. Used to parse model graph. If the model has more than one input, pass a tuple. User is expected to place the tensors on the appropriate device.quant_scheme (
QuantScheme
) – Quantization scheme. Supported options are QuantScheme.post_training_tf, QuantScheme.post_training_tf_enhanced (default), QuantScheme.training_range_learning_with_tf_init, training_range_learning_with_tf_enhanced_init. Note that passing a string is deprecated.rounding_mode (
str
) – Rounding mode. Supported options are ‘nearest’ or ‘stochastic’default_output_bw (
int
) – Default bitwidth (4-31) to use for quantizing all layer inputs and outputsdefault_param_bw (
int
) – Default bitwidth (4-31) to use for quantizing all layer parametersin_place (
bool
) – If True, then the given ‘model’ is modified in-place to add quant-sim nodes. Only suggested use of this option is when the user wants to avoid creating a copy of the modelconfig_file (
Optional
[str
]) – Path to config file that specifies rules on how quantizers should be configured. E.g. this file can specify if parameters or activation quantization should be symmetric. If this is not specified, a default config file is used.
- Note about Quantization SchemesAIMET offers multiple Quantization Schemes-
Post Training Quantization- The encodings of the model are computed using TF or TF-Enhanced scheme
- Trainable Quantization- The min max of encodings are learnt during training
Range Learning with TF initialization - Uses TF scheme to initialize the encodings and then during training these encodings are fine-tuned to improve accuracy of the model
Range Learning with TF-Enhanced initialization - Uses TF-Enhanced scheme to initialize the encodings and then during training these encodings are fine-tuned to improve accuracy of the model
The following API can be used to Compute Encodings for Model
-
QuantizationSimModel.
compute_encodings
(forward_pass_callback, forward_pass_callback_args)¶ Computes encodings for all quantization sim nodes in the model. It is also used to find initial encodings for Range Learning
- Parameters
forward_pass_callback – A callback function that simply runs forward passes on the model. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. This callback internally chooses the number of data samples it wants to use for calculating encodings.
forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex. If set to None, forward_pass_callback will be invoked with no parameters.
- Returns
None
The following APIs can be used to save and restore the quantized model
-
quantsim.
save_checkpoint
(file_path)¶ This API provides a way for the user to save a checkpoint of the quantized model which can be loaded at a later point to continue fine-tuning e.g. See also load_checkpoint()
- Parameters
quant_sim_model (
QuantizationSimModel
) – QuantizationSimModel to save checkpoint forfile_path (
str
) – Path to the file where you want to save the checkpoint
- Returns
None
-
quantsim.
load_checkpoint
()¶ Load the quantized model
- Parameters
file_path (
str
) – Path to the file where you want to save the checkpoint- Return type
- Returns
A new instance of the QuantizationSimModel created after loading the checkpoint
The following API can be used to Export the Model to target
-
QuantizationSimModel.
export
(path, filename_prefix, dummy_input, onnx_export_args=<aimet_torch.onnx_utils.OnnxExportApiArgs object>)¶ This method exports out the quant-sim model so it is ready to be run on-target.
Specifically, the following are saved
The sim-model is exported to a regular PyTorch model without any simulation ops
The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)
Optionally, An equivalent model in ONNX format is exported. In addition, nodes in the ONNX model are named the same as the corresponding PyTorch module names. This helps with matching ONNX node to their quant encoding from #2.
- Parameters
path (
str
) – path where to store model pth and encodingsfilename_prefix (
str
) – Prefix to use for filenames of the model pth and encodings filesdummy_input (
Union
[Tensor
,Tuple
]) – Dummy input to the model. Used to parse model graph. It is required for the dummy_input to be placed on CPU.onnx_export_args (
Optional
[OnnxExportApiArgs
]) – optional export argument with onnx specific overrides if not provide export via
torchscript graph :return: None
Encoding format is described in the Quantization Encoding Specification
Enum Definition¶
Quant Scheme Enum
-
class
aimet_common.defs.
QuantScheme
¶ Enumeration of Quant schemes
-
post_training_tf
= 1¶ TF scheme (absolute min-max)
-
post_training_tf_enhanced
= 2¶ Tf- enhanced scheme (SQNR based approach to discard outliers)
-
training_range_learning_with_tf_enhanced_init
= 4¶ Learn appropriate encodings (scale/offset) during QAT. Uses TF-enhanced scheme (SQNR) to initialize.
-
training_range_learning_with_tf_init
= 3¶ Learn appropriate encodings (scale/offset) during QAT. Uses TF scheme (absolute min/max) to initialize.
-
Code Example #1 - Post Training Quantization¶
Required imports
import torch
from aimet_torch.examples import mnist_torch_model
# Quantization related import
from aimet_torch.quantsim import QuantizationSimModel
Evaluation function
def evaluate_model(model: torch.nn.Module, eval_iterations: int, use_cuda: bool = False) -> float:
"""
This is intended to be the user-defined model evaluation function.
AIMET requires the above signature. So if the user's eval function does not
match this signature, please create a simple wrapper.
Note: Honoring the number of iterations is not absolutely necessary.
However if all evaluations run over an entire epoch of validation data,
the runtime for AIMET compression will obviously be higher.
:param model: Model to evaluate
:param eval_iterations: Number of iterations to use for evaluation.
None for entire epoch.
:param use_cuda: If true, evaluate using gpu acceleration
:return: single float number (accuracy) representing model's performance
"""
return .5
Quantize and fine-tune a trained model
def quantize_model(trainer_function):
model = mnist_torch_model.Net().to(torch.device('cuda'))
sim = QuantizationSimModel(model, default_output_bw=8, default_param_bw=8, dummy_input=torch.rand(1, 1, 28, 28),
config_file='../../../TrainingExtensions/common/src/python/aimet_common/quantsim_config/'
'default_config.json')
# Quantize the untrained MNIST model
sim.compute_encodings(forward_pass_callback=evaluate_model, forward_pass_callback_args=5)
# Fine-tune the model's parameter using training
trainer_function(model=sim.model, epochs=1, num_batches=100, use_cuda=True)
# Export the model
sim.export(path='./', filename_prefix='quantized_mnist', dummy_input=torch.rand(1, 1, 28, 28))
Code Example #2 - Trainable Quantization¶
Required imports
import torch
# Quantization Aware Training related imports
from aimet_common.defs import QuantScheme
import aimet_torch.examples.mnist_torch_model as mnist_model
from aimet_torch.quantsim import QuantizationSimModel
Evaluation function to be used for computing initial encodings
def evaluate_model(model: torch.nn.Module, eval_iterations: int, use_cuda: bool = False) -> float:
"""
This is intended to be the user-defined model evaluation function.
AIMET requires the above signature. So if the user's eval function does not
match this signature, please create a simple wrapper.
Note: Honoring the number of iterations is not absolutely necessary.
However if all evaluations run over an entire epoch of validation data,
the runtime for AIMET compression will obviously be higher.
:param model: Model to evaluate
:param eval_iterations: Number of iterations to use for evaluation.
None for entire epoch.
:param use_cuda: If true, evaluate using gpu acceleration
:return: single float number (accuracy) representing model's performance
"""
return .5
Quantize and fine-tune a trained model to learn min max ranges
def quantization_aware_training_range_learning(forward_pass):
model = mnist_model.Net().to(device='cuda')
sim = QuantizationSimModel(model, quant_scheme=QuantScheme.training_range_learning_with_tf_init, default_output_bw=8,
default_param_bw=8, input_shapes=(1, 1, 28, 28))
# Initialize the model with encodings
sim.compute_encodings(forward_pass, forward_pass_callback_args=5)
# Train the model to fine-tune the encodings
sim.model.train()
mnist_model.train(sim.model, epochs=1, num_batches=100, use_cuda=True)