Quantization-Aware Training with BatchNorm Re-estimation

This notebook shows a working code example of how to use AIMET to perform QAT (Quantization-aware training) with batchnorm re-estimation. Batchnorm re-estimation is a technique for countering potential instability of batchnrom statistics (i.e. running mean and variance) during QAT. More specifically, batchnorm re-estimation recalculates the batchnorm statistics based on the model after QAT. By doing so, we aim to make our model learn batchnorm statistics from from stable outputs after QAT, rather than from likely noisy outputs during QAT.

Overall flow

This notebook covers the following steps: 1. Create a quantization simulation model with fake quantization ops inserted. 2. Finetune and evaluate the quantization simulation model 3. Re-estimate batchnorm statistics and compare the eval score before and after re-estimation. 4. Fold the re-estimated batchnorm layers and export the quantization simulation model

What this notebook is not

In this notebook, we will focus how to apply batchnorm re-estimation after QAT, rather than covering all the details about QAT itself. For more information about QAT, please refer to QAT notebook or QAT range learning notebook.


Dataset

This notebook relies on the ImageNet dataset for the task of image classification. If you already have a version of the dataset readily available, please use that. Else, please download the dataset from appropriate location (e.g. https://image-net.org/challenges/LSVRC/2012/index.php#) and convert them into tfrecords.

Note1: The ImageNet tfrecords dataset typically has the following characteristics and the dataloader provided in this example notebook rely on these - A folder containing tfrecords files starting with ‘train*’ for training files and ‘valid*’ for validation files. Each tfrecord file should have features: ‘image/encoded’ for image data and ‘image/class/label’ for its corresponding class.

Note2: To speed up the execution of this notebook, you may use a reduced subset of the ImageNet dataset. E.g. the entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. But for the purpose of running this notebook, you could perhaps reduce the dataset to say 2 samples per class and then convert it into tfrecords. This exercise is left upto the reader and is not necessary.

Edit the cell below and specify the directory where the downloaded ImageNet dataset is saved.

[ ]:
TFRECORDS_DIR = '/path/to/dataset/'         # Please replace this with a real directory

1. Example evaluation and training pipeline

The following is an example training and validation loop for this image classification task.

  • Does AIMET have any limitations on how the training, validation pipeline is written? Not really. We will see later that AIMET will modify the user’s session graph to create a QuantizationSim model which is still a Tensorflow graph. This QuantizationSim model can be used in place of the original model when doing inference or training.

  • Does AIMET put any limitation on the interface of the evaluate() or train() methods? Not really. You should be able to use your existing evaluate and train routines as-is.

[ ]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
tf.logging.set_verbosity(tf.logging.ERROR)
from typing import List

from Examples.common import image_net_config
from Examples.tensorflow.utils.image_net_evaluator import ImageNetDataLoader
from Examples.tensorflow.utils.image_net_evaluator import ImageNetEvaluator
from Examples.tensorflow.utils.image_net_trainer import ImageNetTrainer

class ImageNetDataPipeline:
    """
    Provides APIs for model evaluation and finetuning using ImageNet Dataset.
    """

    @staticmethod
    def get_val_dataloader():
        """
        Instantiates a validation dataloader for ImageNet dataset and returns it
        """
        data_loader = ImageNetDataLoader(TFRECORDS_DIR,
                                         image_size=image_net_config.dataset['image_size'],
                                         batch_size=image_net_config.evaluation['batch_size'],
                                         format_bgr=True)

        return data_loader

    @staticmethod
    def evaluate(sess: tf.Session) -> float:
        """
        Given a TF session, evaluates its Top-1 accuracy on the validation dataset
        :param sess: The sess graph to be evaluated.
        :return: The accuracy for the sample with the maximum accuracy.
        """
        evaluator = ImageNetEvaluator(TFRECORDS_DIR, training_inputs=['keras_learning_phase:0'],
                                      data_inputs=['input_1:0'], validation_inputs=['labels:0'],
                                      image_size=image_net_config.dataset['image_size'],
                                      batch_size=image_net_config.evaluation['batch_size'],
                                      format_bgr=True)

        return evaluator.evaluate(sess)


    @staticmethod
    def finetune(sess: tf.Session, update_ops_name: List[str], epochs: int, learning_rate: float, decay_steps: int):
        """
        Given a TF session, finetunes it to improve its accuracy
        :param sess: The sess graph to fine-tune.
        :param update_ops_name: list of name of update ops (mostly BatchNorms' moving averages).
                                tf.GraphKeys.UPDATE_OPS collections is always used
                                in addition to this list
        :param epochs: The number of epochs used during the finetuning step.
        :param learning_rate: The learning rate used during the finetuning step.
        :param decay_steps: A number used to adjust(decay) the learning rate after every decay_steps epochs in training.
        """
        trainer = ImageNetTrainer(TFRECORDS_DIR, training_inputs=['keras_learning_phase:0'],
                                  data_inputs=['input_1:0'], validation_inputs=['labels:0'],
                                  image_size=image_net_config.dataset['image_size'],
                                  batch_size=image_net_config.train['batch_size'],
                                  num_epochs=epochs, format_bgr=True)

        trainer.train(sess, update_ops_name=update_ops_name, learning_rate=learning_rate, decay_steps=decay_steps)

2. Load FP32 model

AIMET currently support BatchNorm Re-estimation on Tensorflow sessions. In this example notebook, we are going to load a pretrained ResNet50 model from keras and covert it to work with Tensorflow session. Similarly, you can load any pretrained Tensorflow model. Please refer to QAT notebook for more detail.

[ ]:
from tensorflow.compat.v1.keras.applications.resnet import ResNet50

tf.keras.backend.clear_session()
model = ResNet50(weights='imagenet', input_shape=(224, 224, 3))
sess = tf.keras.backend.get_session()

# Following lines are additional steps to make keras model work with AIMET.
from Examples.tensorflow.utils.add_computational_nodes_in_graph import add_image_net_computational_nodes_in_graph
add_image_net_computational_nodes_in_graph(sess, model.output.name, image_net_config.dataset['images_classes'])

We need names of input and output of the model to work with AIMET.

[ ]:
input_op_names = [model.input.op.name]
output_op_names = [model.output.op.name]

BatchNorm Rewriter

In the later notebook, we will make changes to parameters of BatchNorms to improve performance. However, depending on how the BatchNorm was configured, this might be difficult to achieve.

AIMET provides model_sess_bn_mutable that changes BatchNorm layer to make it easier to modify parameters.

[ ]:
from aimet_tensorflow.utils.op.bn_mutable import modify_sess_bn_mutable
modify_sess_bn_mutable(sess, input_op_names, output_op_names, training_tf_placeholder=False)

Let’s determine the FP32 (floating point 32-bit) accuracy of this model using the evaluate() routine

[ ]:
accuracy = ImageNetDataPipeline.evaluate(sess=sess)
print(accuracy)

3. Create a quantization simulation model and Perform QAT

Create Quantization Sim Model

Now we use AIMET to create a QuantizationSimModel. This basically means that AIMET will insert fake quantization ops in the model graph and will configure them. A few of the parameters are explained here - quant_scheme: We set this to “QuantScheme.post_training_tf_enhanced” - Supported options are ‘tf_enhanced’ or ‘tf’ or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced - default_output_bw: Setting this to 8, essentially means that we are asking AIMET to perform all activation quantizations in the model using integer 8-bit precision - default_param_bw: Setting this to 8, essentially means that we are asking AIMET to perform all parameter quantizations in the model using integer 8-bit precision

There are other parameters that are set to default values in this example. Please check the AIMET API documentation of QuantizationSimModel to see reference documentation for all the parameters.

NOTE: Note that, unlike in other QAT example scripts, we didn’t fold batchnorm layers before QAT. This is because we aim to finetune our model with batchnorm layers present and re-estimate the batchnorm statatistics for better accuracy. The batchnorm layers will be folded after re-estimation.

[ ]:
import json
from aimet_common.defs import QuantScheme
from aimet_tensorflow.quantsim import QuantizationSimModel

default_config_per_channel = {
            "defaults":
                {
                    "ops":
                        {
                            "is_output_quantized": "True"
                        },
                    "params":
                        {
                            "is_quantized": "True",
                            "is_symmetric": "True"
                        },
                    "strict_symmetric": "False",
                    "unsigned_symmetric": "True",
                    "per_channel_quantization": "True"
                },

            "params":
                {
                    "bias":
                        {
                            "is_quantized": "False"
                        }
                },

            "op_type":
                {
                    "Squeeze":
                        {
                            "is_output_quantized": "False"
                        },
                    "Pad":
                        {
                            "is_output_quantized": "False"
                        },
                    "Mean":
                        {
                            "is_output_quantized": "False"
                        }
                },

            "supergroups":
                [
                    {
                        "op_list": ["Conv", "Relu"]
                    },
                    {
                        "op_list": ["Conv", "Clip"]
                    },
                    {
                        "op_list": ["Conv", "BatchNormalization", "Relu"]
                    },
                    {
                        "op_list": ["Add", "Relu"]
                    },
                    {
                        "op_list": ["Gemm", "Relu"]
                    }
                ],

            "model_input":
                {
                    "is_input_quantized": "True"
                },

            "model_output":
                {}
        }

config_file_path = "/tmp/default_config_per_channel.json"
with open(config_file_path, "w") as f:
    json.dump(default_config_per_channel, f)

sim = QuantizationSimModel(sess, input_op_names, output_op_names, use_cuda=True,
                                   quant_scheme=QuantScheme.training_range_learning_with_tf_init,
                                   config_file=config_file_path)

Compute Encodings

Even though AIMET has added ‘quantizer’ nodes to the model graph, the model is not ready to be used yet. Before we can use the sim model for inference or training, we need to find appropriate scale/offset quantization parameters for each ‘quantizer’ node. For activation quantization nodes, we need to pass unlabeled data samples through the model to collect range statistics which will then let AIMET calculate appropriate scale/offset quantization parameters. This process is sometimes referred to as calibration. AIMET simply refers to it as ‘computing encodings’.

So we create a routine to pass unlabeled data samples through the model. This should be fairly simple - use the existing train or validation data loader to extract some samples and pass them to the model. We don’t need to compute any loss metric etc. So we can just ignore the model output for this purpose. A few pointers regarding the data samples - In practice, we need a very small percentage of the overall data samples for computing encodings. For example, the training dataset for ImageNet has 1M samples. For computing encodings we only need 500 or 1000 samples. - It may be beneficial if the samples used for computing encoding are well distributed. It’s not necessary that all classes need to be covered etc. since we are only looking at the range of values at every layer activation. However, we definitely want to avoid an extreme scenario like all ‘dark’ or ‘light’ samples are used - e.g. only using pictures captured at night might not give ideal results.

The following shows an example of a routine that passes unlabeled samples through the model for computing encodings. This routine can be written in many different ways, this is just an example.

[ ]:
def pass_calibration_data(session: tf.compat.v1.Session, _):
    data_loader = ImageNetDataPipeline.get_val_dataloader()
    batch_size = data_loader.batch_size

    input_label_tensors = [session.graph.get_tensor_by_name('input_1:0'),
                           session.graph.get_tensor_by_name('labels:0')]

    train_tensors = [session.graph.get_tensor_by_name('keras_learning_phase:0')]
    train_tensors_dict = dict.fromkeys(train_tensors, False)

    eval_outputs = [session.graph.get_operation_by_name('top1-acc').outputs[0]]

    samples = 500

    batch_cntr = 0
    for input_label in data_loader:
        input_label_tensors_dict = dict(zip(input_label_tensors, input_label))

        feed_dict = {**input_label_tensors_dict, **train_tensors_dict}

        with session.graph.as_default():
            _ = session.run(eval_outputs, feed_dict=feed_dict)

        batch_cntr += 1
        if (batch_cntr * batch_size) > samples:
            break

sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=None)

