Automatic quantization

Context

AIMET toolkit offers a suite of neural network post-training quantization (PTQ) techniques. Often, applying these techniques in a specific sequence results in better quantized accuracy and performance.

The Automatic quantization (AutoQuant) feature analyzes your pre-trained model, determines the best sequence of AIMET PTQ quantization techniques, and applies these techniques. You can specify the accuracy drop that can be tolerated in the AutoQuant API. As soon as this threshold accuracy is reached, AutoQuant stops applying PTQ quantization techniques.

Without the AutoQuant feature, you must manually try combinations of AIMET quantization techniques. This manual process is error-prone and time-consuming.

Prerequisites

Workflow

The workflow looks like this:

../images/auto_quant_v2_flowchart.png

Before entering the optimization workflow, AutoQuant prepares by:

  1. Checking the validity of the model and converting the model into an AIMET quantization-friendly format (Prepare Model).

  2. Selecting the best-performing quantization scheme for the given model (QuantScheme Selection)

After the preparation steps, AutoQuant proceeds to try three techniques:

  1. BatchNorm folding

  2. Cross-layer equalization (CLE)

  3. Adaptive rounding (Adaround) (if enabled)

  4. Automatic Mixed Precision (AMP) (if enabled)

These techniques are applied in a best-effort manner until the model meets the allowed accuracy drop. If applying AutoQuant fails to satisfy the evaluation goal, AutoQuant returns the model that returned the best results.

Code example

Step 1

Load the model for automatic quantization.


import torch
import math
import numpy as np

from torchvision.models import mobilenet_v2
from torchvision import transforms
from torch.utils.data import DataLoader
from datasets import load_dataset

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(pretrained=True).eval().to(device)

Load the model for automatic quantization. In this code example, we will use MobileNetV2

import tensorflow as tf
from aimet_common.defs import QuantizationDataType, QuantScheme
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_tensorflow.keras.adaround_weight import AdaroundParameters
from aimet_tensorflow.keras.auto_quant_v2 import AutoQuantWithAutoMixedPrecision
from tensorflow.keras import applications, losses, metrics, preprocessing
from tensorflow.keras.applications import mobilenet_v2

model = applications.MobileNetV2()

Load the model for automatic quantization. In this code example, we will convert PyTorch MobileNetV2 to ONNX and use it in the subsequent code

import math
import os
from typing import Optional

import numpy as np
import onnx
import onnxruntime as ort
import onnxsim
import torch
from aimet_common.defs import QuantizationDataType
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_onnx.adaround.adaround_weight import AdaroundParameters
from aimet_onnx.auto_quant_v2 import AutoQuantWithAutoMixedPrecision
from aimet_onnx.defs import DataLoader
from datasets import load_dataset
from torchvision import transforms
from torchvision.models import MobileNet_V2_Weights, mobilenet_v2
from tqdm import tqdm

pt_model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
print(pt_model)

# Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)

# Modify file_path as you wish, we are using temporary directory for now
file_path = os.path.join('/tmp', f'mobilenet_v2.onnx')
torch.onnx.export(
    pt_model,
    (dummy_input,),
    file_path,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'},
    },
)
# Load exported ONNX model
model = onnx.load_model(file_path)

Step 2

Prepare model and dataloader

num_batches = 32
labeled_data = load_dataset('imagenet-1k', streaming=True, split="train")
labeled_data_loader = DataLoader(labeled_data, batch_size=num_batches, num_workers = 4)
dummy_input = torch.randn(1, 3, 224, 224).to(device)

dataset = load_dataset(
    'ILSVRC/imagenet-1k',
    split='validation',
)
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

def transforms(examples):
    examples['image'] = [
        preprocess(image.convert('RGB')).numpy() for image in examples['image']
    ]
    examples['image'] = [np.expand_dims(image, axis=0) for image in examples['image']]
    return examples


dataset.set_transform(transforms)

