Quantization-Aware Training with Range Learning

This notebook shows a working code example of how to use AIMET to perform Quantization-Aware Training(QAT) with range-learning. QAT with range-learning is an AIMET feature adding quantization simulation ops to a pre-trained model and using a standard training pipeline to fine-tune both the model and quantization parameters for a few epochs. While QAT fine-tunes only the model parameters, QAT with range-learning also learns encoding min/max of parameter quantizers(hence the name range-learning). The resulting model should show improved accuracy on quantized ML accelerators.

The quantization parameters(like encoding min/max/scale/offset) for activations are computed once initially. During QAT, both the model weights and quantization parameters are jointly updated to minimize the effects of quantization in the forward pass.

Overall flow

This notebook covers the following 1. Instantiate the example evaluation and training pipeline 2. Load a pretrained FP32 model and determine the baseline FP32 accuracy 3. Create a quantization simulation model (with fake quantization ops inserted) and evaluate this simulation model to get a quantized accuracy score 4. Fine-tune the quantization simulation model using QAT with range-learning and evaluate the simulation model to get a post fine-tuned quantized accuracy score

What this notebook is not

  • This notebook is not designed to show state-of-the-art QAT results. For example, it uses a relatively quantization-friendly model like Resnet50. Also, some optimization parameters like number of epochs are deliberately chosen to have the notebook execute more quickly.

Example evaluation and training pipeline

The following is an example training and validation loop for this image classification task.

  • Does AIMET have any limitations on how the training, validation pipeline for QAT with range-learning is written? Yes, there is limitation only on the training pipeline due to restrictions of keras for range-learning. You cannot use a custom training loop to do QAT with range-learning. Doing so would prevent the encoding min/max from updating during training. Instead, the only way to achieve range-learning is to:

    1. Compile the quantization simulation model directly with sim.compile

    2. Run QAT directly on the simulation model with sim.fit

  • Does AIMET put any limitation on the interface of evaluate() or train() methods for QAT with range-learning? Only on the train method. You should be able to use your existing evaluation routine as-is, but there is a restriction on training as mentioned above

Dataset

