Quantization-Aware Training

This notebook shows a working code example of how to use AIMET to perform QAT (Quantization-aware training). QAT is an AIMET feature adding quantization simulation ops (also called fake quantization ops sometimes) to a trained ML model and using a standard training pipeline to fine-tune or train the model for a few epochs. The resulting model should show improved accuracy on quantized ML accelerators.

AIMET supports two different types of QAT 1. Simply referred to as QAT - quantization parameters like per-tensor scale/offsets for activations are computed once. During fine-tuning, the model weights are updated to minimize the effects of quantization in the forward pass, keeping the quantization parameters constant. 2. Referred to as QAT with range-learning - quantization parameters like per-tensor scale/offsets for activations are computed initially. Then both the quantization parameters and the model weights are jointly updated during fine-tuning to minimize the effects of quantization in the forward pass.

This notebook specifically shows working code example for #1 above. You can find a separate notebook for #2 in the same folder.

Overall flow

This notebook covers the following 1. Instantiate the example evaluation and training pipeline 2. Load the FP32 model and evaluate the model to find the baseline FP32 accuracy 3. Create a quantization simulation model (with fake quantization ops inserted) and evaluate this simuation model to get a quantized accuracy score 4. Fine-tune the quantization simulation model and evaluate the simulation model to get a post-finetuned quantized accuracy score

What this notebook is not

  • This notebook is not designed to show state-of-the-art QAT results. For example, it uses a relatively quantization-friendly model like Resnet18. Also, some optimization parameters like number of epochs to fine-tune are deliberately chosen to have the notebook execute more quickly.


Dataset