BATCH_SIZE = 1
EVAL_DATASET_SIZE = 64
CALIBRATION_DATASET_SIZE = 32

class CustomDataLoader(DataLoader):
    def __init__(self, data: np.ndarray, batch_size: int, iterations: int):
        super().__init__(data, batch_size, iterations)
        self._batch_index = 0

    def __iter__(self):
        self._batch_index = 0
        return self

    def __next__(self):
        if self._batch_index < self.iterations:
            str_idx = self._batch_index
            end_idx = self._batch_index + self.batch_size
            self._batch_index += 1
            return self._data[str_idx:end_idx]
        else:
            raise StopIteration


unlabelled_data_loader = CustomDataLoader(
    dataset['image'], BATCH_SIZE, math.ceil(CALIBRATION_DATASET_SIZE / BATCH_SIZE)
)
unlabeled_data = load_dataset('imagenet-1k', streaming=True, split="unlabeled")
unlabeled_data_loader = DataLoader(unlabeled_data, batch_size=num_batches, num_workers = 4)

Prepare dataset

BATCH_SIZE = 32
imagenet_dataset = preprocessing.image_dataset_from_directory(
    directory='<your_imagenet_validation_data_path>',
    label_mode='categorical',
    image_size=(224, 224),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

imagenet_dataset = imagenet_dataset.map(
    lambda x, y: (mobilenet_v2.preprocess_input(x), y)
)

NUM_CALIBRATION_SAMPLES = 2048
unlabeled_dataset = imagenet_dataset.take(NUM_CALIBRATION_SAMPLES // BATCH_SIZE).map(
    lambda x, _: x
)
eval_dataset = imagenet_dataset.skip(NUM_CALIBRATION_SAMPLES // BATCH_SIZE)

Prepare model and dataloader

try:
    model, _ = onnxsim.simplify(model)
except:
    print('ONNX Simplifier failed. Proceeding with unsimplified model')


dataset = load_dataset(
    'ILSVRC/imagenet-1k',
    split='validation',
).shuffle()


class CustomDataLoader(DataLoader):
    def __init__(
        self,
        data: np.ndarray,
        batch_size: int,
        iterations: int,
        unlabeled: bool = True,
    ):
        super().__init__(data, batch_size, iterations)
        self._current_iteration = 0
        self._unlabeled = unlabeled

    def __iter__(self):
        self._current_iteration = 0
        return self

    def __next__(self):
        if self._current_iteration < self.iterations:
            start = self._current_iteration * self.batch_size
            end = start + self.batch_size
            self._current_iteration += 1

            batch_data = self._data[start:end]
            if self._unlabeled:
                return np.stack(batch_data['image'])
            else:
                return np.stack(batch_data['image']), np.stack(batch_data['label'])
        else:
            raise StopIteration


preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


def transforms(examples):
    examples['image'] = [
        preprocess(image.convert('RGB')) for image in examples['image']
    ]
    return examples


dataset.set_transform(transforms)

BATCH_SIZE = 32
NUM_CALIBRATION_SAMPLES = 2048
NUM_EVAL_SAMPLES = 50000
unlabeled_data_loader = CustomDataLoader(
    dataset, BATCH_SIZE, math.ceil(NUM_CALIBRATION_SAMPLES / BATCH_SIZE)
)
eval_data_loader = CustomDataLoader(
    dataset, BATCH_SIZE, math.ceil(NUM_EVAL_SAMPLES / BATCH_SIZE), unlabeled=False
)

Step 3

Prepare eval callback

In the actual use cases, the users should implement this part to serve their own goals, maintaining the function signature.

from typing import Optional

def eval_callback(model: torch.nn.Module, num_of_samples: Optional[int] = None) -> float:
    data_loader = CustomDataLoader(
        dataset, BATCH_SIZE, math.ceil(EVAL_DATASET_SIZE / BATCH_SIZE)
    )
    if num_of_samples:
        iterations = math.ceil(num_of_samples / data_loader.batch_size)
    else:
        iterations = len(data_loader)
    batch_cntr = 1
    acc_top1 = 0
    for data in data_loader:
        input_data = data['image'][0]
        target = data['label']
        pred_probs = model(input_data.cuda())
        pred_labels = np.argmax(pred_probs, axis=1)
        acc_top1 += np.sum(pred_labels == target)

        batch_cntr += 1
        if batch_cntr > iterations:
            break
    acc_top1 /= len(data_loader)
    return acc_top1

Prepare eval callback

def eval_callback(model: tf.keras.Model, _) -> float:
    # Model should be compiled before evaluation
    model.compile(
        loss=losses.CategoricalCrossentropy(), metrics=metrics.CategoricalAccuracy()
    )
    _, acc = model.evaluate(eval_dataset)

    return acc

Prepare eval callback

def eval_callback(
    session: ort.InferenceSession, num_of_samples: Optional[int] = None
) -> float:
    correct_predictions = 0
    total_samples = 0
    for inputs, labels in tqdm(eval_data_loader):
        input_name = session.get_inputs()[0].name
        pred_probs, *_ = session.run(None, {input_name: inputs})
        pred_labels = np.argmax(pred_probs, axis=1)
        correct_predictions += np.sum(pred_labels == labels)
        total_samples += labels.shape[0]

    accuracy = correct_predictions / total_samples
    return accuracy

Step 4

Create AutoQuant object.

from aimet_torch.auto_quant import AutoQuantWithAutoMixedPrecision

auto_quant = AutoQuantWithAutoMixedPrecision(
    model, dummy_input, unlabelled_data_loader, eval_callback
)

Create AutoQuant object

auto_quant = AutoQuantWithAutoMixedPrecision(
    model,
    eval_callback,
    unlabeled_dataset,
    param_bw=4,
    output_bw=8,
    quant_scheme=QuantScheme.post_training_tf,
    config_file=get_path_for_per_channel_config(),
)

Create AutoQuant object

dummy_input = {'input': np.random.randn(*input_shape).astype(np.float32)}
auto_quant = AutoQuantWithAutoMixedPrecision(
    model,
    dummy_input,
    unlabeled_data_loader,
    eval_callback,
    param_bw=4,
    output_bw=8,
    config_file=get_path_for_per_channel_config(),
)

Step 5

Set AdaRound params

from aimet_torch.adaround.adaround_weight import AdaroundParameters

ADAROUND_DATASET_SIZE = 128
adaround_data_loader = DataLoader(
    data=dataset['image'],
    batch_size=BATCH_SIZE,
    iterations=math.ceil(ADAROUND_DATASET_SIZE / BATCH_SIZE),
)
adaround_params = AdaroundParameters(
    adaround_data_loader, num_batches=len(adaround_data_loader)
)
auto_quant.set_adaround_params(adaround_params)

Set AdaRound params

adaround_params = AdaroundParameters(
    unlabeled_dataset, num_batches=NUM_CALIBRATION_SAMPLES // BATCH_SIZE
)
auto_quant.set_adaround_params(adaround_params)

Set AdaRound params

adaround_params = AdaroundParameters(
    unlabeled_data_loader, num_batches=len(unlabeled_data_loader)
)
auto_quant.set_adaround_params(adaround_params)

Step 6

Set AMP params

from aimet_common.defs import QuantizationDataType

W8A8 = (
    (8, QuantizationDataType.int),  # A: int8
    (8, QuantizationDataType.int),  # W: int8
)
W8A16 = (
    (16, QuantizationDataType.int),  # A: int16
    (8, QuantizationDataType.int),  # W: int8
)
auto_quant.set_mixed_precision_params(candidates=[W8A16, W8A8])

Set AMP params

W4A8 = (
    (8, QuantizationDataType.int),  # A: int8
    (4, QuantizationDataType.int),  # W: int4
)
W8A8 = (
    (8, QuantizationDataType.int),  # A: int8
    (8, QuantizationDataType.int),  # W: int8
)
auto_quant.set_mixed_precision_params(candidates=[W4A8, W8A8])

Set AMP params

W4A8 = (
    (8, QuantizationDataType.int),  # A: int8
    (4, QuantizationDataType.int),  # W: int4
)
W8A8 = (
    (8, QuantizationDataType.int),  # A: int8
    (8, QuantizationDataType.int),  # W: int8
)
auto_quant.set_mixed_precision_params(candidates=[W4A8, W8A8])

Step 7

Run AutoQuant

sim, initial_accuracy = auto_quant.run_inference()
model, optimized_accuracy, encoding_path, pareto_front = auto_quant.optimize(allowed_accuracy_drop=0.01)

print(f"- Quantized Accuracy (before optimization): {initial_accuracy:.4f}")
print(f"- Quantized Accuracy (after optimization):  {optimized_accuracy:.4f}")

Run AutoQuant

sim, initial_accuracy = auto_quant.run_inference()
model, optimized_accuracy, encoding_path, pareto_front = auto_quant.optimize(
    allowed_accuracy_drop=0.01
)

print(f'- Quantized Accuracy (before optimization): {initial_accuracy:.4f}')
print(f'- Quantized Accuracy (after optimization):  {optimized_accuracy:.4f}')
- Quantized Accuracy (before optimization): 0.0235
- Quantized Accuracy (after optimization):  0.7164

Run AutoQuant

sim, initial_accuracy = auto_quant.run_inference()
model, optimized_accuracy, encoding_path, pareto_front = auto_quant.optimize(
    allowed_accuracy_drop=0.01
)

print(f'- Quantized Accuracy (before optimization): {initial_accuracy:.4f}')
print(f'- Quantized Accuracy (after optimization):  {optimized_accuracy:.4f}')
- Quantized Accuracy (before optimization): 0.0235
- Quantized Accuracy (after optimization):  0.7164

Results

Next steps

API

Top-level API

class aimet_torch.auto_quant.AutoQuantWithAutoMixedPrecision(model, dummy_input, data_loader, eval_callback, param_bw=8, output_bw=8, quant_scheme=QuantScheme.post_training_tf_enhanced, rounding_mode='nearest', config_file=None, results_dir='/tmp', cache_id=None, strict_validation=True, model_prepare_required=True)[source]

Integrate and apply post-training quantization techniques.

AutoQuant includes 1) batchnorm folding, 2) cross-layer equalization, 3) Adaround, and 4) Automatic Mixed Precision (if enabled). These techniques will be applied in a best-effort manner until the model meets the evaluation goal given as allowed_accuracy_drop.