Now the QuantizationSim model is ready to be used for inference or training. First we can pass this model to the same evaluation routine we used before. The evaluation routine will now give us a simulated quantized accuracy score for INT8 quantization instead of the FP32 accuracy score we saw before.

[ ]:
accuracy = ImageNetDataPipeline.evaluate(sim.session)
print(accuracy)

Perform QAT

To perform quantization aware training (QAT), we simply train the model for a few more epochs (typically 15-20). As with any training job, hyper-parameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so.

For the purpose of this example notebook, we are going to train only for 1 epoch. But feel free to change these parameters as you see fit.

[ ]:
update_ops_name = [op.name for op in model.updates] # Used for finetuning
ImageNetDataPipeline.finetune(sim.session, update_ops_name=update_ops_name, epochs=1, learning_rate=5e-7, decay_steps=5)

After we are done with QAT, we can run quantization simulation inference against the validation dataset at the end to observe any improvements in accuracy.

[ ]:
finetuned_accuracy  = ImageNetDataPipeline.evaluate(sim.session)
print(finetuned_accuracy)

4. Perform BatchNorm Reestimation

Re-estimate BatchNorm Statistics

AIMET provides a helper function, reestimate_bn_stats, for re-estimating batchnorm statistics. Here is the full list of parameters for this function: * model: Model to re-estimate the BatchNorm statistics. * dataloader Train dataloader. * num_batches (optional): The number of batches to be used for reestimation. (Default: 100) * forward_fn (optional): Optional adapter function that performs forward pass given a model and a input batch yielded from the data loader. If not specified, it is expected that inputs yielded from dataloader can be passed directly to the model.

