Batch norm folding

Context

Batch norm folding is a technique widely used in deep learning inference runtimes, including the Qualcomm® AI Engine Direct. Batch normalization layers are typically folded into the weights and biases of adjacent convolution layers whenever possible to eliminate unnecessary computations. To accurately simulate inference in these runtimes, it is generally advisable to perform batch norm folding on the floating-point model before applying quantization. Doing so not only results in a speedup in inferences per second by avoiding unnecessary computations but also often improves the accuracy of the quantized model by removing redundant computations and requantization. We aim to simulate this on-target behavior by performing batch norm folding here.

Workflow

Code example

Step 1

Load the model for batch norm folding. In this code example, we will use MobileNetV2

import torch
from torchvision.models import mobilenet_v2

# General setup that can be changed as needed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(pretrained=True).eval().to(device)
dummy_input = torch.randn(1, 3, 224, 224).to(device)

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    ...
)

Load the model for batch norm folding. In this code example, we will use MobileNetV2

from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms
from aimet_tensorflow.keras.model_preparer import prepare_model
from tensorflow.keras import applications

model = applications.MobileNetV2()
print(model.summary())
Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to
==================================================================================================
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []
                                )]

 Conv1 (Conv2D)                 (None, 112, 112, 32  864         ['input_1[0][0]']
                                )

 bn_Conv1 (BatchNormalization)  (None, 112, 112, 32  128         ['Conv1[0][0]']
                                )

 Conv1_relu (ReLU)              (None, 112, 112, 32  0           ['bn_Conv1[0][0]']
                                )
 ...

Load the model for batch norm folding. In this code example, we will convert PyTorch MobileNetV2 to ONNX and use it in the subsequent code

import os

import onnx
import onnxsim
import torch
from aimet_onnx.batch_norm_fold import fold_all_batch_norms_to_weight
from torchvision.models import MobileNet_V2_Weights, mobilenet_v2

pt_model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
print(pt_model)

# Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)

# Modify file_path as you wish, we are using temporary directory for now
file_path = os.path.join('/tmp', f'mobilenet_v2.onnx')
torch.onnx.export(
    pt_model,
    (dummy_input,),
    file_path,
    do_constant_folding=False,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'},
    },
)
# Load exported ONNX model
model = onnx.load_model(file_path)
MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    ...
)

Step 2

No preparation step is needed for PyTorch.

AIMET provides TensorFlow prepare_model API, which performs preprocessing on the user model if necessary

prepared_model = prepare_model(model)

print('*** Before batch norm folding ***')

print('\nprepared_model.layers[1]:')
print(type(prepared_model.layers[1]))

print('\nprepared_model.layers[2]:')
print(type(prepared_model.layers[2]))

print('\nConv weight')
print(prepared_model.layers[1].get_weights()[0])
*** Before batch norm folding ***

prepared_model.layers[1]:
<class 'keras.layers.convolutional.conv2d.Conv2D'>

prepared_model.layers[2]:
<class 'keras.layers.normalization.batch_normalization.BatchNormalization'>

Conv weight:
[[[[-1.71659231e-01 -3.33731920e-01  5.30122258e-02 -5.93232973e-21
     2.08742931e-01 -1.20433941e-01  1.75700430e-02 -3.10708203e-22
    -9.62498877e-03  1.90229788e-01 -3.67278278e-01  3.95997976e-22
  ...
     3.87471542e-02 -3.67677957e-02 -3.23011987e-02 -4.83861901e-02
     1.23156421e-02 -5.57984132e-03 -6.53976866e-04 -1.92511864e-02
    -2.09685047e-22  1.19186290e-01 -2.52912678e-02  2.02078857e-02]]]]

It’s recommended to simplify the ONNX model before applying AIMET functionalities

# Unlike AIMET, which supports both forward/backward folding, ONNX simplifier only performs backward folding.
# Therefore, we disable the corresponding optimization in `skipped_optimizers` and proceed with the example
try:
    model, _ = onnxsim.simplify(model, skipped_optimizers=['fuse_bn_into_conv'])
except:
    print('ONNX Simplifier failed. Proceeding with unsimplified model')

initializers = {init.name: init for init in model.graph.initializer}
conv_weight_name = model.graph.node[0].input[1]
conv_weight = initializers[conv_weight_name]
conv_weight_array = onnx.numpy_helper.to_array(conv_weight)

print("*** Before batch norm folding ***")

print("\nmodel.graph.node[0]:")
print(model.graph.node[0].name)

print("\nmodel.graph.node[1]:")
print(model.graph.node[1].name)

print("\nConv weight")
print(conv_weight_array)
*** Before batch norm folding ***

model.graph.node[0]:
name: "/features/features.0/features.0.0/Conv"

model.graph.node[1]:
name: "/features/features.0/features.0.1/BatchNormalization"

Conv weight:
[[[[-6.31080866e-02 -1.87656835e-01 -1.51876003e-01]
   [-4.93787616e-01 -6.42477691e-01 -5.89348674e-01]
   [-6.80053532e-01 -9.74478185e-01 -7.63172388e-01]]
  ...
  [[ 1.24257803e-02 -4.73242160e-03 -1.81884710e-02]
   [ 2.32141271e-01  7.22583652e-01  1.21250950e-01]
   [-2.59643137e-01 -7.18673885e-01 -9.19778645e-02]]]]