This notebook relies on the ImageNet dataset for the task of image classification. If you already have a version of the dataset readily available, please use that. Else, please download the dataset from appropriate location (e.g. https://image-net.org/challenges/LSVRC/2012/index.php#)

Note: To speed up the execution of this notebook, you may use a reduced subset of the ImageNet dataset. E.g. the entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. But for the purpose of running this notebook, you could perhaps reduce the dataset to say 2 samples per class. This exercise is left upto the reader and is not necessary.

Edit the cell below and specify the directory where the downloaded ImageNet dataset is saved.

[ ]:
DATASET_DIR = '/path/to/dir'        # Please replace this with a real directory
BATCH_SIZE = 128
IMAGE_SIZE = (224, 224)
[ ]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf

1. Load the dataset

We defined a few utility functions and assign the training and validation dataset to dataset_train and dataset_valid respectively

[ ]:
dataset_train = dataset_valid = tf.keras.preprocessing.image_dataset_from_directory(
    directory=os.path.join(DATASET_DIR, "train"),
    labels="inferred",
    label_mode="categorical",
    batch_size=BATCH_SIZE,
    shuffle=True,
    image_size=IMAGE_SIZE
)
dataset_valid = tf.keras.preprocessing.image_dataset_from_directory(
    directory=os.path.join(DATASET_DIR, "val"),
    labels="inferred",
    label_mode="categorical",
    batch_size=BATCH_SIZE,
    shuffle=False,
    image_size=IMAGE_SIZE
)

2. Load a pretrained FP32 model

For this example notebook, we are going to load a pretrained ResNet50 model from Keras. Similarly, you can load any pretrained Keras model instead.

[ ]:
from tensorflow.keras.applications.resnet import ResNet50

model = ResNet50(weights='imagenet')
model.compile(optimizer="adam", loss="categorical_crossentropy")

3. Determine the baseline FP32 accuracy

Let’s determine the FP32 (floating point 32-bit) accuracy of this model using evaluate() routine

[ ]:
model.evaluate(dataset_valid)

4. Create a QuantizationSim Model and determine quantized accuracy

Fold Batch Normalization layers

Before we determine the simulated quantized accuracy using QuantizationSimModel, we will fold the BatchNormalization (BN) layers in the model. These layers get folded into adjacent Convolutional layers. The BN layers that cannot be folded are left as they are.

Why do we need to this? On quantized runtimes (like TFLite, SnapDragon Neural Processing SDK, etc.), it is a common practice to fold the BN layers. Doing so, results in an inferences/sec speedup since unnecessary computation is avoided. Now from a floating point compute perspective, a BN-folded model is mathematically equivalent to a model with BN layers from an inference perspective, and produces the same accuracy. However, folding the BN layers can increase the range of the tensor values for the weight parameters of the adjacent layers. And this can have a negative impact on the quantized accuracy of the model (especially when using INT8 or lower precision). So, we want to simulate that on-target behavior by doing BN folding here.

The following code calls AIMET to fold the BN layers of a given model. NOTE: During folding, a new model is returned. Please use the returned model for the rest of the pipeline.

[ ]:
from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms

_, model = fold_all_batch_norms(model)

Create Quantization Sim Model

Now we use AIMET to create a QuantizationSimModel. This basically means that AIMET will insert fake quantization ops in the model graph and will configure them. A few of the parameters are explained here - quant_scheme: We set this to “training_range_learning_with_tf_init” - This is the key setting that enables “range learning”. With this choice of quant scheme, AIMET will use the TF quant scheme to initialize the quantization parameters like scale/offset. And then those parameters are set to be trainable so they can continue to be updated during fine-tuning. - Another choice for quant_scheme is “training_range_learning_with_tf_enhanced_init”. Similar to the above, but the initialization for scale/offset is doing using the TF Enhanced scheme. Since in both schemes the quantization parameters are set to be trainable, there is not much benefit to using this choice instead of “training_range_learning_with_tf_init. - default_output_bw: Setting this to 8, essentially means that we are asking AIMET to perform all activation quantizations in the model using integer 8-bit precision - default_param_bw: Setting this to 8, essentially means that we are asking AIMET to perform all parameter quantizations in the model using integer 8-bit precision

There are other parameters that are set to default values in this example. Please check the AIMET API documentation of QuantizationSimModel to see reference documentation for all the parameters.

[ ]:
from aimet_tensorflow.keras.quantsim import QuantizationSimModel
from aimet_common.defs import QuantScheme

sim = QuantizationSimModel(model=model,
                           quant_scheme=QuantScheme.training_range_learning_with_tf_init,
                           rounding_mode="nearest",
                           default_output_bw=8,
                           default_param_bw=8)

Compute Encodings

Even though AIMET has wrapped the layers to act as being ‘quantized’ but the model is not ready to be used yet. Before we can use the sim model for inference or training, we need to find appropriate scale/offset quantization parameters for each ‘quantizer’ layer. For activation quantization layers, we need to pass unlabeled data samples through the model to collect range statistics which will then let AIMET calculate appropriate scale/offset quantization parameters. This process is sometimes referred to as calibration. AIMET simply refers to it as ‘computing encodings’.

So we create a routine to pass unlabeled data samples through the model. This should be fairly simple - use the existing train or validation data loader to extract some samples and pass them to the model. We don’t need to compute any loss metric etc. So we can just ignore the model output for this purpose. A few pointers regarding the data samples

In practice, we need a very small percentage of the overall data samples for computing encodings. For example, the training dataset for ImageNet has 1M samples. For computing encodings we only need 500 or 1000 samples. It may be beneficial if the samples used for computing encoding are well distributed. It’s not necessary that all classes need to be covered etc. since we are only looking at the range of values at every layer activation. However, we definitely want to avoid an extreme scenario like all ‘dark’ or ‘light’ samples are used - e.g. only using pictures captured at night might not give ideal results. The following shows an example of a routine that passes unlabeled samples through the model for computing encodings. This routine can be written in many ways, this is just an example.

[ ]:
from tensorflow.keras.utils import Progbar
from tensorflow.keras.applications.resnet import preprocess_input

def pass_calibration_data(sim_model, samples):
    dataset = dataset_valid
    progbar = Progbar(samples)

    batch_cntr = 0
    for inputs, _ in dataset:
        sim_model(preprocess_input(inputs))

        batch_cntr += 1
        progbar_stat_update = \
            batch_cntr * BATCH_SIZE if (batch_cntr * BATCH_SIZE) < samples else samples
        progbar.update(progbar_stat_update)
        if (batch_cntr * BATCH_SIZE) > samples:
            break

Now we call AIMET to use the above routine to pass data through the model and then subsequently compute the quantization encodings. Encodings here refer to scale/offset quantization parameters.

[ ]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=1000)

Compile the model

Configure the model for training and evaluation. The model must be compiled before evaluation.

[ ]:
sim.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

Evaluate the performance of the quantized model

Next, we can evaluate the performance of the quantized model

[ ]:
sim.evaluate(dataset_valid)

5. Perform QAT

To perform quantization aware training (QAT), we simply train the model for a few more epochs (typically 15-20). As with any training job, hyperparameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so. For the purpose of this example notebook, we are going to train only for 1 epoch. But feel free to change these parameters as you see fit.

[ ]:
quantized_callback = tf.keras.callbacks.TensorBoard(log_dir="./log/quantized")
history = sim.fit(dataset_train, epochs=1, validation_data=dataset_valid, callbacks=[quantized_callback])

6. Evaluate validation accuracy after QAT

Next, let’s evaluate the validation accuracy of our model after QAT

[ ]:
sim.evaluate(dataset_valid)

7. Export the encodings

Finally, let’s compute and export the encodings of the model after performing QAT. When comparing the encodings file generated by this step and the encodings generated before quantization, there should be some differences. These differences are an artifact of QAT.

[ ]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=1000)
sim.export('./data', 'model_after_qat')

Summary

Hope this notebook was useful for you to understand how to use AIMET for performing QAT with range-learning.

Few additional resources - Refer to the AIMET API docs to know more details of the APIs and optional parameters - Refer to the other example notebooks to understand how to use AIMET post-training quantization techniques and vanilla QAT(without range-learning)