Cross-layer equalization¶
Context¶
Quantization of floating-point models into lower bitwidths introduces quantization noise on the weights and activations, which often reduces model performance. To minimize quantization noise, AIMET recommends a quantization workflow that includes a variety of post training quantization (PTQ) techniques. You can learn more about these techniques here.
AIMET includes a cross-layer equalization (CLE) tool that applies the following PTQ techniques:
- Batch Norm Folding
This feature folds batch norm layers into adjacent convolutional and linear layers. For more on BNF see Batch norm folding.
- Cross Layer Scaling
In some models, the parameter ranges for different channels in a layer show a wide variance. See the first chart in the following figure.
Cross-layer scaling attempts to equalize the distribution of weights per channel of consecutive layers. This gives different channels a similar range so that the same quantization parameters can be used for weights across all channels. See the second chart in the figure.
- High Bias Fold
Cross layer scaling may result in high bias parameter values for some layers. This technique folds some of the bias of a layer into the subsequent layer’s parameters. This feature requires batch norm parameters to operate on and is not applied otherwise.
Workflow¶
Setup¶
Load the model.
This code example uses MobileNetV2.
import torch
from torchvision.models import mobilenet_v2
# General setup that can be changed as needed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(pretrained=True).eval().to(device)
input_shape = (1, 3, 224, 224)
This code example uses MobileNetV2.
We recommend applying the TensorFlow prepare_model API before applying AIMET functionalities. After preparation the model contains consecutive convolutions, which can be optimized through cross-layer equalization.
from aimet_tensorflow.keras.cross_layer_equalization import equalize_model
from aimet_tensorflow.keras.model_preparer import prepare_model
from tensorflow.keras import applications
model = applications.MobileNetV2()
print(model.summary())
prepared_model = prepare_model(model)
print('*** Before cross-layer equalization ***')
print('\nprepared_model.layers[1]:')
print(type(prepared_model.layers[1]))
print('\nprepared_model.layers[4]:')
print(type(prepared_model.layers[4]))
print('\nPrev Conv weight')
print(prepared_model.layers[1].get_weights()[0])
print('\nNext Conv weight')
print(prepared_model.layers[4].get_weights()[0])
Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 224, 224, 3 0 []
)]
Conv1 (Conv2D) (None, 112, 112, 32 864 ['input_1[0][0]']
)
bn_Conv1 (BatchNormalization) (None, 112, 112, 32 128 ['Conv1[0][0]']
)
Conv1_relu (ReLU) (None, 112, 112, 32 0 ['bn_Conv1[0][0]']
)
expanded_conv_depthwise (Depth (None, 112, 112, 32 288 ['Conv1_relu[0][0]']
wiseConv2D) )
...
*** Before cross-layer equalization ***
prepared_model.layers[1]:
<class 'keras.layers.convolutional.conv2d.Conv2D'>
prepared_model.layers[4]:
<class 'keras.layers.convolutional.depthwise_conv2d.DepthwiseConv2D'>
Prev Conv weight
[[[[-1.71659231e-01 -3.33731920e-01 5.30122258e-02 -5.93232973e-21
2.08742931e-01 -1.20433941e-01 1.75700430e-02 -3.10708203e-22
-9.62498877e-03 1.90229788e-01 -3.67278278e-01 3.95997976e-22
...
3.87471542e-02 -3.67677957e-02 -3.23011987e-02 -4.83861901e-02
1.23156421e-02 -5.57984132e-03 -6.53976866e-04 -1.92511864e-02
-2.09685047e-22 1.19186290e-01 -2.52912678e-02 2.02078857e-02]]]]
Next Conv weight
[[[[-9.15259957e-01]
[ 6.11176670e-01]
[-4.27415752e+00]
...
[-1.17871511e+00]
[ 2.55578518e+00]
[ 3.69716495e-01]]]]
Load the model for cross-layer equalization. This example converts PyTorch MobileNetV2 to ONNX and uses it in the subsequent code.
We recommend simplifying the ONNX model before applying AIMET functionalities. After simplification, the model contains consecutive convolutions, which can be optimized through cross-layer equalization.
import os
import onnx
import onnxsim
import torch
from aimet_onnx.cross_layer_equalization import equalize_model
from torchvision.models import MobileNet_V2_Weights, mobilenet_v2
pt_model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
print(pt_model)
# Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)
# Modify file_path as you wish, we are using temporary directory for now
file_path = os.path.join('/tmp', f'mobilenet_v2.onnx')
torch.onnx.export(
pt_model,
(dummy_input,),
file_path,
)
# Load exported ONNX model
model = onnx.load_model(file_path)
# Simplifying the model
try:
model, _ = onnxsim.simplify(model)
except:
print('ONNX Simplifier failed. Proceeding with unsimplified model')
initializers = {init.name: init for init in model.graph.initializer}
prev_conv_weight = onnx.numpy_helper.to_array(
initializers[model.graph.node[4].input[1]]
)
next_conv_weight = onnx.numpy_helper.to_array(
initializers[model.graph.node[5].input[1]]
)
print("*** Before cross-layer equalization ***")
print("\nmodel.graph.node[4]:")
print(model.graph.node[4].name)
print("\nmodel.graph.node[5]:")
print(model.graph.node[5].name)
print("\nPrev Conv weight")
print(prev_conv_weight)
print("\nNext Conv weight")
print(next_conv_weight)
MobileNetV2(
(features): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
...
)
*** Before cross-layer equalization ***
model.graph.node[4]:
/features/features.1/conv/conv.1/Conv
model.graph.node[5]:
/features/features.2/conv/conv.0/conv.0.0/Conv
Prev Conv weight
[[[[ 1.83640555e-01]]
[[ 6.34215236e-01]]
[[ 8.44993666e-02]]
...
[[-6.70130579e-17]]
[[-1.37757687e-02]]
[[ 9.16839484e-03]]]]
Next Conv weight
[[[[-8.41059163e-02]]
[[-1.12039044e-01]]
[[-2.72468403e-02]]
...
[[ 9.46642041e-01]]
[[ 4.35139937e-03]]
[[ 2.57021021e-02]]]]
Execution¶
Apply cross-layer equalization.
Execute the AIMET cross-layer equalization API function.
from aimet_torch.cross_layer_equalization import equalize_model
# Performs BatchNorm folding, cross layer scaling and high bias folding
equalize_model(model, input_shape)
Execute the AIMET cross-layer equalization API function.
cle_applied_model = equalize_model(prepared_model)
print('*** After cross-layer equalization ***')
print('\nPrev Conv weight')
print(cle_applied_model.layers[1].get_weights()[0])
print('\nNext Conv weight')
print(cle_applied_model.layers[3].get_weights()[0])
*** After cross-layer equalization ***
Prev Conv weight
[[[[-3.01457286e-01 -1.49024737e+00 6.10569119e-01 -1.29590677e-19
1.51547194e-01 -1.51446089e-01 1.38100997e-01 -4.89249423e-21
-5.16245179e-02 4.64579314e-01 -2.44408584e+00 1.22219264e-20
...
1.67510852e-01 -2.60713138e-02 -1.05549544e-01 -2.53403008e-01
1.39502389e-02 -1.54620111e-02 -1.97294299e-02 -9.41715762e-02
-6.88260233e-21 8.95088911e-02 -1.87630311e-01 2.48399768e-02]]]]
Next Conv weight
[[[[-1.00347728e-01]
[ 6.30402938e-02]
[-9.67416465e-01]
...
[-7.88373709e-01]
[ 6.75162792e-01]
[ 1.48045555e-01]]]]
Execute the AIMET cross-layer equalization API function.
# Cross-layer equalization is working as in-place manner
equalize_model(model=model)
prev_conv_weight = onnx.numpy_helper.to_array(
initializers[model.graph.node[4].input[1]]
)
next_conv_weight = onnx.numpy_helper.to_array(
initializers[model.graph.node[5].input[1]]
)
print("*** After cross-layer equalization ***")
print("\nPrev Conv weight")
print(prev_conv_weight)
print("\nNext Conv weight")
print(next_conv_weight)
*** After cross-layer equalization ***
Prev Conv weight
[[[[ 6.28238320e-02]]
[[ 2.16966406e-01]]
[[ 2.89074164e-02]]
...
[[-2.44632760e-17]]
[[-5.02887694e-03]]
[[ 3.34694423e-03]]]]
Next Conv weight
[[[[-2.4585028e-01]]
[[-3.5856506e-01]]
[[-3.3467390e-02]]
...
[[ 1.2930528e+00]]
[[ 1.6213797e-02]]
[[ 7.0406616e-02]]]]
API¶
Top-level API
- aimet_torch.cross_layer_equalization.equalize_model(model, input_shapes=None, dummy_input=None)[source]¶
High-level API to perform Cross-Layer Equalization (CLE) on the given model. The model is equalized in place.
- Parameters:
model (
Module
) – Model to equalizeinput_shapes (
Union
[Tuple
,List
[Tuple
],None
]) – Shape of the input (can be a tuple or a list of tuples if multiple inputs)dummy_input (
Union
[Tensor
,Tuple
,None
]) – A dummy input to the model. Can be a Tensor or a Tuple of Tensors. dummy_input will be placed on CPU if not already.
Top-level API