Step 3

Execute AIMET batch norm folding API

print("*** Before batch norm folding ***")

print("\nmodel.features[0][0]:")
print(model.features[0][0])

print("\nmodel.features[0][1]:")
print(model.features[0][1])

from aimet_torch.batch_norm_fold import fold_all_batch_norms
fold_all_batch_norms(model, dummy_input.shape, dummy_input=dummy_input)

print("*** After batch norm folding ***")

print("\nmodel.features[0][0]:")
print(model.features[0][0])

print("\nmodel.features[0][1]:")
print(model.features[0][1])

*** Before batch norm folding ***

model.features[0][0]:
Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

model.features[0][1]:
BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)


*** After batch norm folding ***

model.features[0][0]:
Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))

model.features[0][1]:
Identity()

Execute AIMET batch norm folding API

_, folded_model = fold_all_batch_norms(prepared_model)

print('*** After batch norm folding ***')

print('\nfolded_model.layers[1]:')
print(type(folded_model.layers[1]))

print('\nfolded_model.layers[2]:')
print(type(folded_model.layers[2]))

print('\nConv weight')
print(folded_model.layers[1].get_weights()[0])
*** After batch norm folding ***

folded_model.layers[1]:
<class 'keras.layers.convolutional.conv2d.Conv2D'>

folded_model.layers[2]:
<class 'keras.layers.activation.relu.ReLU'>

Conv weight
[[[[-3.01457286e-01 -1.49024737e+00  6.10569119e-01 -1.29590677e-19
     1.51547194e-01 -1.51446089e-01  1.38100997e-01 -4.89249423e-21
    -5.16245179e-02  4.64579314e-01 -2.44408584e+00  1.22219264e-20
     ...
     1.67510852e-01 -2.60713138e-02 -1.05549544e-01 -2.53403008e-01
     1.39502389e-02 -1.54620111e-02 -1.97294299e-02 -9.41715762e-02
    -6.88260233e-21  8.95088911e-02 -1.87630311e-01  2.48399768e-02]]]]

Execute AIMET batch norm folding API

_ = fold_all_batch_norms_to_weight(model=model)

conv_weight = initializers[conv_weight_name]
conv_weight_array = onnx.numpy_helper.to_array(conv_weight)

print("*** After batch norm folding ***")

print("\nmodel.graph.node[0]:")
print(model.graph.node[0].name)

print("\nmodel.graph.node[1]:")
print(model.graph.node[1].name)

print("\nConv weight")
print(conv_weight_array)
*** After batch norm folding ***

model.graph.node[0]:
name: "/features/features.0/features.0.0/Conv"

model.graph.node[1]:
name: "/features/features.0/features.0.2/Clip"

Conv weight:
[[[[-2.00183112e-02 -5.95260113e-02 -4.81760912e-02]
   [-1.56632766e-01 -2.03798249e-01 -1.86945379e-01]
   [-2.15717569e-01 -3.09111059e-01 -2.42083430e-01]]
   ...
  [[ 1.21066449e-02 -4.61087702e-03 -1.77213307e-02]
   [ 2.26179108e-01  7.04025269e-01  1.18136823e-01]
   [-2.52974629e-01 -7.00215936e-01 -8.96155685e-02]]]]

API

Top-level API

aimet_torch.batch_norm_fold.fold_all_batch_norms(model, input_shapes, dummy_input=None)

Fold all batch_norm layers in a model into the weight of the corresponding conv layers

Parameters:
  • model (Module) – Model

  • input_shapes (Union[Tuple, List[Tuple]]) – Input shapes for the model (can be one or multiple inputs)

  • dummy_input (Union[Tensor, Tuple, None]) – A dummy input to the model. Can be a Tensor or a Tuple of Tensors

Return type:

List[Tuple[Union[Linear, Conv1d, Conv2d, ConvTranspose2d], Union[BatchNorm1d, BatchNorm2d]]]

Returns:

A list of pairs of layers [(Conv/Linear, BN layer that got folded)]

Top-level API

aimet_tensorflow.keras.batch_norm_fold.fold_all_batch_norms(model)[source]

Fold all batch_norm layers in a model into corresponding conv/linear layers

Parameters:

model (Model) – model to find all batch norms for

Return type:

Tuple[List[Tuple[Union[Conv2D, Dense, Conv2DTranspose, DepthwiseConv2D], BatchNormalization]], Model]

Returns:

A tuple of List of conv/linear layers with associated bn op / activation info and a new model with the Batch Normalization layers folded

Top-level API

aimet_onnx.batch_norm_fold.fold_all_batch_norms_to_weight(model)[source]

Fold all possible batch_norm layers in a model into the weight of the corresponding conv layers

Parameters:

model (ModelProto) – onnx Model to perform BN fold on

Return type:

[typing.List]

Returns:

A list of pairs of layers [(Conv/Linear, BN layer that got folded)]