This notebook relies on the ImageNet dataset for the task of image classification. If you already have a version of the dataset readily available, please use that. Else, please download the dataset from appropriate location (e.g. https://image-net.org/challenges/LSVRC/2012/index.php#).

Note1: The ImageNet dataset typically has the following characteristics and the dataloader provided in this example notebook rely on these - Subfolders ‘train’ for the training samples and ‘val’ for the validation samples. Please see the pytorch dataset description for more details. - A subdirectory per class, and a file per each image sample

Note2: To speed up the execution of this notebook, you may use a reduced subset of the ImageNet dataset. E.g. the entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. But for the purpose of running this notebook, you could perhaps reduce the dataset to say 2 samples per class. This exercise is left upto the reader and is not necessary.

Edit the cell below and specify the directory where the downloaded ImageNet dataset is saved.

[ ]:
DATASET_DIR = '/path/to/dataset/'         # Please replace this with a real directory

1. Example evaluation and training pipeline

The following is an example training and validation loop for this image classification task.

  • Does AIMET have any limitations on how the training, validation pipeline is written? Not really. We will see later that AIMET will modify the user’s model to create a QuantizationSim model which is still a PyTorch model. This QuantizationSim model can be used in place of the original model when doing inference or training.

  • Does AIMET put any limitation on the interface of the evaluate() or train() methods? Not really. You should be able to use your existing evaluate and train routines as-is.

[ ]:
import os
import torch
from Examples.common import image_net_config
from Examples.torch.utils.image_net_evaluator import ImageNetEvaluator
from Examples.torch.utils.image_net_trainer import ImageNetTrainer
from Examples.torch.utils.image_net_data_loader import ImageNetDataLoader

class ImageNetDataPipeline:

    @staticmethod
    def get_val_dataloader() -> torch.utils.data.DataLoader:
        """
        Instantiates a validation dataloader for ImageNet dataset and returns it
        """
        data_loader = ImageNetDataLoader(DATASET_DIR,
                                         image_size=image_net_config.dataset['image_size'],
                                         batch_size=image_net_config.evaluation['batch_size'],
                                         is_training=False,
                                         num_workers=image_net_config.evaluation['num_workers']).data_loader
        return data_loader

    @staticmethod
    def evaluate(model: torch.nn.Module, use_cuda: bool) -> float:
        """
        Given a torch model, evaluates its Top-1 accuracy on the dataset
        :param model: the model to evaluate
        :param use_cuda: whether or not the GPU should be used.
        """
        evaluator = ImageNetEvaluator(DATASET_DIR, image_size=image_net_config.dataset['image_size'],
                                      batch_size=image_net_config.evaluation['batch_size'],
                                      num_workers=image_net_config.evaluation['num_workers'])

        return evaluator.evaluate(model, iterations=None, use_cuda=use_cuda)

    @staticmethod
    def finetune(model: torch.nn.Module, epochs, learning_rate, learning_rate_schedule, use_cuda):
        """
        Given a torch model, finetunes the model to improve its accuracy
        :param model: the model to finetune
        :param epochs: The number of epochs used during the finetuning step.
        :param learning_rate: The learning rate used during the finetuning step.
        :param learning_rate_schedule: The learning rate schedule used during the finetuning step.
        :param use_cuda: whether or not the GPU should be used.
        """
        trainer = ImageNetTrainer(DATASET_DIR, image_size=image_net_config.dataset['image_size'],
                                  batch_size=image_net_config.train['batch_size'],
                                  num_workers=image_net_config.train['num_workers'])

        trainer.train(model, max_epochs=epochs, learning_rate=learning_rate,
                      learning_rate_schedule=learning_rate_schedule, use_cuda=use_cuda)

2. Load the model and evaluate to get a baseline FP32 accuracy score

For this example notebook, we are going to load a pretrained resnet18 model from torchvision. Similarly, you can load any pretrained PyTorch model instead.

[ ]:
from torchvision.models import resnet18

model = resnet18(pretrained=True)

AIMET quantization simulation requires the user’s model definition to follow certain guidelines. For example, functionals defined in forward pass should be changed to equivalent torch.nn.Module. AIMET user guide lists all these guidelines. The following ModelPreparer API uses new graph transformation feature available in PyTorch 1.9+ version and automates model definition changes required to comply with the above guidelines.

[ ]:
from aimet_torch.model_preparer import prepare_model

model = prepare_model(model)

We should decide whether to place the model on a CPU or CUDA device. This example code will use CUDA if available in your current execution environment. You can change this logic and force a device placement if needed.

[ ]:
use_cuda = False
if torch.cuda.is_available():
    use_cuda = True
    model.to(torch.device('cuda'))

Let’s determine the FP32 (floating point 32-bit) accuracy of this model using the evaluate() routine

[ ]:
accuracy = ImageNetDataPipeline.evaluate(model, use_cuda)
print(accuracy)

3. Create a quantization simulation model and determine quantized accuracy

Fold Batch Normalization layers

Before we determine the simulated quantized accuracy using QuantizationSimModel, we will fold the BatchNormalization (BN) layers in the model. These layers get folded into adjacent Convolutional layers. The BN layers that cannot be folded are left as they are.

Why do we need to this? On quantized runtimes (like TFLite, SnapDragon Neural Processing SDK, etc.), it is a common practice to fold the BN layers. Doing so, results in an inferences/sec speedup since unnecessary computation is avoided. Now from a floating point compute perspective, a BN-folded model is mathematically equivalent to a model with BN layers from an inference perspective, and produces the same accuracy. However, folding the BN layers can increase the range of the tensor values for the weight parameters of the adjacent layers. And this can have a negative impact on the quantized accuracy of the model (especially when using INT8 or lower precision). So, we want to simulate that on-target behavior by doing BN folding here.

The following code calls AIMET to fold the BN layers in-place on the given model

[ ]:
from aimet_torch.batch_norm_fold import fold_all_batch_norms

_ = fold_all_batch_norms(model, input_shapes=(1, 3, 224, 224))

Create Quantization Sim Model

Now we use AIMET to create a QuantizationSimModel. This basically means that AIMET will insert fake quantization ops in the model graph and will configure them. A few of the parameters are explained here - quant_scheme: We set this to “QuantScheme.post_training_tf_enhanced” - Supported options are ‘tf_enhanced’ or ‘tf’ or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced - default_output_bw: Setting this to 8, essentially means that we are asking AIMET to perform all activation quantizations in the model using integer 8-bit precision - default_param_bw: Setting this to 8, essentially means that we are asking AIMET to perform all parameter quantizations in the model using integer 8-bit precision

There are other parameters that are set to default values in this example. Please check the AIMET API documentation of QuantizationSimModel to see reference documentation for all the parameters.

[ ]:
from aimet_common.defs import QuantScheme
from aimet_torch.quantsim import QuantizationSimModel

dummy_input = torch.rand(1, 3, 224, 224)    # Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)
if use_cuda:
    dummy_input = dummy_input.cuda()

sim = QuantizationSimModel(model=model,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           dummy_input=dummy_input,
                           default_output_bw=8,
                           default_param_bw=8)

We can check the modifications AIMET has made to the model graph. One way is to print the model, and we can see that AIMET has added quantization wrapper layers. Note: use sim.model to access the modified PyTorch model. By default, AIMET creates a copy of the original model prior to modifying it. There is a parameter to override this behavior.

[ ]:
print(sim.model)

We can also check how AIMET has configured the added fake quantization nodes, which AIMET refers to as ‘quantizers’. You can see this by printing the sim object.

[ ]:
print(sim)

Even though AIMET has added ‘quantizer’ nodes to the model graph but the model is not ready to be used yet. Before we can use the sim model for inference or training, we need to find appropriate scale/offset quantization parameters for each ‘quantizer’ node. For activation quantization nodes, we need to pass unlabeled data samples through the model to collect range statistics which will then let AIMET calculate appropriate scale/offset quantization parameters. This process is sometimes referred to as calibration. AIMET simply refers to it as ‘computing encodings’.

So we create a routine to pass unlabeled data samples through the model. This should be fairly simple - use the existing train or validation data loader to extract some samples and pass them to the model. We don’t need to compute any loss metric etc. So we can just ignore the model output for this purpose. A few pointers regarding the data samples - In practice, we need a very small percentage of the overall data samples for computing encodings. For example, the training dataset for ImageNet has 1M samples. For computing encodings we only need 500 or 1000 samples. - It may be beneficial if the samples used for computing encoding are well distributed. It’s not necessary that all classes need to be covered etc. since we are only looking at the range of values at every layer activation. However, we definitely want to avoid an extreme scenario like all ‘dark’ or ‘light’ samples are used - e.g. only using pictures captured at night might not give ideal results.

The following shows an example of a routine that passes unlabeled samples through the model for computing encodings. This routine can be written in many different ways, this is just an example.

[ ]:
def pass_calibration_data(sim_model, use_cuda):
    data_loader = ImageNetDataPipeline.get_val_dataloader()
    batch_size = data_loader.batch_size

    if use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    sim_model.eval()
    samples = 1000

    batch_cntr = 0
    with torch.no_grad():
        for input_data, target_data in data_loader:

            inputs_batch = input_data.to(device)
            sim_model(inputs_batch)

            batch_cntr += 1
            if (batch_cntr * batch_size) > samples:
                break

Now we call AIMET to use the above routine to pass data through the model and then subsequently compute the quantization encodings. Encodings here refer to scale/offset quantization parameters.

[ ]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=use_cuda)