Parameters:
  • model (Module) – Model to be quantized. Assumes model is on the correct device

  • dummy_input (Union[Tensor, Tuple]) – Dummy input for the model. Assumes that dummy_input is on the correct device

  • data_loader (DataLoader) – A collection that iterates over an unlabeled dataset, used for computing encodings

  • eval_callback (Callable[[Module], float]) – Function that calculates the evaluation score

  • param_bw (int) – Parameter bitwidth

  • output_bw (int) – Output bitwidth

  • quant_scheme (QuantScheme) – Quantization scheme

  • rounding_mode (str) – Rounding mode

  • config_file (Optional[str]) – Path to configuration file for model quantizers

  • results_dir (str) – Directory to save the results of PTQ techniques

  • cache_id (Optional[str]) – ID associated with cache results

  • strict_validation (bool) – Flag set to True by default.hen False, AutoQuant will proceed with execution and handle errors internally if possible. This may produce unideal or unintuitive results.

  • model_prepare_required (bool) – Flag set to True by default.If False, AutoQuant will skip model prepare block in the pipeline.

run_inference()[source]

Creates a quantization model and performs inference

Return type:

Tuple[QuantizationSimModel, float]

Returns:

QuantizationSimModel, model accuracy as float

optimize(allowed_accuracy_drop=0.0)[source]

