AIMET PyTorch AutoQuant API¶
Examples Notebook Link¶
For an end-to-end notebook showing how to use PyTorch AutoQuant, please see here.
Code Examples¶
Required imports
import random
from typing import Optional
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import models, datasets, transforms
from aimet_torch.adaround.adaround_weight import AdaroundParameters
from aimet_torch.auto_quant import AutoQuant
Define constants and helper functions
EVAL_DATASET_SIZE = 5000
CALIBRATION_DATASET_SIZE = 2000
BATCH_SIZE = 100
_subset_samplers = {}
def _create_sampled_data_loader(dataset, num_samples):
if num_samples not in _subset_samplers:
indices = random.sample(range(len(dataset)), num_samples)
_subset_samplers[num_samples] = SubsetRandomSampler(indices=indices)
return DataLoader(dataset,
sampler=_subset_samplers[num_samples],
batch_size=BATCH_SIZE)
Prepare model and dataset
fp32_model = models.resnet18(pretrained=True).eval()
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)
transform = transforms.Compose((
transforms.ToTensor(),
))
# NOTE: In the actual use cases, a real dataset should provide by the users.
eval_dataset = datasets.FakeData(size=EVAL_DATASET_SIZE,
image_size=input_shape[1:],
num_classes=1000,
transform=transform)
Prepare unlabeled dataset
# NOTE: In the actual use cases, the users should implement this part to serve
# their own goals if necessary.
class UnlabeledDatasetWrapper(Dataset):
def __init__(self, dataset):
self._dataset = dataset
def __len__(self):
return len(self._dataset)
def __getitem__(self, index):
images, _ = self._dataset[index]
return images
unlabeled_dataset = UnlabeledDatasetWrapper(eval_dataset)
unlabeled_data_loader = _create_sampled_data_loader(unlabeled_dataset, CALIBRATION_DATASET_SIZE)
Prepare eval callback
# NOTE: In the actual use cases, the users should implement this part to serve
# their own goals if necessary.
def eval_callback(model: torch.nn.Module, num_samples: Optional[int] = None) -> float:
if num_samples is None:
num_samples = len(eval_dataset)
eval_data_loader = _create_sampled_data_loader(eval_dataset, num_samples)
num_correct_predictions = 0
for images, labels in eval_data_loader:
predictions = torch.argmax(model(images.cuda()), dim=1)
num_correct_predictions += torch.sum(predictions.cpu() == labels)
return int(num_correct_predictions) / num_samples
Create AutoQuant object
auto_quant = AutoQuant(allowed_accuracy_drop=0.01,
unlabeled_dataset_iterable=unlabeled_data_loader,
eval_callback=eval_callback)
(Optional) Set Adaround parameters
For setting the num_batches parameter, use the following guideline. The number of batches is used to evaluate the model while calculating the quantization encodings. Typically we want AdaRound to use around 2000 samples. For example, if the batch size is 32, num_batches is 64. If the batch size you are using is different, adjust the num_batches accordingly.
ADAROUND_DATASET_SIZE = 2000
adaround_data_loader = _create_sampled_data_loader(unlabeled_dataset, ADAROUND_DATASET_SIZE)
adaround_params = AdaroundParameters(adaround_data_loader, num_batches=len(adaround_data_loader))
auto_quant.set_adaround_params(adaround_params)
Run AutoQuant
model, accuracy, encoding_path =\
auto_quant.apply(fp32_model.cuda(),
dummy_input_on_cpu=dummy_input.cpu(),
dummy_input_on_gpu=dummy_input.cuda())