AutoQuant

This notebook shows a working code example of how to use AIMET AutoQuant feature.

AIMET offers a suite of neural network post-training quantization techniques. Often, applying these techniques in a specific sequence, results in better accuracy and performance. Without the AutoQuant feature, the AIMET user needs to manually try out various combinations of AIMET quantization features. This manual process is error-prone and often time-consuming.

The AutoQuant feature, analyzes the model, determines the sequence of AIMET quantization techniques and applies these techniques. In addition, the user can specify the amount of accuracy drop that can be tolerated, in the AutoQuant API. As soon as this threshold accuracy is reached, AutoQuant stops applying any additional quantization technique. In summary, the AutoQuant feature saves time and automates the quantization of the neural networks.

Overall flow

This notebook covers the following 1. Instantiate the example evaluation and training pipeline 2. Load a pretrained FP32 model 3. Determine the baseline FP32 accuracy 4. Define constants and helper functions 5. Apply AutoQuant

What this notebook is not

  • This notebook is not designed to show state-of-the-art AutoQuant results. For example, it uses a relatively quantization-friendly model like Resnet18. Also, some optimization parameters are deliberately chosen to have the notebook execute more quickly.


Dataset

This notebook relies on the ImageNet dataset for the task of image classification. If you already have a version of the dataset readily available, please use that. Else, please download the dataset from appropriate location (e.g. https://image-net.org/challenges/LSVRC/2012/index.php#) and convert them into tfrecords.

Note1: The ImageNet tfrecords dataset typically has the following characteristics and the dataloader provided in this example notebook rely on these - A folder containing tfrecords files starting with ‘train’ for training files and ‘valid’ for validation files. Each tfrecord file should have features: ‘image/encoded’ for image data and ‘image/class/label’ for its corresponding class.

Note2: To speed up the execution of this notebook, you may use a reduced subset of the ImageNet dataset. E.g. the entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. But for the purpose of running this notebook, you could perhaps reduce the dataset to say 2 samples per class and then convert it into tfrecords. This exercise is left upto the reader and is not necessary.

Edit the cell below and specify the directory where the downloaded ImageNet dataset is saved.

[ ]:
DATASET_DIR = '/path/to/tfrecords/dir/'       # Please replace this with a real directory

We disable logs at the INFO level and disable eager execution. We set verbosity to the level as displayed (ERROR), so TensorFlow will display all messages that have the label ERROR (or more critical).

[ ]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

1. Example evaluation and training pipeline

The following is an example training and validation loop for this image classification task.

  • Does AIMET have any limitations on how the training, validation pipeline is written? Not really. We will see later that AIMET will modify the user’s model to create a QuantizationSim model which is still a TensorFlow model. This QuantizationSim model can be used in place of the original model when doing inference or training.

  • Does AIMET put any limitation on the interface of evaluate() or train() methods? Not really. You should be able to use your existing evaluate and train routines as-is.

[ ]:
from Examples.common import image_net_config
from Examples.tensorflow.utils.image_net_evaluator import ImageNetDataLoader
from Examples.tensorflow.utils.image_net_evaluator import ImageNetEvaluator

class ImageNetDataPipeline:
    """
    Provides APIs for model evaluation and fine-tuning using ImageNet Dataset.
    """

    @staticmethod
    def get_val_dataloader():
        """
        Instantiates a validation dataloader for ImageNet dataset and returns it
        """
        data_loader = ImageNetDataLoader(DATASET_DIR,
                                         image_size=image_net_config.dataset['image_size'],
                                         batch_size=image_net_config.evaluation['batch_size'],
                                         format_bgr=True)

        return data_loader

    @staticmethod
    def evaluate(sess: tf.compat.v1.Session) -> float:
        """
        Given a TF session, evaluates its Top-1 accuracy on the validation dataset
        :param sess: The sess graph to be evaluated.
        :return: The accuracy for the sample with the maximum accuracy.
        """
        evaluator = ImageNetEvaluator(DATASET_DIR, training_inputs=['keras_learning_phase:0'],
                                      data_inputs=['input_1:0'], validation_inputs=['labels:0'],
                                      image_size=image_net_config.dataset['image_size'],
                                      batch_size=image_net_config.evaluation['batch_size'],
                                      format_bgr=True)

        return evaluator.evaluate(sess)

2. Load a pretrained FP32 model

For this example notebook, we are going to load a pretrained ResNet50 model from keras and covert it to a tensorflow session. Similarly, you can load any pretrained tensorflow model instead.

Calling clear_session() releases the global state: this helps avoid clutter from old models and layers, especially when memory is limited.

[ ]:
from tensorflow.keras.applications.resnet import ResNet50

tf.keras.backend.clear_session()

input_shape = (224, 224, 3)
model = ResNet50(weights='imagenet', input_shape=input_shape)

The following utility method in AIMET sets BN layers in the model to eval mode. This allows AIMET to more easily read the BN parameters from the graph. Eventually we will fold BN layers into adjacent conv layers.

[ ]:
from aimet_tensorflow.utils.graph import update_keras_bn_ops_trainable_flag

model = update_keras_bn_ops_trainable_flag(model, False, load_save_path='./')

AIMET features currently support tensorflow sessions. add_image_net_computational_nodes_in_graph adds an output layer, softmax and loss functions to the Resnet50 model graph.

[ ]:
from Examples.tensorflow.utils.add_computational_nodes_in_graph import add_image_net_computational_nodes_in_graph

sess = tf.compat.v1.keras.backend.get_session()

# Creates the computation graph of ResNet within the tensorflow session.
add_image_net_computational_nodes_in_graph(sess, model.output.name, image_net_config.dataset['images_classes'])

Since all tensorflow input and output tensors have names, we identify the tensors needed by AIMET APIs here.

[ ]:
input_tensor_name = model.input.name
input_op_name, _ = input_tensor_name.split(":")
output_tensor_name = model.output.name
output_op_name, _ = output_tensor_name.split(":")

We are checking if TensorFlow is using CPU or CUDA device. This example code will use CUDA if available in your current execution environment.

[ ]:
use_cuda = tf.test.is_gpu_available(cuda_only=True)

3. Determine the baseline FP32 accuracy

Let’s determine the FP32 (floating point 32-bit) accuracy of this model using evaluate() routine

[ ]:
accuracy = ImageNetDataPipeline.evaluate(sess=sess)
print(accuracy)

4. Define Constants and Helper functions

In this section the constants and helper functions needed to run this example are defined.

  • EVAL_DATASET_SIZE A typical value is 5000. To execute this example faster this value has been set to 50

  • CALIBRATION_DATASET_SIZE A typical value is 2000. To execute this example faster this value has been set to 20

  • BATCH_SIZE User sets the batch size. As an example, set to 10

The helper function **_create_sampled_data_loader()** returns a DataLoader based on the dataset and the number of samples provided.

[ ]:
EVAL_DATASET_SIZE = 50
CALIBRATION_DATASET_SIZE = 20
BATCH_SIZE = 10

_sampled_datasets = {}

def _create_sampled_dataset(dataset: tf.compat.v1.data.Dataset,
                            num_samples: int) -> tf.compat.v1.data.Dataset:
    if num_samples in _sampled_datasets:
        return _sampled_datasets[num_samples]

    with dataset._graph.as_default():
        SHUFFLE_BUFFER_SIZE = 300 # NOTE: Adjust the buffer size as necessary.
        SHUFFLE_SEED = 22222
        dataset = dataset.shuffle(buffer_size=SHUFFLE_BUFFER_SIZE, seed=SHUFFLE_SEED)\
                         .take(num_samples)\
                         .batch(BATCH_SIZE)
        _sampled_datasets[num_samples] = dataset
        return dataset

Prepare unlabeled dataset

The AutoQuant feature utilizes an unlabeled dataset to achieve quantization. Below cell shows how to get an unlabeled Dataset object from a labeled Dataset.

[ ]:
eval_dataset = ImageNetDataPipeline.get_val_dataloader().dataset

with eval_dataset._graph.as_default():
    image_dataset = eval_dataset.map(lambda images, labels: images)
    unlabeled_dataset = image_dataset.batch(BATCH_SIZE)

Prepare the evaluation callback function

The eval_callback() function takes the session object to evaluate and the number of samples to use as arguments. If the num_samples argument is None, the whole evaluation dataset is used to evaluate the model.

[ ]:
import numpy as np
from aimet_tensorflow.utils.common import iterate_tf_dataset
from typing import Optional


def eval_callback(sess: tf.compat.v1.Session,
                  num_samples: Optional[int] = None) -> float:
    if num_samples is None:
        num_samples = EVAL_DATASET_SIZE

    sampled_dataset = _create_sampled_dataset(eval_dataset, num_samples)

    with sess.graph.as_default():
        sess.run(tf.compat.v1.global_variables_initializer())
        input_tensor = sess.graph.get_tensor_by_name(input_tensor_name)
        output_tensor = sess.graph.get_tensor_by_name(output_tensor_name)

        num_correct_predictions = 0
        for images, labels in iterate_tf_dataset(sampled_dataset):
            prob = sess.run(output_tensor, feed_dict={input_tensor: images})
            predictions = np.argmax(prob, axis=1)
            num_correct_predictions += np.sum(predictions == labels)

        return int(num_correct_predictions) / num_samples

5. Apply AutoQuant

As a first step, the AutoQuant object is created.

The allowed_accuracy_drop parameter is set by the user to convey to the AutoQuant feature, how much accuracy drop is tolerated by the user. AutoQuant applies a series of quantization features. When the allowed accuracy is reached, AutoQuant stops applying any subsequent quantization feature. Please refer AutoQuant User Guide and API documentation for complete details.

[ ]:
from aimet_tensorflow.auto_quant import AutoQuant

auto_quant = AutoQuant(allowed_accuracy_drop=0.01,
                       unlabeled_dataset=unlabeled_dataset,
                       eval_callback=eval_callback)

Optionally set AdaRound Parameters

The AutoQuant feature internally uses default parameters to execute the AdaRound step. If and only if necessary, the default AdaRound Parameters should be modified using the API shown below.

Note: To execute this example faster, the default value of the num_iterations parameter has been reduced from 10000 to 2000

[ ]:
from aimet_tensorflow.adaround.adaround_weight import AdaroundParameters

ADAROUND_DATASET_SIZE = 2000
adaround_dataset = _create_sampled_dataset(image_dataset, ADAROUND_DATASET_SIZE)
adaround_params = AdaroundParameters(adaround_dataset,
                                     num_batches=ADAROUND_DATASET_SIZE // BATCH_SIZE)
auto_quant.set_adaround_params(adaround_params)

Run AutoQuant

This step applies the AutoQuant feature. The best possible quantized model, the associated eval_score and the path to the AdaRound encoding files are returned.

[ ]:
sess, accuracy, encoding_path =\
    auto_quant.apply(tf.compat.v1.keras.backend.get_session(),
                     starting_op_names=[input_op_name],
                     output_op_names=[output_op_name])
[ ]:
print(accuracy)
print(encoding_path)

Summary

Hope this notebook was useful for you to understand how to use AIMET AutoQuant feature.

Few additional resources - Refer to the AIMET API docs to know more details of the APIs and parameters - Refer to the other example notebooks to understand how to use AIMET CLE and AdaRound features in a standalone fashion.