AIMET ONNX AutoQuant API
User Guide Link
To learn more about this technique, please see AutoQuant
Top-level API
- class aimet_onnx.auto_quant_v2.AutoQuant(model, dummy_input, data_loader, eval_callback, param_bw=8, output_bw=8, quant_scheme=QuantScheme.post_training_tf_enhanced, rounding_mode='nearest', use_cuda=True, device=0, config_file=None, results_dir='/tmp', cache_id=None, strict_validation=True)[source]
 Integrate and apply post-training quantization techniques.
AutoQuant includes 1) batchnorm folding, 2) cross-layer equalization, and 3) Adaround. These techniques will be applied in a best-effort manner until the model meets the evaluation goal given as allowed_accuracy_drop.
- Parameters:
 model (
Union[ModelProto,ONNXModel]) – Model to be quantized.dummy_input (
Dict[str,ndarray]) – Dummy input dict for the model.data_loader (
Iterable[Union[ndarray,List[ndarray],Tuple[ndarray]]]) – A collection that iterates over an unlabeled dataset, used for computing encodingseval_callback (
Callable[[InferenceSession,int],float]) – Function that calculates the evaluation score given the model sessionparam_bw (
int) – Parameter bitwidthoutput_bw (
int) – Output bitwidthquant_scheme (
QuantScheme) – Quantization schemerounding_mode (
str) – Rounding modeuse_cuda (
bool) – True if using CUDA to run quantization op. False otherwise.config_file (
Optional[str]) – Path to configuration file for model quantizersresults_dir (
str) – Directory to save the results of PTQ techniquescache_id (
Optional[str]) – ID associated with cache resultsstrict_validation (
bool) – Flag set to True by default. When False, AutoQuant will proceed with execution and handle errors internally if possible. This may produce unideal or unintuitive results.
- run_inference()[source]
 Creates a quantization model and performs inference
- Return type:
 Tuple[QuantizationSimModel,float]- Returns:
 QuantizationSimModel, model accuracy as float
- optimize(allowed_accuracy_drop=0.0)[source]
 Integrate and apply post-training quantization techniques.
- Parameters:
 allowed_accuracy_drop (
float) – Maximum allowed accuracy drop- Return type:
 Tuple[ONNXModel,float,str]- Returns:
 Tuple of (best model, eval score, encoding path)
- set_adaround_params(adaround_params)[source]
 Set Adaround parameters. If this method is not called explicitly by the user, AutoQuant will use data_loader (passed to __init__) for Adaround.
- Parameters:
 adaround_params (
AdaroundParameters) – Adaround parameters.- Return type:
 None
- get_quant_scheme_candidates()[source]
 Return the candidates for quant scheme search. During
optimize(), the candidate with the highest accuracy will be selected among them.- Return type:
 Tuple[_QuantSchemePair,...]- Returns:
 Candidates for quant scheme search
- set_quant_scheme_candidates(candidates)[source]
 Set candidates for quant scheme search. During
optimize(), the candidate with the highest accuracy will be selected among them.- Parameters:
 candidates (
Tuple[_QuantSchemePair,...]) – Candidates for quant scheme search
Note: It is recommended to use onnx-simplifier before applying auto-quant.
Code Examples
import math
import onnxruntime as ort
import numpy as np
from onnxsim import simplify
from aimet_onnx.auto_quant_v2 import AutoQuant
from aimet_onnx.adaround.adaround_weight import AdaroundParameters
# Step 1. Define constants
EVAL_DATASET_SIZE = 5000
CALIBRATION_DATASET_SIZE = 500
BATCH_SIZE = 32
# Step 2. Prepare model and dataloader
onnx_model = Model()
# Simplify the model
onnx_model, _ = simplify(onnx_model)
input_shape = (1, 3, 224, 224)
dummy_data = np.random.randn(*input_shape).astype(np.float32)
dummy_input = {'input': dummy_data}
# NOTE: Use your dataloader. It should iterate over unlabelled dataset.
#       Its data will be directly fed as input to the onnx model's inference session.
unlabelled_data_loader = DataLoader(data=data, batch_size=BATCH_SIZE,
                                    iterations=math.ceil(CALIBRATION_DATASET_SIZE / BATCH_SIZE))
# Step 3. Prepare eval callback
# NOTE: In the actual use cases, the users should implement this part to serve
#       their own goals, maintaining the function signature.
def eval_callback(session: ort.InferenceSession, num_of_samples: Optional[int] = None) -> float:
    data_loader = EvalDataLoader()
    if num_of_samples:
        iterations = math.ceil(num_of_samples / data_loader.batch_size)
    else:
        iterations = len(data_loader)
    batch_cntr = 1
    acc_top1 = 0
    acc_top5 = 0
    for input_data, target in data_loader:
        pred = session.run(None, {'input': input_data})
        batch_avg_top_1_5 = accuracy(pred, target, topk=(1, 5))
        acc_top1 += batch_avg_top_1_5[0].item()
        acc_top5 += batch_avg_top_1_5[1].item()
        batch_cntr += 1
        if batch_cntr > iterations:
            break
    acc_top1 /= iterations
    acc_top5 /= iterations
    return acc_top1
# Step 4. Create AutoQuant object
auto_quant = AutoQuant(onnx_model,
                       dummy_input,
                       unlabelled_data_loader,
                       eval_callback)
# Step 5. (Optional) Set AdaRound params
ADAROUND_DATASET_SIZE = 2000
adaround_data_loader = DataLoader(data=data, batch_size=BATCH_SIZE,
                                  iterations=math.ceil(ADAROUND_DATASET_SIZE / BATCH_SIZE))
adaround_params = AdaroundParameters(adaround_data_loader, num_batches=len(adaround_data_loader))
auto_quant.set_adaround_params(adaround_params)
# Step 6. Run AutoQuant
sim, initial_accuracy = auto_quant.run_inference()
model, optimized_accuracy, encoding_path = auto_quant.optimize(allowed_accuracy_drop=0.01)
print(f"- Quantized Accuracy (before optimization): {initial_accuracy:.4f}")
print(f"- Quantized Accuracy (after optimization):  {optimized_accuracy:.4f}")