Batch norm folding¶
Context¶
Batch norm folding (BNF) is a technique widely used in deep learning inference runtimes, including Qualcomm® AI Engine Direct. In BNF, batch normalization layers are folded into the weights and biases of adjacent convolution layers where possible to eliminate unnecessary computations.
To accurately simulate inference in these runtimes, perform BNF on the floating-point model before applying quantization. Doing so not only speeds performance (inferences per second) but also often improves the accuracy of the quantized model by removing redundant computations and requantization. AIMET enables you to apply BNF to the pre-quantized model as a precursor to simulating this on-target behavior in the quantization simulation (QuantSim) model.
Workflow¶
Procedure¶
Step 1¶
Load the model.
This example uses the MobileNetV2 model.
import torch
from torchvision.models import mobilenet_v2
# General setup that can be changed as needed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(pretrained=True).eval().to(device)
dummy_input = torch.randn(1, 3, 224, 224).to(device)
MobileNetV2(
(features): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
...
)
This example uses the MobileNetV2 model.
from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms
from aimet_tensorflow.keras.model_preparer import prepare_model
from tensorflow.keras import applications
model = applications.MobileNetV2()
print(model.summary())
Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 224, 224, 3 0 []
)]
Conv1 (Conv2D) (None, 112, 112, 32 864 ['input_1[0][0]']
)
bn_Conv1 (BatchNormalization) (None, 112, 112, 32 128 ['Conv1[0][0]']
)
Conv1_relu (ReLU) (None, 112, 112, 32 0 ['bn_Conv1[0][0]']
)
...
This example converts the PyTorch MobileNetV2 to ONNX and subsequently uses the ONNX model.
import os
import onnx
import onnxsim
import torch
from aimet_onnx.batch_norm_fold import fold_all_batch_norms_to_weight
from torchvision.models import MobileNet_V2_Weights, mobilenet_v2
pt_model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
print(pt_model)
# Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)
# Modify file_path as you wish, we are using temporary directory for now
file_path = os.path.join('/tmp', f'mobilenet_v2.onnx')
torch.onnx.export(
pt_model,
(dummy_input,),
file_path,
do_constant_folding=False,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'},
},
)
# Load exported ONNX model
model = onnx.load_model(file_path)
MobileNetV2(
(features): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
...
)
Step 2¶
Prepare the model, if required by the model framework.
No preparation step is needed for PyTorch.
AIMET provides the TensorFlow prepare_model API, which pre-processes the user model if necessary.
prepared_model = prepare_model(model)
print('*** Before batch norm folding ***')
print('\nprepared_model.layers[1]:')
print(type(prepared_model.layers[1]))
print('\nprepared_model.layers[2]:')
print(type(prepared_model.layers[2]))
print('\nConv weight')
print(prepared_model.layers[1].get_weights()[0])
*** Before batch norm folding ***
prepared_model.layers[1]:
<class 'keras.layers.convolutional.conv2d.Conv2D'>
prepared_model.layers[2]:
<class 'keras.layers.normalization.batch_normalization.BatchNormalization'>
Conv weight:
[[[[-1.71659231e-01 -3.33731920e-01 5.30122258e-02 -5.93232973e-21
2.08742931e-01 -1.20433941e-01 1.75700430e-02 -3.10708203e-22
-9.62498877e-03 1.90229788e-01 -3.67278278e-01 3.95997976e-22
...
3.87471542e-02 -3.67677957e-02 -3.23011987e-02 -4.83861901e-02
1.23156421e-02 -5.57984132e-03 -6.53976866e-04 -1.92511864e-02
-2.09685047e-22 1.19186290e-01 -2.52912678e-02 2.02078857e-02]]]]
We recommend that you simplify the ONNX model as follows.
# Unlike AIMET, which supports both forward/backward folding, ONNX simplifier only performs backward folding.
# Therefore, we disable the corresponding optimization in `skipped_optimizers` and proceed with the example
try:
model, _ = onnxsim.simplify(model, skipped_optimizers=['fuse_bn_into_conv'])
except:
print('ONNX Simplifier failed. Proceeding with unsimplified model')
initializers = {init.name: init for init in model.graph.initializer}
conv_weight_name = model.graph.node[0].input[1]
conv_weight = initializers[conv_weight_name]
conv_weight_array = onnx.numpy_helper.to_array(conv_weight)
print("*** Before batch norm folding ***")
print("\nmodel.graph.node[0]:")
print(model.graph.node[0].name)
print("\nmodel.graph.node[1]:")
print(model.graph.node[1].name)
print("\nConv weight")
print(conv_weight_array)
*** Before batch norm folding ***
model.graph.node[0]:
name: "/features/features.0/features.0.0/Conv"
model.graph.node[1]:
name: "/features/features.0/features.0.1/BatchNormalization"
Conv weight:
[[[[-6.31080866e-02 -1.87656835e-01 -1.51876003e-01]
[-4.93787616e-01 -6.42477691e-01 -5.89348674e-01]
[-6.80053532e-01 -9.74478185e-01 -7.63172388e-01]]
...
[[ 1.24257803e-02 -4.73242160e-03 -1.81884710e-02]
[ 2.32141271e-01 7.22583652e-01 1.21250950e-01]
[-2.59643137e-01 -7.18673885e-01 -9.19778645e-02]]]]
Step 3¶
Perform the batch norm folding.
Execute the AIMET BNF API.
print("*** Before batch norm folding ***")
print("\nmodel.features[0][0]:")
print(model.features[0][0])
print("\nmodel.features[0][1]:")
print(model.features[0][1])
from aimet_torch.batch_norm_fold import fold_all_batch_norms
fold_all_batch_norms(model, dummy_input.shape, dummy_input=dummy_input)
print("*** After batch norm folding ***")
print("\nmodel.features[0][0]:")
print(model.features[0][0])
print("\nmodel.features[0][1]:")
print(model.features[0][1])
*** Before batch norm folding ***
model.features[0][0]:
Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
model.features[0][1]:
BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
*** After batch norm folding ***
model.features[0][0]:
Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
model.features[0][1]:
Identity()
Execute the AIMET BNF API.
_, folded_model = fold_all_batch_norms(prepared_model)
print('*** After batch norm folding ***')
print('\nfolded_model.layers[1]:')
print(type(folded_model.layers[1]))
print('\nfolded_model.layers[2]:')
print(type(folded_model.layers[2]))
print('\nConv weight')
print(folded_model.layers[1].get_weights()[0])
*** After batch norm folding ***
folded_model.layers[1]:
<class 'keras.layers.convolutional.conv2d.Conv2D'>
folded_model.layers[2]:
<class 'keras.layers.activation.relu.ReLU'>
Conv weight
[[[[-3.01457286e-01 -1.49024737e+00 6.10569119e-01 -1.29590677e-19
1.51547194e-01 -1.51446089e-01 1.38100997e-01 -4.89249423e-21
-5.16245179e-02 4.64579314e-01 -2.44408584e+00 1.22219264e-20
...
1.67510852e-01 -2.60713138e-02 -1.05549544e-01 -2.53403008e-01
1.39502389e-02 -1.54620111e-02 -1.97294299e-02 -9.41715762e-02
-6.88260233e-21 8.95088911e-02 -1.87630311e-01 2.48399768e-02]]]]
Execute the AIMET BNF API.
_ = fold_all_batch_norms_to_weight(model=model)
conv_weight = initializers[conv_weight_name]
conv_weight_array = onnx.numpy_helper.to_array(conv_weight)
print("*** After batch norm folding ***")
print("\nmodel.graph.node[0]:")
print(model.graph.node[0].name)
print("\nmodel.graph.node[1]:")
print(model.graph.node[1].name)
print("\nConv weight")
print(conv_weight_array)
*** After batch norm folding ***
model.graph.node[0]:
name: "/features/features.0/features.0.0/Conv"
model.graph.node[1]:
name: "/features/features.0/features.0.2/Clip"
Conv weight:
[[[[-2.00183112e-02 -5.95260113e-02 -4.81760912e-02]
[-1.56632766e-01 -2.03798249e-01 -1.86945379e-01]
[-2.15717569e-01 -3.09111059e-01 -2.42083430e-01]]
...
[[ 1.21066449e-02 -4.61087702e-03 -1.77213307e-02]
[ 2.26179108e-01 7.04025269e-01 1.18136823e-01]
[-2.52974629e-01 -7.00215936e-01 -8.96155685e-02]]]]
API¶
Top-level API
- aimet_torch.batch_norm_fold.fold_all_batch_norms(model, input_shapes, dummy_input=None)¶
Fold all batch_norm layers in a model into the weight of the corresponding conv layers
- Parameters:
model (
Module
) – Modelinput_shapes (
Union
[Tuple
,List
[Tuple
]]) – Input shapes for the model (can be one or multiple inputs)dummy_input (
Union
[Tensor
,Tuple
,None
]) – A dummy input to the model. Can be a Tensor or a Tuple of Tensors
- Return type:
List
[Tuple
[Union
[Linear
,Conv1d
,Conv2d
,ConvTranspose2d
],Union
[BatchNorm1d
,BatchNorm2d
]]]- Returns:
A list of pairs of layers [(Conv/Linear, BN layer that got folded)]
Top-level API
- aimet_tensorflow.keras.batch_norm_fold.fold_all_batch_norms(model)[source]¶
Fold all batch_norm layers in a model into corresponding conv/linear layers
- Parameters:
model (
Model
) – model to find all batch norms for- Return type:
Tuple
[List
[Tuple
[Union
[Conv2D
,Dense
,Conv2DTranspose
,DepthwiseConv2D
],BatchNormalization
]],Model
]- Returns:
A tuple of List of conv/linear layers with associated bn op / activation info and a new model with the Batch Normalization layers folded
Top-level API
- aimet_onnx.batch_norm_fold.fold_all_batch_norms_to_weight(model)[source]¶
Fold all possible batch_norm layers in a model into the weight of the corresponding conv layers
- Parameters:
model (
ModelProto
) – onnx Model to perform BN fold on- Return type:
[typing.List]
- Returns:
A list of pairs of layers [(Conv/Linear, BN layer that got folded)]