AIMET PyTorch AutoQuant API¶
Top-level API¶
-
class
aimet_torch.auto_quant.
AutoQuant
(allowed_accuracy_drop, unlabeled_dataset_iterable, eval_callback, default_param_bw=8, default_output_bw=8, default_quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, default_rounding_mode='nearest', default_config_file=None)¶ Integrate and apply post-training quantization techniques.
AutoQuant includes 1) batchnorm folding, 2) cross-layer equalization, and 3) Adaround. These techniques will be applied in a best-effort manner until the model meets the evaluation goal given as allowed_accuracy_drop.
- Parameters
allowed_accuracy_drop (
float
) – Maximum allowed accuracy drop.unlabeled_dataset_iterable (
Union
[DataLoader
[+T_co],Collection
[+T_co]]) – A collection (i.e. iterable with __len__) that iterates over an unlabeled dataset used for encoding computation. The values yielded by this iterable are expected to be able to be passed directly to the model. By default, this iterable will be also used for Adaround unless otherwise specified by self.set_adaround_params.eval_callback (
Callable
[[Module
,Optional
[int
]],float
]) – A function that maps model and the number samples to the evaluation score. This callback is expected to return a scalar value representing the model performance evaluated against exactly N samples, where N is the number of samples passed as the second argument of this callback. NOTE: If N is None, the model is expected to be evaluated against the whole evaluation dataset.default_param_bw (
int
) – Default bitwidth (4-31) to use for quantizing layer parameters.default_output_bw (
int
) – Default bitwidth (4-31) to use for quantizing layer inputs andoutputs.default_quant_scheme (
QuantScheme
) – Quantization scheme. Supported values are QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced.default_rounding_mode (
str
) – Rounding mode. Supported options are ‘nearest’ or ‘stochastic’default_config_file (
Optional
[str
]) – Path to configuration file for model quantizers
-
apply
(fp32_model, dummy_input_on_cpu, dummy_input_on_gpu=None, results_dir='/tmp', cache_id=None)¶ Apply post-training quantization techniques.
- Parameters
fp32_model (
Module
) – Model to apply PTQ techniques.dummy_input_on_cpu (
Union
[Tensor
,Tuple
]) – Dummy input to the model in CPU memory.dummy_input_on_gpu (
Union
[Tensor
,Tuple
,None
]) – Dummy input to the model in GPU memory. This parameter is required if and only if the fp32_model is on GPU.results_dir (
str
) – Directory to save the results.cache_id (
Optional
[str
]) – A string that composes a cache id in combination with results_dir. If specified, AutoQuant will load/save the PTQ results from/to the file system if previous PTQ results produced under the same results_dir and cache_id exist,
- Return type
Tuple
[Module
,float
,str
]- Returns
Tuple of (best model, eval score, encoding path front).
- Raises
ValueError if the model is on GPU and dummy_input_on_gpu is not specified.
-
set_adaround_params
(adaround_params)¶ Set Adaround parameters. If this method is not called explicitly by the user, AutoQuant will use unlabeled_dataset_iterable (passed to __init__) for Adaround.
- Parameters
adaround_params (
AdaroundParameters
) – Adaround parameters.- Return type
None
Code Examples¶
Required imports
import random
from typing import Optional
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import models, datasets, transforms
from aimet_torch.adaround.adaround_weight import AdaroundParameters
from aimet_torch.auto_quant import AutoQuant
Define constants and helper functions
EVAL_DATASET_SIZE = 5000
CALIBRATION_DATASET_SIZE = 2000
BATCH_SIZE = 100
_subset_samplers = {}
def _create_sampled_data_loader(dataset, num_samples):
if num_samples not in _subset_samplers:
indices = random.sample(range(len(dataset)), num_samples)
_subset_samplers[num_samples] = SubsetRandomSampler(indices=indices)
return DataLoader(dataset,
sampler=_subset_samplers[num_samples],
batch_size=BATCH_SIZE)
Prepare model and dataset
fp32_model = models.resnet18(pretrained=True).eval()
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)
transform = transforms.Compose((
transforms.ToTensor(),
))
# NOTE: In the actual use cases, a real dataset should provide by the users.
eval_dataset = datasets.FakeData(size=EVAL_DATASET_SIZE,
image_size=input_shape[1:],
num_classes=1000,
transform=transform)
Prepare unlabeled dataset
# NOTE: In the actual use cases, the users should implement this part to serve
# their own goals if necessary.
class UnlabeledDatasetWrapper(Dataset):
def __init__(self, dataset):
self._dataset = dataset
def __len__(self):
return len(self._dataset)
def __getitem__(self, index):
images, _ = self._dataset[index]
return images
unlabeled_dataset = UnlabeledDatasetWrapper(eval_dataset)
unlabeled_data_loader = _create_sampled_data_loader(unlabeled_dataset, CALIBRATION_DATASET_SIZE)
Prepare eval callback
# NOTE: In the actual use cases, the users should implement this part to serve
# their own goals if necessary.
def eval_callback(model: torch.nn.Module, num_samples: Optional[int] = None) -> float:
if num_samples is None:
num_samples = len(eval_dataset)
eval_data_loader = _create_sampled_data_loader(eval_dataset, num_samples)
num_correct_predictions = 0
for images, labels in eval_data_loader:
predictions = torch.argmax(model(images.cuda()), dim=1)
num_correct_predictions += torch.sum(predictions.cpu() == labels)
return int(num_correct_predictions) / num_samples
Create AutoQuant object
auto_quant = AutoQuant(allowed_accuracy_drop=0.01,
unlabeled_dataset_iterable=unlabeled_data_loader,
eval_callback=eval_callback)
(Optional) Set Adaround parameters
For setting the num_batches parameter, use the following guideline. The number of batches is used to evaluate the model while calculating the quantization encodings. Typically we want AdaRound to use around 2000 samples. For example, if the batch size is 32, num_batches is 64. If the batch size you are using is different, adjust the num_batches accordingly.
ADAROUND_DATASET_SIZE = 2000
adaround_data_loader = _create_sampled_data_loader(unlabeled_dataset, ADAROUND_DATASET_SIZE)
adaround_params = AdaroundParameters(adaround_data_loader, num_batches=len(adaround_data_loader))
auto_quant.set_adaround_params(adaround_params)
Run AutoQuant
model, accuracy, encoding_path =\
auto_quant.apply(fp32_model.cuda(),
dummy_input_on_cpu=dummy_input.cpu(),
dummy_input_on_gpu=dummy_input.cuda())