AIMET ONNX AutoQuant API

Top-level API

class aimet_onnx.auto_quant_v2.AutoQuant(model, dummy_input, data_loader, eval_callback, param_bw=8, output_bw=8, quant_scheme=QuantScheme.post_training_tf_enhanced, rounding_mode='nearest', use_cuda=True, device=0, config_file=None, results_dir='/tmp', cache_id=None, strict_validation=True)[source]

Integrate and apply post-training quantization techniques.

AutoQuant includes 1) batchnorm folding, 2) cross-layer equalization, and 3) Adaround. These techniques will be applied in a best-effort manner until the model meets the evaluation goal given as allowed_accuracy_drop.

Parameters:
  • model (Union[ModelProto, ONNXModel]) – Model to be quantized.

  • dummy_input (Dict[str, ndarray]) – Dummy input dict for the model.

  • data_loader (Iterable[Union[ndarray, List[ndarray], Tuple[ndarray]]]) – A collection that iterates over an unlabeled dataset, used for computing encodings

  • eval_callback (Callable[[InferenceSession, int], float]) – Function that calculates the evaluation score given the model session

  • param_bw (int) – Parameter bitwidth

  • output_bw (int) – Output bitwidth

  • quant_scheme (QuantScheme) – Quantization scheme

  • rounding_mode (str) – Rounding mode

  • use_cuda (bool) – True if using CUDA to run quantization op. False otherwise.

  • config_file (Optional[str]) – Path to configuration file for model quantizers

  • results_dir (str) – Directory to save the results of PTQ techniques

  • cache_id (Optional[str]) – ID associated with cache results

  • strict_validation (bool) – Flag set to True by default. When False, AutoQuant will proceed with execution and handle errors internally if possible. This may produce unideal or unintuitive results.

run_inference()[source]

Creates a quantization model and performs inference

Return type:

Tuple[QuantizationSimModel, float]

Returns:

QuantizationSimModel, model accuracy as float

optimize(allowed_accuracy_drop=0.0)[source]

Integrate and apply post-training quantization techniques.

Parameters:

allowed_accuracy_drop (float) – Maximum allowed accuracy drop

Return type:

Tuple[ONNXModel, float, str]

Returns:

Tuple of (best model, eval score, encoding path)

set_adaround_params(adaround_params)[source]

Set Adaround parameters. If this method is not called explicitly by the user, AutoQuant will use data_loader (passed to __init__) for Adaround.

Parameters:

adaround_params (AdaroundParameters) – Adaround parameters.

Return type:

None

get_quant_scheme_candidates()[source]

Return the candidates for quant scheme search. During optimize(), the candidate with the highest accuracy will be selected among them.

Return type:

Tuple[_QuantSchemePair, ...]

Returns:

Candidates for quant scheme search

set_quant_scheme_candidates(candidates)[source]

Set candidates for quant scheme search. During optimize(), the candidate with the highest accuracy will be selected among them.

Parameters:

candidates (Tuple[_QuantSchemePair, ...]) – Candidates for quant scheme search

Code Examples

import math
import onnxruntime as ort
import numpy as np

from aimet_onnx.auto_quant_v2 import AutoQuant
from aimet_onnx.adaround.adaround_weight import AdaroundParameters

# Step 1. Define constants
EVAL_DATASET_SIZE = 5000
CALIBRATION_DATASET_SIZE = 500
BATCH_SIZE = 32

# Step 2. Prepare model and dataloader
onnx_model = Model()

input_shape = (1, 3, 224, 224)
dummy_data = np.random.randn(*input_shape).astype(np.float32)
dummy_input = {'input': dummy_data}

# NOTE: Use your dataloader. It should iterate over unlabelled dataset.
#       Its data will be directly fed as input to the onnx model's inference session.
unlabelled_data_loader = DataLoader(data=data, batch_size=BATCH_SIZE,
                                    iterations=math.ceil(CALIBRATION_DATASET_SIZE / BATCH_SIZE))

# Step 3. Prepare eval callback
# NOTE: In the actual use cases, the users should implement this part to serve
#       their own goals, maintaining the function signature.
def eval_callback(session: ort.InferenceSession, num_of_samples: Optional[int] = None) -> float:
    data_loader = EvalDataLoader()
    if num_of_samples:
        iterations = math.ceil(num_of_samples / data_loader.batch_size)
    else:
        iterations = len(data_loader)
    batch_cntr = 1
    acc_top1 = 0
    acc_top5 = 0
    for input_data, target in data_loader:
        pred = session.run(None, {'input': input_data})

        batch_avg_top_1_5 = accuracy(pred, target, topk=(1, 5))

        acc_top1 += batch_avg_top_1_5[0].item()
        acc_top5 += batch_avg_top_1_5[1].item()

        batch_cntr += 1
        if batch_cntr > iterations:
            break
    acc_top1 /= iterations
    acc_top5 /= iterations
    return acc_top1

# Step 4. Create AutoQuant object
auto_quant = AutoQuant(onnx_model,
                       dummy_input,
                       unlabelled_data_loader,
                       eval_callback)

# Step 5. (Optional) Set AdaRound params
ADAROUND_DATASET_SIZE = 2000
adaround_data_loader = DataLoader(data=data, batch_size=BATCH_SIZE,
                                  iterations=math.ceil(ADAROUND_DATASET_SIZE / BATCH_SIZE))
adaround_params = AdaroundParameters(adaround_data_loader, num_batches=len(adaround_data_loader))
auto_quant.set_adaround_params(adaround_params)

# Step 6. Run AutoQuant
sim, initial_accuracy = auto_quant.run_inference()
model, optimized_accuracy, encoding_path = auto_quant.optimize(allowed_accuracy_drop=0.01)

print(f"- Quantized Accuracy (before optimization): {initial_accuracy:.4f}")
print(f"- Quantized Accuracy (after optimization):  {optimized_accuracy:.4f}")