Layer output generation¶
Context¶
Layer output generation is an API that captures and saves intermediate layer-outputs of your pre-trained model. The model
can be original (FP32) or a QuantizationSimModel
.
The layer outputs are named according to the exported model (PyTorch, ONNX, or TensorFlow) by the
QuantSim export API QuantizationSimModel.export()
.
This enables layer output comparison between quantization simulated (QuantSim) models and quantized models on target runtimes like Qualcomm® AI Engine Direct to debug accuracy mismatch issues at the layer level (per operation).
Workflow¶
The layer output generation framework follows the same workflow for all model frameworks:
Imports
Load a model from AIMET
Obtain inputs
Generate layer outputs
Choose your framework below for code examples.
Step 1: Importing the API¶
Import the API.
import torch
from aimet_torch.v1.quantsim import QuantizationSimModel, load_encodings_to_sim
from aimet_torch.layer_output_utils import LayerOutputUtil, NamingScheme
from aimet_torch.onnx_utils import OnnxExportApiArgs
import tensorflow as tf
from aimet_tensorflow.keras.quantsim import QuantizationSimModel
from aimet_tensorflow.keras.layer_output_utils import LayerOutputUtil
import onnx
from onnxruntime import InferenceSession
from onnxsim import simplify
from aimet_onnx.quantsim import QuantizationSimModel, load_encodings_to_sim
from aimet_onnx.layer_output_utils import LayerOutputUtil
Step 2: Loading a model¶
Export the original or QuantSim model from AIMET.
# Load the model on CPU device. Ensure model definition is present in the PYTHONPATH to successfully load the model.
# If exported on CPU, load this way.
model = torch.load('path/to/aimet_export_artifacts/model.pth')
# Or
# If exported on GPU, load this way.
# model = torch.load('path/to/aimet_export_artifacts/model.pth', map_location=torch.device('cpu'))
dummy_input = torch.rand(1, 3, 224, 224)
# Use same arguments as that were used for the exported QuantSim model. For sake of simplicity only mandatory arguments are passed below.
quantsim = QuantizationSimModel(model=model, dummy_input=dummy_input)
# Load exported encodings into quantsim object
load_encodings_to_sim(quantsim, 'path/to/aimet_export_artifacts/model_torch.encodings')
# Check whether constructed original and quantsim model are running properly before using Layer Output Generation API.
_ = model(dummy_input)
_ = quantsim.model(dummy_input)
# Load the model.
model = tf.keras.models.load_model('path/to/aimet_export_artifacts/model.h5')
# Use same arguments as that were used for the exported QuantSim model. For sake of simplicity only mandatory arguments are passed below.
quantsim = QuantizationSimModel(model)
# Load exported encodings into quantsim object.
quantsim.load_encodings_to_sim('path/to/aimet_export_artifacts/model.encodings')
# Check whether constructed original and quantsim model are running properly before using Layer Output Generation API.
_ = model.predict(dummy_input)
_ = quantsim.predict(dummy_input)
# Load the model.
model = onnx.load('path/to/aimet_export_artifacts/model.onnx')
# Simplify the model
model, _ = simplify(model)
# Use same arguments as that were used for the exported QuantSim model. For sake of simplicity only mandatory arguments are passed below.
quantsim = QuantizationSimModel(model=model, dummy_input=dummy_input_dict, use_cuda=False)
# Load exported encodings into quantsim object
load_encodings_to_sim(quantsim, 'path/to/aimet_export_artifacts/model.encodings')
# Check whether constructed original and quantsim model are running properly before using Layer Output Generation API.
_ = InferenceSession(model.SerializeToString()).run(None, dummy_input_dict)
_ = quantsim.session.run(None, dummy_input_dict)
Step 3: Obtaining inputs¶
Obtain inputs from which to generate intermediate layer outputs.
# Use same input pre-processing pipeline as was used for computing the quantization encodings.
input_batches = get_pre_processed_inputs()
# Use same input pre-processing pipeline as was used for computing the quantization encodings.
input_batches = get_pre_processed_inputs()
# Use same input pre-processing pipeline as was used for computing the quantization encodings.
input_batches = get_pre_processed_inputs()
Step 4: Generating layer outputs¶
Generate the specified layer outputs.
# Use original model to get fp32 layer-outputs
fp32_layer_output_util = LayerOutputUtil(model=model, dir_path='./fp32_layer_outputs', naming_scheme=NamingScheme.ONNX,
dummy_input=dummy_input, onnx_export_args=OnnxExportApiArgs())
# Use quantsim model to get quantsim layer-outputs
quantsim_layer_output_util = LayerOutputUtil(model=quantsim.model, dir_path='./quantsim_layer_outputs', naming_scheme=NamingScheme.ONNX,
dummy_input=dummy_input, onnx_export_args=OnnxExportApiArgs())
for input_batch in input_batches:
fp32_layer_output_util.generate_layer_outputs(input_batch)
quantsim_layer_output_util.generate_layer_outputs(input_batch)
# Use original model to get fp32 layer-outputs
fp32_layer_output_util = LayerOutputUtil(model=model, save_dir='fp32_layer_outputs')
# Use quantsim model to get quantsim layer-outputs
quantsim_layer_output_util = LayerOutputUtil(model=quantsim.model, save_dir='quantsim_layer_outputs')
for input_batch in input_batches:
fp32_layer_output_util.generate_layer_outputs(input_batch=input_batch)
quantsim_layer_output_util.generate_layer_outputs(input_batch=input_batch)
# Use original model to get fp32 layer-outputs
fp32_layer_output_util = LayerOutputUtil(model=model, dir_path='./fp32_layer_outputs')
# Use quantsim model to get quantsim layer-outputs
quantsim_layer_output_util = LayerOutputUtil(model=quantsim.model.model, dir_path='./quantsim_layer_outputs')
for input_batch in input_batches:
fp32_layer_output_util.generate_layer_outputs(input_batch)
quantsim_layer_output_util.generate_layer_outputs(input_batch)
# Note: Generate layer-outputs for fp32 model before creating quantsim model becuase the fp32 model itself is modified to get quantsim version.
API¶
- class aimet_torch.layer_output_utils.LayerOutputUtil(model, dir_path, naming_scheme=NamingScheme.PYTORCH, dummy_input=None, onnx_export_args=None)[source]¶
Implementation to capture and save outputs of intermediate layers of a model (fp32/quantsim).
Constructor for LayerOutputUtil.
- Parameters:
model (
Module
) – Model whose layer-outputs are needed.dir_path (
str
) – Directory wherein layer-outputs will be saved.naming_scheme (
NamingScheme
) – Naming scheme to be followed to name layer-outputs. There are multiple schemes as per the exported model (pytorch, onnx or torchscript). Refer the NamingScheme enum definition.dummy_input (
Union
[Tensor
,Tuple
,List
,None
]) – Dummy input to model. Required if naming_scheme is ‘NamingScheme.ONNX’ or ‘NamingScheme.TORCHSCRIPT’.onnx_export_args (
Union
[OnnxExportApiArgs
,Dict
,None
]) – Should be same as that passed to quantsim export API to have consistency between layer-output names present in exported onnx model and generated layer-outputs. Required if naming_scheme is ‘NamingScheme.ONNX’.
The following API can be used to Generate Layer Outputs
- LayerOutputUtil.generate_layer_outputs(input_instance)[source]¶
This method captures output of every layer of a model & saves the single input instance and corresponding layer-outputs to disk.
- Parameters:
input_instance (
Union
[Tensor
,List
[Tensor
],Tuple
[Tensor
]]) – Single input instance for which we want to obtain layer-outputs.- Returns:
None
Naming Scheme Enum
- class aimet_torch.layer_output_utils.NamingScheme(value)[source]¶
Enumeration of layer-output naming schemes.
- ONNX = 2¶
Names outputs according to exported onnx model. Layer output names are generally numeric.
- PYTORCH = 1¶
Names outputs according to exported pytorch model. Layer names are used.
- TORCHSCRIPT = 3¶
Names outputs according to exported torchscript model. Layer output names are generally numeric.
The following API can be used to Generate Layer Outputs
The following API can be used to Generate Layer Outputs