AIMET ONNX AutoQuant API

Top-level API

Code Examples

import math
import onnxruntime as ort
import numpy as np

from aimet_onnx.auto_quant_v2 import AutoQuant
from aimet_onnx.adaround.adaround_weight import AdaroundParameters

# Step 1. Define constants
EVAL_DATASET_SIZE = 5000
CALIBRATION_DATASET_SIZE = 500
BATCH_SIZE = 32

# Step 2. Prepare model and dataloader
onnx_model = Model()

input_shape = (1, 3, 224, 224)
dummy_data = np.random.randn(*input_shape).astype(np.float32)
dummy_input = {'input': dummy_data}

# NOTE: Use your dataloader. It should iterate over unlabelled dataset.
#       Its data will be directly fed as input to the onnx model's inference session.
unlabelled_data_loader = DataLoader(data=data, batch_size=BATCH_SIZE,
                                    iterations=math.ceil(CALIBRATION_DATASET_SIZE / BATCH_SIZE))

# Step 3. Prepare eval callback
# NOTE: In the actual use cases, the users should implement this part to serve
#       their own goals, maintaining the function signature.
def eval_callback(session: ort.InferenceSession, num_of_samples: Optional[int] = None) -> float:
    data_loader = EvalDataLoader()
    if num_of_samples:
        iterations = math.ceil(num_of_samples / data_loader.batch_size)
    else:
        iterations = len(data_loader)
    batch_cntr = 1
    acc_top1 = 0
    acc_top5 = 0
    for input_data, target in data_loader:
        pred = session.run(None, {'input': input_data})

        batch_avg_top_1_5 = accuracy(pred, target, topk=(1, 5))

        acc_top1 += batch_avg_top_1_5[0].item()
        acc_top5 += batch_avg_top_1_5[1].item()

        batch_cntr += 1
        if batch_cntr > iterations:
            break
    acc_top1 /= iterations
    acc_top5 /= iterations
    return acc_top1

# Step 4. Create AutoQuant object
auto_quant = AutoQuant(onnx_model,
                       dummy_input,
                       unlabelled_data_loader,
                       eval_callback)

# Step 5. (Optional) Set AdaRound params
ADAROUND_DATASET_SIZE = 2000
adaround_data_loader = DataLoader(data=data, batch_size=BATCH_SIZE,
                                  iterations=math.ceil(ADAROUND_DATASET_SIZE / BATCH_SIZE))
adaround_params = AdaroundParameters(adaround_data_loader, num_batches=len(adaround_data_loader))
auto_quant.set_adaround_params(adaround_params)

# Step 6. Run AutoQuant
sim, initial_accuracy = auto_quant.run_inference()
model, optimized_accuracy, encoding_path = auto_quant.optimize(allowed_accuracy_drop=0.01)

print(f"- Quantized Accuracy (before optimization): {initial_accuracy:.4f}")
print(f"- Quantized Accuracy (after optimization):  {optimized_accuracy:.4f}")