AIMET ONNX Quantization SIM API¶
Top-level API¶
Note about Quantization Schemes : Since ONNX Runtime will be used for optimized inference only, ONNX framework will support Post Training Quantization schemes i.e. TF or TF-enhanced to compute the encodings.
The following API can be used to Compute Encodings for Model
The following API can be used to Export the Model to target
Code Examples¶
Required imports
from aimet_onnx.quantsim import QuantizationSimModel
from aimet_common.defs import QuantScheme
User should write this function to pass calibration data
def pass_calibration_data(session):
"""
The User of the QuantizationSimModel API is expected to write this function based on their data set.
This is not a working function and is provided only as a guideline.
:param session: Model's session
:return:
"""
# User action required
# The following line of code is an example of how to use the ImageNet data's validation data loader.
# Replace the following line with your own dataset's validation data loader.
data_loader = None # Your Dataset's data loader
# User action required
# For computing the activation encodings, around 1000 unlabelled data samples are required.
# Edit the following 2 lines based on your dataloader's batch size.
# batch_size * max_batch_counter should be 1024
batch_size = 64
max_batch_counter = 16
input_tensor = None # input tensor in session
current_batch_counter = 0
for input_data, _ in data_loader:
session.run(None, input_data)
current_batch_counter += 1
if current_batch_counter == max_batch_counter:
break
Quantize the model and finetune (QAT)
def quantize_model():
onnx_model = Model()
input_shape = (1, 3, 224, 224)
dummy_data = np.random.randn(*input_shape).astype(np.float32)
dummy_input = {'input' : dummy_data}
sim = QuantizationSimModel(onnx_model, dummy_input, quant_scheme=QuantScheme.post_training_tf,
rounding_mode='nearest', default_param_bw=8, default_activation_bw=8,
use_symmetric_encodings=False, use_cuda=False)
sim.compute_encodings(pass_calibration_data, None)
# Evaluate the quant sim
forward_pass_function(sim.session)