AIMET PyTorch AutoQuant API
User Guide Link
To learn more about this technique, please see AutoQuant
Examples Notebook Link
For an end-to-end notebook showing how to use PyTorch AutoQuant, please see here.
Top-level API
Note
This module is also available in the experimental aimet_torch.v2
namespace with the same top-level API. To
learn more about the differences between aimet_torch
and aimet_torch.v2
, please visit the
QuantSim v2 Overview.
- class aimet_torch.auto_quant.AutoQuant(model, dummy_input, data_loader, eval_callback, param_bw=8, output_bw=8, quant_scheme=QuantScheme.post_training_tf_enhanced, rounding_mode='nearest', config_file=None, results_dir='/tmp', cache_id=None, strict_validation=True, model_prepare_required=True)[source]
Integrate and apply post-training quantization techniques.
AutoQuant includes 1) batchnorm folding, 2) cross-layer equalization, and 3) Adaround. These techniques will be applied in a best-effort manner until the model meets the evaluation goal given as allowed_accuracy_drop.
- Parameters:
model (
Module
) – Model to be quantized. Assumes model is on the correct devicedummy_input (
Union
[Tensor
,Tuple
]) – Dummy input for the model. Assumes that dummy_input is on the correct devicedata_loader (
DataLoader
) – A collection that iterates over an unlabeled dataset, used for computing encodingseval_callback (
Callable
[[Module
],float
]) – Function that calculates the evaluation scoreparam_bw (
int
) – Parameter bitwidthoutput_bw (
int
) – Output bitwidthquant_scheme (
QuantScheme
) – Quantization schemerounding_mode (
str
) – Rounding modeconfig_file (
Optional
[str
]) – Path to configuration file for model quantizersresults_dir (
str
) – Directory to save the results of PTQ techniquescache_id (
Optional
[str
]) – ID associated with cache resultsstrict_validation (
bool
) – Flag set to True by default.hen False, AutoQuant will proceed with execution and handle errors internally if possible. This may produce unideal or unintuitive results.model_prepare_required (
bool
) – Flag set to True by default.If False, AutoQuant will skip model prepare block in the pipeline.
Code Examples
import random
from typing import Optional
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import models, datasets, transforms
from aimet_torch.adaround.adaround_weight import AdaroundParameters
from aimet_torch.auto_quant_v2 import AutoQuant
# Step 1. Define constants and helper functions
EVAL_DATASET_SIZE = 5000
CALIBRATION_DATASET_SIZE = 2000
BATCH_SIZE = 100
_subset_samplers = {}
def _create_sampled_data_loader(dataset, num_samples):
if num_samples not in _subset_samplers:
indices = random.sample(range(len(dataset)), num_samples)
_subset_samplers[num_samples] = SubsetRandomSampler(indices=indices)
return DataLoader(dataset,
sampler=_subset_samplers[num_samples],
batch_size=BATCH_SIZE)
# Step 2. Prepare model and dataset
fp32_model = models.resnet18(pretrained=True).eval()
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)
transform = transforms.Compose((
transforms.ToTensor(),
))
# NOTE: In the actual use cases, a real dataset should provide by the users.
eval_dataset = datasets.FakeData(size=EVAL_DATASET_SIZE,
image_size=input_shape[1:],
num_classes=1000,
transform=transform)
# Step 3. Prepare unlabeled dataset
# NOTE: In the actual use cases, the users should implement this part to serve
# their own goals if necessary.
class UnlabeledDatasetWrapper(Dataset):
def __init__(self, dataset):
self._dataset = dataset
def __len__(self):
return len(self._dataset)
def __getitem__(self, index):
images, _ = self._dataset[index]
return images
unlabeled_dataset = UnlabeledDatasetWrapper(eval_dataset)
unlabeled_data_loader = _create_sampled_data_loader(unlabeled_dataset, CALIBRATION_DATASET_SIZE)
# Step 4. Prepare eval callback
# NOTE: In the actual use cases, the users should implement this part to serve
# their own goals if necessary.
def eval_callback(model: torch.nn.Module, num_samples: Optional[int] = None) -> float:
if num_samples is None:
num_samples = len(eval_dataset)
eval_data_loader = _create_sampled_data_loader(eval_dataset, num_samples)
num_correct_predictions = 0
for images, labels in eval_data_loader:
predictions = torch.argmax(model(images.cuda()), dim=1)
num_correct_predictions += torch.sum(predictions.cpu() == labels)
return int(num_correct_predictions) / num_samples
# Step 5. Create AutoQuant object
auto_quant = AutoQuant(fp32_model.cuda(),
dummy_input.cuda(),
unlabeled_data_loader,
eval_callback)
# Step 6. (Optional) Set adaround params
ADAROUND_DATASET_SIZE = 2000
adaround_data_loader = _create_sampled_data_loader(unlabeled_dataset, ADAROUND_DATASET_SIZE)
adaround_params = AdaroundParameters(adaround_data_loader, num_batches=len(adaround_data_loader))
auto_quant.set_adaround_params(adaround_params)
# Step 7. Run AutoQuant
sim, initial_accuracy = auto_quant.run_inference()
model, optimized_accuracy, encoding_path = auto_quant.optimize(allowed_accuracy_drop=0.01)
print(f"- Quantized Accuracy (before optimization): {initial_accuracy:.4f}")
print(f"- Quantized Accuracy (after optimization): {optimized_accuracy:.4f}")
Note
To use auto_quant.AutoQuant
(will be deprecated), apply the following code changes to step 5 and 7.
# Step 5. Create AutoQuant object
auto_quant = AutoQuant(allowed_accuracy_drop=0.01,
unlabeled_dataset_iterable=unlabeled_data_loader,
eval_callback=eval_callback)
# Step 6. (Optional) Set adaround params
ADAROUND_DATASET_SIZE = 2000
adaround_data_loader = _create_sampled_data_loader(unlabeled_dataset, ADAROUND_DATASET_SIZE)
adaround_params = AdaroundParameters(adaround_data_loader, num_batches=len(adaround_data_loader))
auto_quant.set_adaround_params(adaround_params)
# Step 7. Run AutoQuant
model, accuracy, encoding_path =\
auto_quant.apply(fp32_model.cuda(),
dummy_input_on_cpu=dummy_input.cpu(),
dummy_input_on_gpu=dummy_input.cuda())
print(f"- Quantized Accuracy (after optimization): {optimized_accuracy:.4f}")