AutoQuant

This notebook contains an example of how to use AIMET AutoQuant feature.

AIMET offers a suite of neural network post-training quantization (PTQ) techniques that can be applied in succession. However, finding the right sequence of techniques to apply is time-consuming and can be challenging for non-expert users. We instead recommend AutoQuant to save time and effort.

AutoQuant is an API that analyzes the model and automatically applies various PTQ techniques based on best-practices heuristics. You specify a tolerable accuracy drop, and AutoQuant applies PTQ techniques cumulatively until the target accuracy is satisfied.

Overall flow

This example performs the following steps:

  1. Instantiate the example evaluation and training pipeline

  2. Load a pretrained FP32 model

  3. Determine the baseline FP32 accuracy

  4. Define constants and helper functions

  5. Apply AutoQuant

Note

This notebook does not show state-of-the-art results. For example, it uses a relatively quantization-friendly model (Resnet18). Also, some optimization parameters like number of fine-tuning epochs are chosen to improve execution speed in the notebook.


Dataset

This example does image classification on the ImageNet dataset. If you already have a version of the data set, use that. Otherwise download the data set, for example from https://image-net.org/challenges/LSVRC/2012/index .

Note

To speed up the execution of this notebook, you can use a reduced subset of the ImageNet dataset. For example: The entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. However, for the purpose of running this notebook, you can reduce the dataset to, say, two samples per class.

Edit the cell below to specify the directory where the downloaded ImageNet dataset is saved.

[ ]:
DATASET_DIR = '/path/to/dataset/'         # Replace this path with a real directory

1. Instantiate the example training and validation pipeline

Use the following training and validation loop for the image classification task.

Things to note:

  • AIMET does not put limitations on how the training and validation pipeline is written. AIMET modifies the user’s model to create a QuantizationSim model, which is still a PyTorch model. The QuantizationSim model can be used in place of the original model when doing inference or training.

  • AIMET doesn not put limitations on the interface of the evaluate() or train() methods. You should be able to use your existing evaluate and train routines as-is.

[ ]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from aimet_tensorflow.keras.auto_quant import AutoQuant

from typing import Optional
from Examples.common import image_net_config
from Examples.tensorflow.utils.keras.image_net_dataset import ImageNetDataset
from Examples.tensorflow.utils.keras.image_net_evaluator import ImageNetEvaluator


class ImageNetDataPipeline:
    """
    Provides APIs for model evaluation and finetuning using ImageNet Dataset.
    """

    @staticmethod
    def get_val_dataset(batch_size: Optional[int] = None) -> tf.data.Dataset:
        """
        Instantiates a validation dataloader for ImageNet dataset and returns it
        :return: A tensorflow dataset
        """
        if batch_size is None:
            batch_size = image_net_config.evaluation['batch_size']

        data_loader = ImageNetDataset(DATASET_DIR,
                                      image_size=image_net_config.dataset['image_size'],
                                      batch_size=batch_size)

        return data_loader

    @staticmethod
    def evaluate(model, iterations=None) -> float:
        """
        Given a Keras model, evaluates its Top-1 accuracy on the validation dataset
        :param model: The Keras model to be evaluated.
        :param iterations: The number of iterations to run. If None, all the data will be used
        :return: The accuracy for the sample with the maximum accuracy.
        """
        evaluator = ImageNetEvaluator(DATASET_DIR,
                                      image_size=image_net_config.dataset["image_size"],
                                      batch_size=image_net_config.evaluation["batch_size"])

        return evaluator.evaluate(model=model, iterations=iterations)

2. Load a pretrained FP32 model

Load a pretrained resnet50 model from Keras.

You can load any pretrained PyTorch model instead.

[ ]:
from tensorflow.keras.applications.resnet import ResNet50

model = ResNet50(weights='imagenet')

3. Determine the baseline FP32 accuracy

Determine the floating point 32-bit (FP32) accuracy of this model using the ``evaluate()`` routine.

[ ]:
ImageNetDataPipeline.evaluate(model=model)

4. Define constants and helper functions

4.1 Define the following constants:

  • EVAL_DATASET_SIZE A typical value is 5000. In this example, the value has been set to 50 for faster execution.

  • CALIBRATION_DATASET_SIZE A typical value is 2000. In this example, the value has been set to 20 for faster execution.

  • BATCH_SIZE You define the batch size. Set to 10 in this example.

[ ]:
EVAL_DATASET_SIZE = 50
CALIBRATION_DATASET_SIZE = 20
BATCH_SIZE = 10

4.2 Use the constants to create the evaluation dataset.

[ ]:
eval_dataset = ImageNetDataPipeline.get_val_dataset(BATCH_SIZE).dataset
unlabeled_dataset = eval_dataset.map(lambda images, labels: images)

4.3 Prepare the evaluation callback function.

The eval_callback() function takes the model object to evaluate and compile option dictionary and the number of samples to use as arguments. If the num_samples argument is None, the whole evaluation dataset is used to evaluate the model.

[ ]:
from typing import Optional


def eval_callback(model: tf.keras.Model,
                  num_samples: Optional[int] = None) -> float:
    if num_samples is None:
        num_samples = EVAL_DATASET_SIZE

    sampled_dataset = eval_dataset.take(num_samples)

    # Model should be compiled before evaluation
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.CategoricalCrossentropy(),
                  metrics=tf.keras.metrics.CategoricalAccuracy())
    _, acc = model.evaluate(sampled_dataset)

    return acc

5. Apply AutoQuant

5.1 Create an AutoQuant object.

The AutoQuant feature uses an unlabeled dataset to quantize the model. The UnlabeledDatasetWrapper class creates an unlabeled Dataset object from a labeled Dataset.

The allowed_accuracy_drop indicates the tolerable accuracy drop. AutoQuant applies a series of quantization features until the target accuracy (FP32 accuracy - allowed accuracy drop) is satisfied. When the target accuracy is reached, AutoQuant returns immediately without applying furhter PTQ techniques. See the AutoQuant User Guide and AutoQuant API documentation for details.

[ ]:
auto_quant = AutoQuant(allowed_accuracy_drop=0.01,
                       unlabeled_dataset=unlabeled_dataset,
                       eval_callback=eval_callback)

5.2 Set AdaRound Parameters (optional)

AutoQuant uses predefined default parameters for AdaRound. These values were determined empirically and work well with the common models.

If necessary, you can use custom parameters for Adaround. This example uses very small AdaRound parameters for faster execution.

[ ]:
from aimet_tensorflow.adaround.adaround_weight import AdaroundParameters

ADAROUND_DATASET_SIZE = 2000
adaround_dataset = unlabeled_dataset.take(ADAROUND_DATASET_SIZE)
adaround_params = AdaroundParameters(adaround_dataset,
                                     num_batches=ADAROUND_DATASET_SIZE // BATCH_SIZE)
auto_quant.set_adaround_params(adaround_params)

5.3 Run AutoQuant Optimization

This step runs AutoQuant optimization. AutoQuant returns the following: - The best possible quantized model - The corresponding evaluation score - The path to the encoding file

[ ]:
model, accuracy, encoding_path = auto_quant.apply(model)

For more information

See the AIMET API docs for details about the AIMET APIs and optional parameters.

See the other example notebooks to learn how to use other AIMET post-training quantization techniques.