AIMET PyTorch AutoQuant API

Top-level API

class aimet_torch.auto_quant_v2.AutoQuant(model, dummy_input, data_loader, eval_callback, param_bw=8, output_bw=8, quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, rounding_mode='nearest', config_file=None, results_dir='/tmp', cache_id=None, strict_validation=True)[source]

Integrate and apply post-training quantization techniques.

AutoQuant includes 1) batchnorm folding, 2) cross-layer equalization, and 3) Adaround. These techniques will be applied in a best-effort manner until the model meets the evaluation goal given as allowed_accuracy_drop.

Parameters
  • model (Module) – Model to be quantized. Assumes model is on the correct device

  • dummy_input (Union[Tensor, Tuple]) – Dummy input for the model. Assumes that dummy_input is on the correct device

  • data_loader (DataLoader[+T_co]) – A collection that iterates over an unlabeled dataset, used for computing encodings

  • eval_callback (Callable[[Module], float]) – Function that calculates the evaluation score

  • param_bw (int) – Parameter bitwidth

  • output_bw (int) – Output bitwidth

  • quant_scheme (QuantScheme) – Quantization scheme

  • rounding_mode (str) – Rounding mode

  • config_file (Optional[str]) – Path to configuration file for model quantizers

  • results_dir (str) – Directory to save the results of PTQ techniques

  • cache_id (Optional[str]) – ID associated with cache results

  • strict_validation (bool) – Flag set to True by default.hen False, AutoQuant will proceed with execution and handle errors internally if possible. This may produce unideal or unintuitive results.

run_inference()[source]

Creates a quantization model and performs inference

Return type

Tuple[QuantizationSimModel, float]

Returns

QuantizationSimModel, model accuracy as float

optimize(allowed_accuracy_drop=0.0)[source]

Integrate and apply post-training quantization techniques.

Parameters

allowed_accuracy_drop (float) – Maximum allowed accuracy drop

Return type

Tuple[Module, float, str]

Returns

Tuple of (best model, eval score, encoding path)

set_adaround_params(adaround_params)[source]

Set Adaround parameters. If this method is not called explicitly by the user, AutoQuant will use data_loader (passed to __init__) for Adaround.

Parameters

adaround_params (AdaroundParameters) – Adaround parameters.

Return type

None

set_export_params(onnx_export_args=-1, propagate_encodings=None)[source]

Set parameters for QuantizationSimModel.export.

Parameters
  • onnx_export_args (OnnxExportApiArgs) – optional export argument with onnx specific overrides if not provide export via torchscript graph

  • propagate_encodings (Optional[bool]) – If True, encoding entries for intermediate ops (when one PyTorch ops results in multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of ops.

Return type

None

set_model_preparer_params(modules_to_exclude=None, module_classes_to_exclude=None, concrete_args=None)[source]

Set parameters for model preparer.

Parameters
  • modules_to_exclude (Optional[List[Module]]) – List of modules to exclude when tracing.

  • module_classes_to_exclude (Optional[List[Module]]) – List of module classes to exclude when tracing.

  • concrete_args (Optional[Dict[str, Any]]) – Parameter for model preparer. Allows you to partially specialize your function, whether it’s to remove control flow or data structures. If the model has control flow, torch.fx won’t be able to trace the model. Check torch.fx.symbolic_trace API in detail.

get_quant_scheme_candidates()[source]

Return the candidates for quant scheme search. During optimize(), the candidate with the highest accuracy will be selected among them.

Return type

Tuple[_QuantSchemePair, …]

Returns

Candidates for quant scheme search

set_quant_scheme_candidates(candidates)[source]

Set candidates for quant scheme search. During optimize(), the candidate with the highest accuracy will be selected among them.

Parameters

candidates (Tuple[_QuantSchemePair, …]) – Candidates for quant scheme search

class aimet_torch.auto_quant.AutoQuant(allowed_accuracy_drop, unlabeled_dataset_iterable, eval_callback, default_param_bw=8, default_output_bw=8, default_quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, default_rounding_mode='nearest', default_config_file=None)[source]

Warning

auto_quant.AutoQuant is deprecated and will be replaced with auto_quant_v2.AutoQuant in the later versions.

Integrate and apply post-training quantization techniques.

AutoQuant includes 1) batchnorm folding, 2) cross-layer equalization, and 3) Adaround. These techniques will be applied in a best-effort manner until the model meets the evaluation goal given as allowed_accuracy_drop.

