AutoQuant¶
This notebook shows a working code example of how to use AIMET AutoQuant feature.
AIMET offers a suite of neural network post-training quantization techniques. Often, applying these techniques in a specific sequence, results in better accuracy and performance. Without the AutoQuant feature, the AIMET user needs to manually try out various combinations of AIMET quantization features. This manual process is error-prone and often time-consuming.
The AutoQuant feature, analyzes the model, determines the sequence of AIMET quantization techniques and applies these techniques. In addition, the user can specify the amount of accuracy drop that can be tolerated, in the AutoQuant API. As soon as this threshold accuracy is reached, AutoQuant stops applying any additional quantization technique. In summary, the AutoQuant feature saves time and automates the quantization of the neural networks.
Overall flow¶
This notebook covers the following 1. Instantiate the example evaluation and training pipeline 2. Load a pretrained FP32 model 3. Determine the baseline FP32 accuracy 4. Define constants and helper functions 5. Run AutoQuant
What this notebook is not¶
This notebook is not designed to show state-of-the-art AutoQuant results. For example, it uses a relatively quantization-friendly model like Resnet18. Also, some optimization parameters are deliberately chosen to have the notebook execute more quickly.
NOTE: This notebook is for auto_quant_v2.AutoQuant. For examples of auto_quant.AutoQuant (will be deprecated), see
<a href="autoquant_v1.ipynb">autoquant_v1.ipynb</a>.
Dataset¶
This notebook relies on the ImageNet dataset for the task of image classification. If you already have a version of the dataset readily available, please use that. Else, please download the dataset from appropriate location (e.g. https://image-net.org/challenges/LSVRC/2012/index.php#).
Note1: The ImageNet dataset typically has the following characteristics and the dataloader provided in this example notebook rely on these - Subfolders ‘train’ for the training samples and ‘val’ for the validation samples. Please see the pytorch dataset description for more details. - A subdirectory per class, and a file per each image sample
Note2: To speed up the execution of this notebook, you may use a reduced subset of the ImageNet dataset. E.g. the entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. But for the purpose of running this notebook, you could perhaps reduce the dataset to say 2 samples per class. This exercise is left upto the reader and is not necessary.
Edit the cell below and specify the directory where the downloaded ImageNet dataset is saved.
[ ]:
DATASET_DIR = '/path/to/dataset' # Please replace this with a real directory
1. Example evaluation and training pipeline¶
The following is an example training and validation loop for this image classification task.
Does AIMET have any limitations on how the training, validation pipeline is written? Not really. We will see later that AIMET will modify the user’s model to create a QuantizationSim model which is still a PyTorch model. This QuantizationSim model can be used in place of the original model when doing inference or training.
Does AIMET put any limitation on the interface of the evaluate() or train() methods? Not really. You should be able to use your existing evaluate and train routines as-is.
[ ]:
import os
import sys
sys.path.append("../../..")
import torch
from Examples.common import image_net_config
from Examples.torch.utils.image_net_evaluator import ImageNetEvaluator
from Examples.torch.utils.image_net_trainer import ImageNetTrainer
from Examples.torch.utils.image_net_data_loader import ImageNetDataLoader
class ImageNetDataPipeline:
@staticmethod
def get_val_dataloader() -> torch.utils.data.DataLoader:
"""
Instantiates a validation dataloader for ImageNet dataset and returns it
"""
data_loader = ImageNetDataLoader(DATASET_DIR,
image_size=image_net_config.dataset['image_size'],
batch_size=image_net_config.evaluation['batch_size'],
is_training=False,
num_workers=image_net_config.evaluation['num_workers']).data_loader
return data_loader
@staticmethod
def evaluate(model: torch.nn.Module, use_cuda: bool) -> float:
"""
Given a torch model, evaluates its Top-1 accuracy on the dataset
:param model: the model to evaluate
:param use_cuda: whether or not the GPU should be used.
"""
evaluator = ImageNetEvaluator(DATASET_DIR, image_size=image_net_config.dataset['image_size'],
batch_size=image_net_config.evaluation['batch_size'],
num_workers=image_net_config.evaluation['num_workers'])
return evaluator.evaluate(model, iterations=None, use_cuda=use_cuda)
@staticmethod
def finetune(model: torch.nn.Module, epochs, learning_rate, learning_rate_schedule, use_cuda):
"""
Given a torch model, finetunes the model to improve its accuracy
:param model: the model to finetune
:param epochs: The number of epochs used during the finetuning step.
:param learning_rate: The learning rate used during the finetuning step.
:param learning_rate_schedule: The learning rate schedule used during the finetuning step.
:param use_cuda: whether or not the GPU should be used.
"""
trainer = ImageNetTrainer(DATASET_DIR, image_size=image_net_config.dataset['image_size'],
batch_size=image_net_config.train['batch_size'],
num_workers=image_net_config.train['num_workers'])
trainer.train(model, max_epochs=epochs, learning_rate=learning_rate,
learning_rate_schedule=learning_rate_schedule, use_cuda=use_cuda)
2. Load a pretrained FP32 model¶
For this example, we are going to load a pretrained resnet18 model from torchvision. Similarly, you can load any pretrained PyTorch model instead.
[ ]:
from torchvision.models import resnet18
model = resnet18(pretrained=True).eval()
use_cuda = False
if torch.cuda.is_available():
use_cuda = True
model.to(torch.device('cuda'))
3. Determine the baseline FP32 accuracy¶
Let’s determine the FP32 (floating point 32-bit) accuracy of this model using the evaluate() routine
We should decide whether to place the model on a CPU or CUDA device. This example code will use CUDA if available in your current execution environment. You can change this logic and force a device placement if needed.
[ ]:
accuracy = ImageNetDataPipeline.evaluate(model, use_cuda)
print(accuracy)
4. Define Constants and Helper functions¶
In this section the constants and helper functions needed to run this eaxmple are defined.
EVAL_DATASET_SIZE A typical value is 5000. To execute this example faster this value has been set to 50
CALIBRATION_DATASET_SIZE A typical value is 2000. To execute this example faster this value has been set to 20
BATCH_SIZE User sets the batch size. As an example, set to 10
The helper function **_create_sampled_data_loader()** returns a DataLoader based on the dataset and the number of samples provided.
[ ]:
import random
EVAL_DATASET_SIZE = 50
CALIBRATION_DATASET_SIZE = 20
BATCH_SIZE = 10
_subset_samplers = {}
def _create_sampled_data_loader(dataset, num_samples):
if num_samples not in _subset_samplers:
indices = random.sample(range(len(dataset)), num_samples)
_subset_samplers[num_samples] = SubsetRandomSampler(indices=indices)
return DataLoader(dataset,
sampler=_subset_samplers[num_samples],
batch_size=BATCH_SIZE)
Prepare unlabeled dataset¶
The AutoQuant feature utilizes an unlabeled dataset to achieve quantization. The class UnlabeledDatasetWrapper creates an unlabeled Dataset object from a labeled Dataset.
[ ]:
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
class UnlabeledDatasetWrapper(Dataset):
def __init__(self, dataset):
self._dataset = dataset
def __len__(self):
return len(self._dataset)
def __getitem__(self, index):
images, _ = self._dataset[index]
return images
from torchvision import transforms
normalize = transforms.Normalize(mean=image_net_config.dataset['images_mean'],
std=image_net_config.dataset['images_std'])
image_size = image_net_config.dataset['image_size'],
val_transforms = transforms.Compose([
transforms.CenterCrop(image_size),
transforms.ToTensor(),
normalize])
from Examples.torch.utils.image_net_data_loader import ImageFolder
imagenet_dataset = ImageFolder(root=os.path.join(DATASET_DIR, 'val'), transform=val_transforms)
unlabeled_imagenet_dataset = UnlabeledDatasetWrapper(imagenet_dataset)
unlabeled_imagenet_data_loader = _create_sampled_data_loader(unlabeled_imagenet_dataset, CALIBRATION_DATASET_SIZE)
Prepare the evaluation callback function¶
The eval_callback() function takes the model to evaluate and the number of samples to use as arguments. If the num_samples argument is None, the whole evaluation dataset is used to evaluate the model.
[ ]:
from typing import Optional
def eval_callback(model: torch.nn.Module, num_samples: Optional[int] = None) -> float:
return ImageNetDataPipeline.evaluate(model, use_cuda)
5. Run AutoQuant¶
This step runs AuotQuant.
[ ]:
from aimet_torch.auto_quant_v2 import AutoQuant
dummy_input = torch.randn((1, 3, 224, 224))
if use_cuda:
dummy_input = dummy_input.cuda()
auto_quant = AutoQuant(model,
dummy_input=dummy_input,
data_loader=unlabeled_imagenet_data_loader,
eval_callback=eval_callback)
Set AdaRound Parameters (optional)¶
The AutoQuant feature internally uses default parameters to execute the AdaRound step. If and only if necessary, the default AdaRound Parameters should be modified using the API shown below.
Note: To execute this example faster, the default value of the num_iterations parameter has been reduced from 10000 to 2000
[ ]:
from aimet_torch.adaround.adaround_weight import AdaroundParameters
ADAROUND_DATASET_SIZE = 20
adaround_data_loader = _create_sampled_data_loader(unlabeled_imagenet_dataset, ADAROUND_DATASET_SIZE)
print(len(adaround_data_loader))
adaround_params = AdaroundParameters(adaround_data_loader, num_batches=len(adaround_data_loader), default_num_iterations=2000)
auto_quant.set_adaround_params(adaround_params)
Run AutoQuant Inference¶
This step runs AutoQuant inference. AutoQuant inference will run evaluation using the eval_callback with the vanilla quantized model without applying PTQ techniques. This will be useful for measuring the baseline evaluation score before running AutoQuant optimization.
[ ]:
sim, initial_accuracy = auto_quant.run_inference()
print(f"- Quantized Accuracy (before optimization): {initial_accuracy}")
Run AutoQuant Optimization¶
This step runs AutoQuant optimization, which returns the best possible quantized model, corresponding evaluation score and the path to the encoding file. The allowed_accuracy_drop parameter indicates the tolerable amount of accuracy drop. AutoQuant applies a series of quantization features until the target accuracy (FP32 accuracy - allowed accuracy drop) is satisfied. When the target accuracy is reached, AutoQuant will return immediately without applying furhter PTQ techniques. Please refer AutoQuant User Guide and API documentation for complete details.
[ ]:
model, optimized_accuracy, encoding_path = auto_quant.optimize(allowed_accuracy_drop=0.01)
print(f"- Quantized Accuracy (after optimization): {optimized_accuracy}")
Summary¶
Hope this notebook was useful for you to understand how to use AIMET AutoQuant feature.
Few additional resources - Refer to the AIMET API docs to know more details of the APIs and parameters - Refer to the other example notebooks to understand how to use AIMET CLE and AdaRound features in a standalone fashion.