Now the QuantizationSim model is ready to be used for inference or training. First we can pass this model to the same evaluation routine we used before. The evaluation routine will now give us a simulated quantized accuracy score for INT8 quantization instead of the FP32 accuracy score we saw before.

[ ]:
accuracy = ImageNetDataPipeline.evaluate(sim.model, use_cuda)
print(accuracy)

4. Perform QAT

To perform quantization aware training (QAT), we simply train the model for a few more epochs (typically 15-20). As with any training job, hyper-parameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so.

For the purpose of this example notebook, we are going to train only for 1 epoch. But feel free to change these parameters as you see fit.

[ ]:
ImageNetDataPipeline.finetune(sim.model, epochs=1, learning_rate=5e-7, learning_rate_schedule=[5, 10], use_cuda=use_cuda)

After we are done with QAT, we can run quantization simulation inference against the validation dataset at the end to observe any improvements in accuracy.

[ ]:
finetuned_accuracy = ImageNetDataPipeline.evaluate(sim.model, use_cuda)
print(finetuned_accuracy)

Depending on your settings you may have observed a slight gain in accuracy after one epoch of training. Ofcourse, this was just an example. Please try this against the model of your choice and play with the hyper-parameters to get the best results.

So we have an improved model after QAT. Now the next step would be to actually take this model to target. For this purpose, we need to export the model with the updated weights without the fake quant ops. And also to export the encodings (scale/offset quantization parameters) that were updated during training since we employed QAT with range-learning. AIMET QuantizationSimModel provides an export API for this purpose.

[ ]:
os.makedirs('./output/', exist_ok=True)
dummy_input = dummy_input.cpu()
sim.export(path='./output/', filename_prefix='resnet18_after_qat', dummy_input=dummy_input)

Summary

Hope this notebook was useful for you to understand how to use AIMET for performing QAT.

Few additional resources - Refer to the AIMET API docs to know more details of the APIs and optional parameters. - Refer to the other example notebooks to understand how to use AIMET post-training quantization techniques and QAT with range-learning.