Parameters
  • allowed_accuracy_drop (float) – Maximum allowed accuracy drop.

  • unlabeled_dataset_iterable (Union[DataLoader[+T_co], Collection[+T_co]]) – A collection (i.e. iterable with __len__) that iterates over an unlabeled dataset used for encoding computation. The values yielded by this iterable are expected to be able to be passed directly to the model. By default, this iterable will be also used for Adaround unless otherwise specified by self.set_adaround_params.

  • eval_callback (Callable[[Module, Optional[int]], float]) – A function that maps model and the number samples to the evaluation score. This callback is expected to return a scalar value representing the model performance evaluated against exactly N samples, where N is the number of samples passed as the second argument of this callback. NOTE: If N is None, the model is expected to be evaluated against the whole evaluation dataset.

  • default_param_bw (int) – Default bitwidth (4-31) to use for quantizing layer parameters.

  • default_output_bw (int) – Default bitwidth (4-31) to use for quantizing layer inputs andoutputs.

  • default_quant_scheme (QuantScheme) – Quantization scheme. Supported values are QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced.

  • default_rounding_mode (str) – Rounding mode. Supported options are ‘nearest’ or ‘stochastic’

  • default_config_file (Optional[str]) – Path to configuration file for model quantizers

apply(fp32_model, dummy_input_on_cpu, dummy_input_on_gpu=None, results_dir='/tmp', cache_id=None)[source]

Apply post-training quantization techniques.

Parameters
  • fp32_model (Module) – Model to apply PTQ techniques.

  • dummy_input_on_cpu (Union[Tensor, Tuple]) – Dummy input to the model in CPU memory.

  • dummy_input_on_gpu (Union[Tensor, Tuple, None]) – Dummy input to the model in GPU memory. This parameter is required if and only if the fp32_model is on GPU.

  • results_dir (str) – Directory to save the results.

  • cache_id (Optional[str]) – A string that composes a cache id in combination with results_dir. If specified, AutoQuant will load/save the PTQ results from/to the file system if previous PTQ results produced under the same results_dir and cache_id exist,

Return type

Tuple[Module, float, str]

Returns

Tuple of (best model, eval score, encoding path front).

Raises
  • ValueError if the model is on GPU and dummy_input_on_gpu is not specified.

set_adaround_params(adaround_params)[source]

Set Adaround parameters. If this method is not called explicitly by the user, AutoQuant will use unlabeled_dataset_iterable (passed to __init__) for Adaround.

Parameters

adaround_params (AdaroundParameters) – Adaround parameters.

Return type

None

set_export_params(onnx_export_args=-1, propagate_encodings=None)[source]

Set parameters for QuantizationSimModel.export.

Parameters
  • onnx_export_args (OnnxExportApiArgs) – optional export argument with onnx specific overrides if not provide export via torchscript graph

  • propagate_encodings (Optional[bool]) – If True, encoding entries for intermediate ops (when one PyTorch ops results in multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of ops.

Return type

None

Code Examples

from typing import Optional

import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import models, datasets, transforms

from aimet_torch.adaround.adaround_weight import AdaroundParameters
from aimet_torch.auto_quant_v2 import AutoQuant

# Step 1. Define constants and helper functions
EVAL_DATASET_SIZE = 5000
CALIBRATION_DATASET_SIZE = 2000
BATCH_SIZE = 100

_subset_samplers = {}

def _create_sampled_data_loader(dataset, num_samples):
    if num_samples not in _subset_samplers:
        indices = random.sample(range(len(dataset)), num_samples)
        _subset_samplers[num_samples] = SubsetRandomSampler(indices=indices)
    return DataLoader(dataset,
                      sampler=_subset_samplers[num_samples],
                      batch_size=BATCH_SIZE)

# Step 2. Prepare model and dataset
fp32_model = models.resnet18(pretrained=True).eval()

input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)

