AIMET PyTorch AdaRound API¶
Examples Notebook Link¶
For an end-to-end notebook showing how to use PyTorch AdaRound, please see here.
Top-level API¶
-
aimet_torch.adaround.adaround_weight.Adaround.
apply_adaround
(model, dummy_input, params, path, filename_prefix, default_param_bw=4, param_bw_override_list=None, ignore_quant_ops_list=None, default_quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, default_config_file=None)¶ Returns model with optimized weight rounding of every module (Conv and Linear) and also saves the corresponding quantization encodings to a separate JSON-formatted file that can then be imported by QuantSim for inference or QAT
- Parameters
model (
Module
) – Model to Adarounddummy_input (
Union
[Tensor
,Tuple
]) – Dummy input to the model. Used to parse model graph. If the model has more than one input, pass a tuple. User is expected to place the tensors on the appropriate device.params (
AdaroundParameters
) – Parameters for Adaroundpath (
str
) – path where to store parameter encodingsfilename_prefix (
str
) – Prefix to use for filename of the encodings filedefault_param_bw (
int
) – Default bitwidth (4-31) to use for quantizing layer parametersparam_bw_override_list (
Optional
[List
[Tuple
[Module
,int
]]]) – List of Tuples. Each Tuple is a module and the corresponding parameter bitwidth to be used for that module.ignore_quant_ops_list (
Optional
[List
[Module
]]) – Ops listed here are skipped during quantization needed for AdaRounding. Do not specify Conv and Linear modules in this list. Doing so, will affect accuracy.default_quant_scheme (
QuantScheme
) – Quantization scheme. Supported options are using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanceddefault_config_file (
Optional
[str
]) – Default configuration file for model quantizers
- Return type
Module
- Returns
Model with Adarounded weights and saves corresponding parameter encodings JSON file at provided path
Adaround Parameters¶
-
class
aimet_torch.adaround.adaround_weight.
AdaroundParameters
(data_loader, num_batches, default_num_iterations=None, default_reg_param=0.01, default_beta_range=(20, 2), default_warm_start=0.2, forward_fn=None)[source]¶ Configuration parameters for Adaround
- Parameters
data_loader (
DataLoader
[+T_co]) – Data loadernum_batches (
int
) – Number of batches to be used for Adaround. A commonly recommended value for this parameter is the smaller value among (1) len(data_loader) and (2) ceil(2000/batch_size)default_num_iterations (
Optional
[int
]) – Number of iterations to adaround each layer. The default value is 10K for models with 8- or higher bit weights, and 15K for models with lower than 8 bit weights.default_reg_param (
float
) – Regularization parameter, trading off between rounding loss vs reconstruction loss. Default 0.01default_beta_range (
Tuple
) – Start and stop beta parameter for annealing of rounding loss (start_beta, end_beta). Default (20, 2)default_warm_start (
float
) – warm up period, during which rounding loss has zero effect. Default 20% (0.2)forward_fn (
Optional
[Callable
[[Module
,Any
],Any
]]) – Optional adapter function that performs forward pass given a model and inputs yielded from the data loader. The function expects model as first argument and inputs to model as second argument.
Enum Definition¶
Quant Scheme Enum
-
class
aimet_common.defs.
QuantScheme
[source]¶ Enumeration of Quant schemes
-
post_training_percentile
= 6¶ For a Tensor, adjusted minimum and maximum values are selected based on the percentile value passed. The Quantization encodings are calculated using the adjusted minimum and maximum value.
-
post_training_tf
= 1¶ For a Tensor, the absolute minimum and maximum value of the Tensor are used to compute the Quantization encodings.
-
post_training_tf_enhanced
= 2¶ For a Tensor, searches and selects the optimal minimum and maximum value that minimizes the Quantization Noise. The Quantization encodings are calculated using the selected minimum and maximum value.
-
training_range_learning_with_tf_enhanced_init
= 4¶ For a Tensor, the encoding values are initialized with the post_training_tf_enhanced scheme. Then, the encodings are learned during training.
-
training_range_learning_with_tf_init
= 3¶ For a Tensor, the encoding values are initialized with the post_training_tf scheme. Then, the encodings are learned during training.
-
Code Example - Adaptive Rounding (AdaRound)¶
This example shows how to use AIMET to perform Adaptive Rounding (AdaRound).
Load the model
For this example, we are going to load a pretrained ResNet18 model from torchvision. Similarly, you can load any pretrained PyTorch model instead.
import torch
from torchvision import models
model = models.resnet18(pretrained=True).eval()
Prepare the model for Quantization simulation
AIMET quantization simulation requires the user’s model definition to follow certain guidelines. For example, functionals defined in forward pass should be changed to equivalent torch.nn.Module. AIMET user guide lists all these guidelines. The following ModelPreparer API uses new graph transformation feature available in PyTorch 1.9+ version and automates model definition changes required to comply with the above guidelines.
For more details, please refer: Model Preparer API:
from aimet_torch.model_preparer import prepare_model
prepared_model = prepare_model(model)
Apply AdaRound
We can now apply AdaRound to this model.
Some of the parameters for AdaRound are described below
dataloader: AdaRound needs a dataloader to use data samples for the layer-by-layer optimization to learn the rounding vectors. Either a training or validation dataloader could be passed in.
num_batches: The number of batches used to evaluate the model while calculating the quantization encodings. Typically we want AdaRound to use around 2000 samples. So with a batch size of 32, this may translate to 64 batches. To speed up the execution here we are using a batch size of 1.
default_num_iterations: The number of iterations to adaround each layer. Default value is set to 10000 and we strongly recommend to not reduce this number. But in this example we are using 32 to speed up the execution runtime.
from aimet_common.defs import QuantScheme
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters
# User action required
# The following line of code is an example of how to use the ImageNet data's training data loader.
# Replace the following line with your own dataset's training data loader.
data_loader = ImageNetDataPipeline.get_train_dataloader()
params = AdaroundParameters(data_loader=data_loader, num_batches=4, default_num_iterations=32,
default_reg_param=0.01, default_beta_range=(20, 2))
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)
# Returns model with adarounded weights and their corresponding encodings
adarounded_model = Adaround.apply_adaround(prepared_model, dummy_input, params, path='./',
filename_prefix='resnet18', default_param_bw=4,
default_quant_scheme=QuantScheme.post_training_tf_enhanced,
default_config_file=None)
Create the Quantization Simulation Model
Now we use AdaRounded model and create a QuantizationSimModel. This basically means that AIMET will insert fake quantization ops in the model graph and will configure them. A few of the parameters are explained here
default_param_bw: The QuantizationSimModel must be created with the same parameter bitwidth precision that was used in the apply_adaround() created.
Freezing the parameter encodings: After creating the QuantizationSimModel, the set_and_freeze_param_encodings() API must be called before calling the compute_encodings() API. While applying AdaRound, the parameter values have been rounded up or down based on these initial encodings internally created. Fo r Quantization Simulation accuracy, it is important to freeze these encodings. If the parameters encodings are NOT frozen, the call to compute_encodings() will alter the value of the parameters encoding and Quantization Simulation accuracy will not be correct.
sim = QuantizationSimModel(adarounded_model, quant_scheme=quant_scheme, default_param_bw=param_bw,
default_output_bw=output_bw, dummy_input=dummy_input)
# Set and freeze encodings to use same quantization grid and then invoke compute encodings
sim.set_and_freeze_param_encodings(encoding_path='./resnet18.encodings')
An example User created function that is called back from compute_encodings()
Even though AIMET has added ‘quantizer’ nodes to the model graph, the model is not ready to be used yet. Before we can use the sim model for inference or training, we need to find appropriate scale/offset quantization parameters for each ‘quantizer’ node. For activation quantization nodes, we need to pass unlabeled data samples through the model to collect range statistics which will then let AIMET calculate appropriate scale/offset quantization parameters. This process is sometimes referred to as calibration. AIMET simply refers to it as ‘computing encodings’.
So we create a routine to pass unlabeled data samples through the model. This should be fairly simple - use the existing train or validation data loader to extract some samples and pass them to the model. We don’t need to compute any loss metric etc. So we can just ignore the model output for this purpose. A few pointers regarding the data samples
In practice, we need a very small percentage of the overall data samples for computing encodings. For example, the training dataset for ImageNet has 1M samples. For computing encodings we only need 500 or 1000 samples.
It may be beneficial if the samples used for computing encoding are well distributed. It’s not necessary that all classes need to be covered etc. since we are only looking at the range of values at every layer activation. However, we definitely want to avoid an extreme scenario like all ‘dark’ or ‘light’ samples are used - e.g. only using pictures captured at night might not give ideal results.
def pass_calibration_data(sim_model):
"""
The User of the QuantizationSimModel API is expected to write this function based on their data set.
This is not a working function and is provided only as a guideline.
:param sim_model:
:return:
"""
# User action required
# The following line is an example of how to use the ImageNet data's validation data loader.
# Replace the following line with your own dataset's validation data loader.
data_loader = ImageNetDataPipeline.get_val_dataloader()
# User action required
# For computing the activation encodings, around 1000 unlabelled data samples are required.
# Edit the following 2 lines based on your batch size.
# batch_size * max_batch_counter should be 1024
batch_size = 64
max_batch_counter = 16
sim_model.eval()
current_batch_counter = 0
with torch.no_grad():
for input_data, target_data in data_loader:
inputs_batch = input_data # labels are ignored
sim_model(inputs_batch)
current_batch_counter += 1
if current_batch_counter == max_batch_counter:
break
Compute the Quantization Encodings
Now we call AIMET to use the above routine to pass data through the model and then subsequently compute the quantization encodings. Encodings here refer to scale/offset quantization parameters.
sim.compute_encodings(pass_calibration_data, forward_pass_callback_args=None)
Determine Simulated Accuracy
Now the QuantizationSim model is ready to be used for inference. First we can pass this model to an evaluation routine. The evaluation routine will now give us a simulated quantized accuracy score for INT8 quantization.
accuracy = ImageNetDataPipeline.evaluate(sim.model, use_cuda)
print(accuracy)
Export the model
So we have an improved model after AdaRound. Now the next step would be to actually take this model to target. For this purpose, we need to export the model with the updated weights without the fake quant ops. We also to export the encodings (scale/offset quantization parameters) that were updated during training since we employed QAT. AIMET QuantizationSimModel provides an export API for this purpose.
# Export the model which saves pytorch model without any simulation nodes and saves encodings file for both
# activations and parameters in JSON format
sim.export(path='./', filename_prefix='quantized_resnet18', dummy_input=dummy_input.cpu())