[ ]:
from aimet_tensorflow.bn_reestimation import reestimate_bn_stats
import numpy as np

data_loader = ImageNetDataLoader(TFRECORDS_DIR,
                                         image_size=image_net_config.dataset['image_size'],
                                         batch_size=image_net_config.evaluation['batch_size'],
                                         format_bgr=True)

arrays=[]
for input_label in data_loader:
    arrays.append(input_label[0])
real_inputs = np.vstack(arrays)

dataset = tf.compat.v1.data.Dataset.from_tensor_slices(real_inputs)
bn_re_restimation_dataset = dataset.batch(32)

reestimate_bn_stats(sim, start_op_names=input_op_names, output_op_names=output_op_names,
                    bn_re_estimation_dataset=bn_re_restimation_dataset, bn_num_batches=100)

finetuned_accuracy_bn_reestimated = ImageNetDataPipeline.evaluate(sim.session)
print(finetuned_accuracy_bn_reestimated)

Fold BatchNorm Layers

So far, we have improved our quantization simulation model through QAT and batchnorm re-estimation. The next step would be to actually take this model to target. But first, we should fold the batchnorm layers for our model to run on target devices more efficiently.

[ ]:
from aimet_tensorflow.batch_norm_fold import fold_all_batch_norms_to_scale

fold_all_batch_norms_to_scale(sim, input_op_names, output_op_names)

5. Export Model

As the final step, we will export the model to run it on actual target devices. AIMET QuantizationSimModel provides an export API for this purpose.

[ ]:
os.makedirs('./output/', exist_ok=True)
sim.export(path='./output/', filename_prefix='resnet50_after_qat')

Summary

Hope this notebook was useful for you to understand how to use batchnorm re-estimation feature of AIMET.

Few additional resources - Refer to the AIMET API docs to know more details of the APIs and optional parameters. - Refer to the other example notebooks to understand how to use AIMET post-training quantization techniques and QAT methods.