Integrate and apply post-training quantization techniques.

Parameters:

allowed_accuracy_drop (float) – Maximum allowed accuracy drop

Return type:

Tuple[Module, float, str, List[Tuple[int, float, QuantizerGroup, Tuple]]]

Returns:

Tuple of (best model, eval score, encoding path, pareto front). Pareto front is None if AMP is not enabled or AutoQuant exits without performing AMP.

set_adaround_params(adaround_params)[source]

Set Adaround parameters. If this method is not called explicitly by the user, AutoQuant will use data_loader (passed to __init__) for Adaround.

Parameters:

adaround_params (AdaroundParameters) – Adaround parameters.

Return type:

None

set_export_params(onnx_export_args=-1, propagate_encodings=None)[source]

Set parameters for QuantizationSimModel.export.

Parameters:
  • onnx_export_args (OnnxExportApiArgs) – optional export argument with onnx specific overrides if not provide export via torchscript graph

  • propagate_encodings (Optional[bool]) – If True, encoding entries for intermediate ops (when one PyTorch ops results in multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of ops.

Return type:

None

set_mixed_precision_params(candidates, num_samples_for_phase_1=128, forward_fn=<function _default_forward_fn>, num_samples_for_phase_2=None)[source]

Set mixed precision parameters. NOTE: Automatic mixed precision will NOT be enabled unless this method is explicitly called by the user.

Parameters:
  • candidates (List[Tuple[Tuple[int, QuantizationDataType], Tuple[int, QuantizationDataType]]]) – List of tuples of candidate bitwidths and datatypes.

  • num_samples_for_phase_1 (Optional[int]) – Number of samples to be used for performance evaluation in AMP phase 1.

  • forward_fn (Callable) – Function that runs forward pass and returns the output tensor. which will be used for SQNR compuatation in phase 1. This function is expected to take 1) a model and 2) a single batch yielded from the data loader, and return a single torch.Tensor object which represents the output of the model. The default forward function is roughly equivalent to lambda model, batch: model(batch)

  • num_samples_for_phase_2 (Optional[int]) – Number of samples to be used for performance evaluation in AMP phase 2.

