Adaptive Rounding (AdaRound)
This notebook contains a working example of AIMET adaptive rounding (AdaRound).
AIMET quantization features typically use the “nearest rounding” technique for achieving quantization. When using the nearest rounding technique, the weight value is quantized to the nearest integer value.
AdaRound optimizes a loss function using unlabeled training data to decide whether to quantize a specific weight to the closer integer value or the farther one. Using AdaRound, quantized accuracy is closer to the FP32 model than with nearest rounding.
Overall flow
The example follows these high-level steps:
Instantiate the example evaluation and training pipeline
Load the FP32 model and evaluate the model to find the baseline FP32 accuracy
Create a quantization simulation model (with fake quantization ops) and evaluate the quantized simuation model
Apply AdaRound and evaluate the simulation model to get a post-finetuned quantized accuracy score
Note
This notebook does not show state-of-the-art results. For example, it uses a relatively quantization-friendly model (Resnet18). Also, some optimization parameters like number of fine-tuning epochs are chosen to improve execution speed in the notebook.
Dataset
This example does image classification on the ImageNet dataset. If you already have a version of the data set, use that. Otherwise download the data set, for example from https://image-net.org/challenges/LSVRC/2012/index .
Note
The dataloader provided in this example relies on these features of the ImageNet data set:
Subfolders
train
for the training samples andval
for the validation samples. See the pytorch dataset description for more details.One subdirectory per class, and one file per image sample.
Note
To speed up the execution of this notebook, you can use a reduced subset of the ImageNet dataset. For example: The entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. However, for the purpose of running this notebook, you can reduce the dataset to, say, two samples per class.
Edit the cell below to specify the directory where the downloaded ImageNet dataset is saved.
[ ]:
DATASET_DIR = '/path/to/dataset/' # Replace this path with a real directory
1. Instantiate the example training and validation pipeline
Use the following training and validation loop for the image classification task.
Things to note:
AIMET does not put limitations on how the training and validation pipeline is written. AIMET modifies the user’s model to create a QuantizationSim model, which is still a PyTorch model. The QuantizationSim model can be used in place of the original model when doing inference or training.
AIMET doesn not put limitations on the interface of the
evaluate()
ortrain()
methods. You should be able to use your existing evaluate and train routines as-is.
[ ]:
import os
import torch
from Examples.common import image_net_config
from Examples.torch.utils.image_net_evaluator import ImageNetEvaluator
from Examples.torch.utils.image_net_data_loader import ImageNetDataLoader
class ImageNetDataPipeline:
@staticmethod
def get_val_dataloader() -> torch.utils.data.DataLoader:
"""
Instantiates a validation dataloader for ImageNet dataset and returns it
"""
data_loader = ImageNetDataLoader(DATASET_DIR,
image_size=image_net_config.dataset['image_size'],
batch_size=image_net_config.evaluation['batch_size'],
is_training=False,
num_workers=image_net_config.evaluation['num_workers']).data_loader
return data_loader
@staticmethod
def evaluate(model: torch.nn.Module, use_cuda: bool) -> float:
"""
Given a torch model, evaluates its Top-1 accuracy on the dataset
:param model: the model to evaluate
:param use_cuda: whether or not the GPU should be used.
"""
evaluator = ImageNetEvaluator(DATASET_DIR, image_size=image_net_config.dataset['image_size'],
batch_size=image_net_config.evaluation['batch_size'],
num_workers=image_net_config.evaluation['num_workers'])
return evaluator.evaluate(model, iterations=None, use_cuda=use_cuda)
2. Load the model and evaluate to get a baseline FP32 accuracy score
2.1 Load a pretrained resnet18 model from torchvision.
You can load any pretrained PyTorch model instead.
[ ]:
from torchvision.models import resnet18
model = resnet18(pretrained=True)
AIMET quantization simulation requires the model definition to follow certain guidelines. For example, functionals defined in the forward pass should be changed to the equivalent torch.nn.Module. The AIMET user guide lists all these guidelines.
2.2 Use the following ModelPreparer API call to automate the model definition changes required to comply with the AIMET guidelines.
The call uses the graph transformation feature available in PyTorch 1.9+.
[ ]:
from aimet_torch.model_preparer import prepare_model
model = prepare_model(model)
2.3 Decide whether to place the model on a CPU or CUDA device.
This example uses CUDA if it is available. You can change this logic and force a device placement if needed.
[ ]:
use_cuda = False
if torch.cuda.is_available():
use_cuda = True
model.to(torch.device('cuda'))
2.4 Compute the floating point 32-bit (FP32) accuracy of this model using the evaluate() routine.
[ ]:
accuracy = ImageNetDataPipeline.evaluate(model, use_cuda)
print(accuracy)
3. Create a quantization simulation model and determine quantized accuracy
Fold Batch Normalization layers
Before calculating the simulated quantized accuracy using QuantizationSimModel, fold the BatchNormalization (BN) layers into adjacent Convolutional layers. The BN layers that cannot be folded are left as they are.
BN folding improves inference performance on quantized runtimes but can degrade accuracy on these platforms. This step simulates this on-target drop in accuracy.
3.1 Use the following code to call AIMET to fold the BN layers in-place on the given model.
[ ]:
from aimet_torch.batch_norm_fold import fold_all_batch_norms
_ = fold_all_batch_norms(model, input_shapes=(1, 3, 224, 224))
Create the Quantization Sim Model
3.2 Use AIMET to create a QuantizationSimModel.
In this step, AIMET inserts fake quantization ops in the model graph and configures them.
Key parameters:
Setting default_output_bw to 8 performs all activation quantizations in the model using integer 8-bit precision
Setting default_param_bw to 8 performs all parameter quantizations in the model using integer 8-bit precision
See QuantizationSimModel in the AIMET API documentation for a full explanation of the parameters.
[ ]:
from aimet_common.defs import QuantScheme
from aimet_torch.v1.quantsim import QuantizationSimModel
dummy_input = torch.rand(1, 3, 224, 224) # Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)
if use_cuda:
dummy_input = dummy_input.cuda()
sim = QuantizationSimModel(model=model,
quant_scheme=QuantScheme.post_training_tf_enhanced,
dummy_input=dummy_input,
default_output_bw=8,
default_param_bw=8)
3.3 Print the model to verify the modifications AIMET has made.
Note that AIMET has added quantization wrapper layers.
Note
Use sim.model to access the modified PyTorch model. By default, AIMET creates a copy of the original model prior to modifying it. There is a parameter to override this behavior.
[ ]:
print(sim.model)
Note also that AIMET has configured the added fake quantization nodes, which AIMET refers to as “quantizers”.
3.4 Print the sim object to see the quantizers.
[ ]:
print(sim)
AIMET has added quantizer nodes to the model graph, but before the sim model can be used for inference or training, scale and offset quantization parameters must be calculated for each quantizer node by passing unlabeled data samples through the model to collect range statistics. This process is sometimes referred to as calibration. AIMET refers to it as “computing encodings”.
3.5 Create a routine to pass unlabeled data samples through the model.
The following code is one way to write a routine that passes unlabeled samples through the model to compute encodings. It uses the existing train or validation data loader to extract samples and pass them to the model. Since there is no need to compute loss metrics, it ignores the model output.
[ ]:
def pass_calibration_data(sim_model, use_cuda):
data_loader = ImageNetDataPipeline.get_val_dataloader()
batch_size = data_loader.batch_size
if use_cuda:
device = torch.device('cuda')
else:
device = torch.device('cpu')
sim_model.eval()
samples = 1000
batch_cntr = 0
with torch.no_grad():
for input_data, target_data in data_loader:
inputs_batch = input_data.to(device)
sim_model(inputs_batch)
batch_cntr += 1
if (batch_cntr * batch_size) > samples:
break
A few notes regarding the data samples:
A very small percentage of the data samples are needed. For example, the training dataset for ImageNet has 1M samples; 500 or 1000 suffice to compute encodings.
The samples should be reasonably well distributed. While it’s not necessary to cover all classes, avoid extreme scenarios like using only dark or only light samples. That is, using only pictures captured at night, say, could skew the results.
3.6 Call AIMET to use the routine to pass data through the model and compute the quantization encodings.
Encodings here refer to scale and offset quantization parameters.
[ ]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
forward_pass_callback_args=use_cuda)
The QuantizationSim model is now ready to be used for inference or training.
3.7 Pass the model to the same evaluation routine as before to calculate a simulated quantized accuracy score for INT8 quantization for comparison with the FP32 score.
[ ]:
accuracy = ImageNetDataPipeline.evaluate(sim.model, use_cuda)
print(accuracy)
4. Apply Adaround
4.1 Use the code below to apply Adaround to the original model.
Some key parameters:
dataloader: is a training or validation dataloader. Adaround needs a dataloader in order to use data samples to learn the rounding vectors.
num_batches: is the number of batches used while calculating the quantization encodings. A typical value for Adaround is 2000 samples. To speed up the execution this example uses a batch size of one.
default_num_iterations: is the number of iterations to apply to each layer. Default value is 10000, and we strongly recommend using at least this number. This example uses 32 to speed up execution.
[ ]:
from aimet_torch.v1.adaround.adaround_weight import Adaround, AdaroundParameters
data_loader = ImageNetDataPipeline.get_val_dataloader()
params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=32)
dummy_input = torch.rand(1, 3, 224, 224)
if use_cuda:
dummy_input = dummy_input.cuda()
os.makedirs('./output/', exist_ok=True)
ada_model = Adaround.apply_adaround(model, dummy_input, params,
path="output",
filename_prefix='adaround',
default_param_bw=8,
default_quant_scheme=QuantScheme.post_training_tf_enhanced)
4.2 Quantize the Adarounded model.
Note
Two important points about the following code:
Parameter Biwidth Precision: The QuantizationSimModel must be created with the same parameter bitwidth precision that was used in
apply_adaround()
.Freezing the parameter encodings: After creating the QuantizationSimModel, you must call
set_and_freeze_param_encodings()
before callingcompute_encodings()
. During AdaRound, the parameters are rounded based on these initial internally created encodings. To maintain accuracy, it is important to freeze these encodings so that the call tocompute_encodings()
does not alter the parameter encodings negate the AdaRounded accuracy.
[ ]:
sim = QuantizationSimModel(model=ada_model,
dummy_input=dummy_input,
quant_scheme=QuantScheme.post_training_tf_enhanced,
default_output_bw=8,
default_param_bw=8)
sim.set_and_freeze_param_encodings(encoding_path=os.path.join("output", 'adaround.encodings'))
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
forward_pass_callback_args=use_cuda)
4.3 Compute the accuracy of the Adarounded model.
Evaluate the simulation model as before to determine simulated quantized accuracy.
[ ]:
accuracy = ImageNetDataPipeline.evaluate(sim.model, use_cuda)
print(accuracy)
There might be little gain in accuracy after this limited application of Adaround. Experiment with the hyper-parameters to get better results.
Next steps
Export the model and encodings.
Export the model with the updated weights but without the fake quant ops.
Export the encodings (scale and offset quantization parameters). AIMET QuantizationSimModel provides an export API for this purpose.
The following code performs these exports.
[ ]:
os.makedirs('./output/', exist_ok=True)
dummy_input = dummy_input.cpu()
sim.export(path='./output/', filename_prefix='resnet18_after_cle_bc', dummy_input=dummy_input)
For more information
See the AIMET API docs for details about the AIMET APIs and optional parameters.
See the other example notebooks to learn how to use other AIMET post-training quantization techniques.