transform = transforms.Compose((
    transforms.ToTensor(),
))
# NOTE: In the actual use cases, a real dataset should provide by the users.
eval_dataset = datasets.FakeData(size=EVAL_DATASET_SIZE,
                                 image_size=input_shape[1:],
                                 num_classes=1000,
                                 transform=transform)

# Step 3. Prepare unlabeled dataset
# NOTE: In the actual use cases, the users should implement this part to serve
#       their own goals if necessary.
class UnlabeledDatasetWrapper(Dataset):
    def __init__(self, dataset):
        self._dataset = dataset

    def __len__(self):
        return len(self._dataset)

    def __getitem__(self, index):
        images, _ = self._dataset[index]
        return images

unlabeled_dataset = UnlabeledDatasetWrapper(eval_dataset)
unlabeled_data_loader = _create_sampled_data_loader(unlabeled_dataset, CALIBRATION_DATASET_SIZE)

# Step 4. Prepare eval callback
# NOTE: In the actual use cases, the users should implement this part to serve
#       their own goals if necessary.
def eval_callback(model: torch.nn.Module, num_samples: Optional[int] = None) -> float:
    if num_samples is None:
        num_samples = len(eval_dataset)

    eval_data_loader = _create_sampled_data_loader(eval_dataset, num_samples)

    num_correct_predictions = 0
    for images, labels in eval_data_loader:
        predictions = torch.argmax(model(images.cuda()), dim=1)
        num_correct_predictions += torch.sum(predictions.cpu() == labels)

    return int(num_correct_predictions) / num_samples

# Step 5. Create AutoQuant object
auto_quant = AutoQuant(fp32_model.cuda(),
                       dummy_input.cuda(),
                       unlabeled_data_loader,
                       eval_callback)

# Step 6. (Optional) Set adaround params
ADAROUND_DATASET_SIZE = 2000
adaround_data_loader = _create_sampled_data_loader(unlabeled_dataset, ADAROUND_DATASET_SIZE)
adaround_params = AdaroundParameters(adaround_data_loader, num_batches=len(adaround_data_loader))
auto_quant.set_adaround_params(adaround_params)

# Step 7. Run AutoQuant
sim, initial_accuracy = auto_quant.run_inference()
model, optimized_accuracy, encoding_path = auto_quant.optimize(allowed_accuracy_drop=0.01)

print(f"- Quantized Accuracy (before optimization): {initial_accuracy:.4f}")
print(f"- Quantized Accuracy (after optimization):  {optimized_accuracy:.4f}")

Note

To use auto_quant.AutoQuant (will be deprecated), apply the following code changes to step 5 and 7.

auto_quant = AutoQuant(allowed_accuracy_drop=0.01,
                       unlabeled_dataset_iterable=unlabeled_data_loader,
                       eval_callback=eval_callback)

# Step 6. (Optional) Set adaround params
ADAROUND_DATASET_SIZE = 2000
adaround_data_loader = _create_sampled_data_loader(unlabeled_dataset, ADAROUND_DATASET_SIZE)
adaround_params = AdaroundParameters(adaround_data_loader, num_batches=len(adaround_data_loader))
auto_quant.set_adaround_params(adaround_params)

# Step 7. Run AutoQuant
model, accuracy, encoding_path =\
    auto_quant.apply(fp32_model.cuda(),
                     dummy_input_on_cpu=dummy_input.cpu(),
                     dummy_input_on_gpu=dummy_input.cuda())

print(f"- Quantized Accuracy (after optimization):  {optimized_accuracy:.4f}")