Return type:

None

set_model_preparer_params(modules_to_exclude=None, concrete_args=None)[source]

Set parameters for model preparer.

Parameters:
  • modules_to_exclude (Optional[List[Module]]) – List of modules to exclude when tracing.

  • concrete_args (Optional[Dict[str, Any]]) – Parameter for model preparer. Allows you to partially specialize your function, whether it’s to remove control flow or data structures. If the model has control flow, torch.fx won’t be able to trace the model. Check torch.fx.symbolic_trace API in detail.

get_quant_scheme_candidates()[source]

Return the candidates for quant scheme search. During optimize(), the candidate with the highest accuracy will be selected among them.

Return type:

Tuple[_QuantSchemePair, ...]

Returns:

Candidates for quant scheme search

set_quant_scheme_candidates(candidates)[source]

Set candidates for quant scheme search. During optimize(), the candidate with the highest accuracy will be selected among them.

Parameters:

candidates (Tuple[_QuantSchemePair, ...]) – Candidates for quant scheme search

Top-level API

class aimet_tensorflow.keras.auto_quant_v2.AutoQuantWithAutoMixedPrecision(model, eval_callback, dataset, param_bw=8, output_bw=8, quant_scheme=QuantScheme.post_training_tf_enhanced, rounding_mode='nearest', config_file=None, results_dir='/tmp', cache_id=None, strict_validation=True)[source]

Integrate and apply post-training quantization techniques.

AutoQuant includes 1) batchnorm folding, 2) cross-layer equalization, 3) Adaround, and 4) Automatic Mixed Precision (if enabled). These techniques will be applied in a best-effort manner until the model meets the evaluation goal given as allowed_accuracy_drop.

Parameters:
  • model (Model) – Model to be quantized. Assumes model is on the correct device

  • eval_callback (Callable[[Model], float]) – A function that maps model and the number samples to the evaluation score. This callback is expected to return a scalar value representing the model performance evaluated against exactly N samples, where N is the number of samples passed as the second argument of this callback. NOTE: If N is None, the model is expected to be evaluated against the whole evaluation dataset.

  • dataset (DatasetV2) – An unlabeled dataset for encoding computation. By default, this dataset will be also used for Adaround unless otherwise specified by self.set_adaround_params

  • param_bw (int) – Parameter bitwidth

  • output_bw (int) – Output bitwidth

  • quant_scheme (QuantScheme) – Quantization scheme

  • rounding_mode (str) – Rounding mode

  • config_file (Optional[str]) – Path to configuration file for model quantizers

  • results_dir (str) – Directory to save the results of PTQ techniques

  • cache_id (Optional[str]) – ID associated with cache results

  • strict_validation (bool) – Flag set to True by default.hen False, AutoQuant will proceed with execution and handle errors internally if possible. This may produce unideal or unintuitive results.

run_inference()[source]

Creates a quantization model and performs inference

Return type:

Tuple[QuantizationSimModel, float]

Returns:

QuantizationSimModel, model accuracy as float

optimize(allowed_accuracy_drop=0.0)[source]

Integrate and apply post-training quantization techniques.

Parameters:

allowed_accuracy_drop (float) – Maximum allowed accuracy drop

Return type:

Tuple[Model, float, str, List[Tuple[int, float, QuantizerGroup, Tuple]]]

Returns:

Tuple of (best model, eval score, encoding path, pareto front). Pareto front is None if AMP is not enabled or AutoQuant exits without performing AMP.

set_adaround_params(adaround_params)[source]

Set Adaround parameters. If this method is not called explicitly by the user, AutoQuant will use dataset (passed to __init__) for Adaround.

Parameters:

adaround_params (AdaroundParameters) – Adaround parameters.

set_mixed_precision_params(candidates, num_samples_for_phase_1=128, forward_fn=<function _default_forward_fn>, num_samples_for_phase_2=None)[source]

Set mixed precision parameters. NOTE: Automatic mixed precision will NOT be enabled unless this method is explicitly called by the user.

Parameters:
  • candidates (List[Tuple[Tuple[int, QuantizationDataType], Tuple[int, QuantizationDataType]]]) – List of tuples of candidate bitwidths and datatypes.

  • num_samples_for_phase_1 (Optional[int]) – Number of samples to be used for performance evaluation in AMP phase 1.

  • forward_fn (Callable) – Function that runs forward pass and returns the output tensor. which will be used for SQNR compuatation in phase 1. This function is expected to take 1) a model and 2) a single batch yielded from the dataset, and return a single torch.Tensor object which represents the output of the model.

  • num_samples_for_phase_2 (Optional[int]) – Number of samples to be used for performance evaluation in AMP phase 2.

get_quant_scheme_candidates()[source]

Return the candidates for quant scheme search. During optimize(), the candidate with the highest accuracy will be selected among them.

Return type:

Tuple[_QuantSchemePair, ...]

Returns:

Candidates for quant scheme search

set_quant_scheme_candidates(candidates)[source]

Set candidates for quant scheme search. During optimize(), the candidate with the highest accuracy will be selected among them.

Parameters:

candidates (Tuple[_QuantSchemePair, ...]) – Candidates for quant scheme search

Top-level API

class aimet_onnx.auto_quant_v2.AutoQuantWithAutoMixedPrecision(model, dummy_input, data_loader, eval_callback, param_bw=8, output_bw=8, quant_scheme=QuantScheme.post_training_tf_enhanced, rounding_mode='nearest', use_cuda=True, device=0, config_file=None, results_dir='/tmp', cache_id=None, strict_validation=True)[source]

Integrate and apply post-training quantization techniques.

AutoQuant includes 1) batchnorm folding, 2) cross-layer equalization, 3) Adaround, and 4) Automatic Mixed Precision (if enabled). These techniques will be applied in a best-effort manner until the model meets the evaluation goal given as allowed_accuracy_drop.

Parameters:
  • model (ONNXModel) – Model to be quantized.

  • dummy_input (Dict[str, ndarray]) – Dummy input dict for the model.

  • data_loader (DataLoader) – A collection that iterates over an unlabeled dataset, used for computing encodings

  • eval_callback (Callable[[InferenceSession, int], float]) – Function that calculates the evaluation score given the model session

  • param_bw (int) – Parameter bitwidth

  • output_bw (int) – Output bitwidth

  • quant_scheme (QuantScheme) – Quantization scheme

  • rounding_mode (str) – Rounding mode

  • use_cuda (bool) – True if using CUDA to run quantization op. False otherwise.

  • config_file (Optional[str]) – Path to configuration file for model quantizers

  • results_dir (str) – Directory to save the results of PTQ techniques

  • cache_id (Optional[str]) – ID associated with cache results

  • strict_validation (bool) – Flag set to True by default.When False, AutoQuant will proceed with execution and handle errors internally if possible. This may produce unideal or unintuitive results.

run_inference()[source]

Creates a quantization model and performs inference

Return type:

Tuple[QuantizationSimModel, float]

Returns:

QuantizationSimModel, model accuracy as float

optimize(allowed_accuracy_drop=0.0)[source]

Integrate and apply post-training quantization techniques.

Parameters:

allowed_accuracy_drop (float) – Maximum allowed accuracy drop

Return type:

Tuple[ONNXModel, float, str, List[Tuple[int, float, QuantizerGroup, Tuple]]]

Returns:

Tuple of (best model, eval score, encoding path, pareto front). Pareto front is None if AMP is not enabled or AutoQuant exits without performing AMP.

set_adaround_params(adaround_params)[source]

Set Adaround parameters. If this method is not called explicitly by the user, AutoQuant will use data_loader (passed to __init__) for Adaround.

Parameters:

adaround_params (AdaroundParameters) – Adaround parameters.

Return type:

None

set_mixed_precision_params(candidates, num_samples_for_phase_1=128, forward_fn=<function _default_forward_fn>, num_samples_for_phase_2=None)[source]

Set mixed precision parameters. NOTE: Automatic mixed precision will NOT be enabled unless this method is explicitly called by the user.

Parameters:
  • candidates (List[Tuple[Tuple[int, QuantizationDataType], Tuple[int, QuantizationDataType]]]) – List of tuples of candidate bitwidths and datatypes.

  • num_samples_for_phase_1 (Optional[int]) – Number of samples to be used for performance evaluation in AMP phase 1.

  • forward_fn (Callable) – Function that runs forward pass and returns the output tensor. which will be used for SQNR compuatation in phase 1. This function is expected to take 1) a model and 2) a single batch yielded from the data loader, and return a single np.ndarray object which represents the output of the model.

  • num_samples_for_phase_2 (Optional[int]) – Number of samples to be used for performance evaluation in AMP phase 2.

Return type:

None

get_quant_scheme_candidates()[source]

Return the candidates for quant scheme search. During optimize(), the candidate with the highest accuracy will be selected among them.

Return type:

Tuple[_QuantSchemePair, ...]

Returns:

Candidates for quant scheme search

set_quant_scheme_candidates(candidates)[source]

Set candidates for quant scheme search. During optimize(), the candidate with the highest accuracy will be selected among them.

Parameters:

candidates (Tuple[_QuantSchemePair, ...]) – Candidates for quant scheme search