AIMET PyTorch Cross Layer Equalization Primitive API¶
Introduction¶
If a user wants to modify the order of Cross Layer equalization, not use some features or manually tweak the list of layers that need to be equalized, the following APIs can be used.
Higher level APIs can be used for using one or more features one after the other. It automatically finds the layers to be folded or scaled.
Lower level APIs can be used to manually tweak the list of layers to be folded. The user has to pass the list of layers in the correct order that they appear in the model.
Note: Before using High Bias fold, Cross Layer Scaling (CLS) needs to be applied and scaling factors obtained from CLS need to be plugged in to High Bias Fold. And, if there are batchnorm layers, they need to be folded and the info saved to be plugged into high bias fold API.
ClsSetInfo Definition¶
Higher Level APIs for Cross Layer Equalization¶
API for Batch Norm Folding
API for Cross Layer Scaling
API for High Bias Folding
Code Examples for Higher Level APIs¶
Required imports
import torch
from torchvision import models
from aimet_torch import batch_norm_fold
from aimet_torch import cross_layer_equalization
from aimet_torch import utils
Cross Layer Equalization in auto mode calling each API
def cross_layer_equalization_auto_step_by_step():
model = models.resnet18(pretrained=True)
model = model.eval()
input_shape = (1, 3, 224, 224)
# Fold batchnorm layers
folded_pairs = batch_norm_fold.fold_all_batch_norms(model, input_shape)
bn_dict = {}
for conv_bn in folded_pairs:
bn_dict[conv_bn[0]] = conv_bn[1]
# Replace any ReLU6 layers with ReLU
utils.replace_modules_of_type1_with_type2(model, torch.nn.ReLU6, torch.nn.ReLU)
# Perform cross-layer scaling on applicable layer sets
cls_set_info_list = cross_layer_equalization.CrossLayerScaling.scale_model(model, input_shape)
# Perform high-bias fold
cross_layer_equalization.HighBiasFold.bias_fold(cls_set_info_list, bn_dict)
Lower Level APIs for Cross Layer Equalization¶
API for Batch Norm Folding
API for Cross Layer Scaling
API for High bias folding
Code Examples for Lower Level APIs¶
Required imports
from torchvision import models
from aimet_torch.examples.mobilenet import MobileNetV2
from aimet_torch import batch_norm_fold
from aimet_torch import cross_layer_equalization
from aimet_torch import utils
Cross Layer Equalization in manual mode
def cross_layer_equalization_manual():
model = models.resnet18(pretrained=True)
model = model.eval()
# Batch Norm Fold
# Create a list of conv/linear and BN layers for folding forward or backward
layer_list = [(model.conv1, model.bn1),
(model.layer1[0].conv1, model.layer1[0].bn1)]
# Save the corresponding BN layers (needed only for high bias folding)
bn_dict = {}
for conv_bn in layer_list:
bn_dict[conv_bn[0]] = conv_bn[1]
batch_norm_fold.fold_given_batch_norms(model, layer_list)
# Replace any ReLU6 layers with ReLU
utils.replace_modules_of_type1_with_type2(model, torch.nn.ReLU6, torch.nn.ReLU)
# Cross Layer Scaling
# Create a list of consecutive conv layers to be equalized
consecutive_layer_list = [(model.conv1, model.layer1[0].conv1),
(model.layer1[0].conv1, model.layer1[0].conv2)]
scaling_factor_list = cross_layer_equalization.CrossLayerScaling.scale_cls_sets(consecutive_layer_list)
# High Bias Fold
# Create a list of consecutive conv layers whose previous layers bias has to be folded to next layers bias
ClsSetInfo = cross_layer_equalization.ClsSetInfo
ClsPairInfo = cross_layer_equalization.ClsSetInfo.ClsSetLayerPairInfo
cls_set_info_list = [ClsSetInfo(ClsPairInfo(model.conv1, model.layer1[0].conv1, scaling_factor_list[0], True)),
ClsSetInfo(ClsPairInfo(model.layer1[0].conv1, model.layer1[0].conv2, scaling_factor_list[1], True))]
cross_layer_equalization.HighBiasFold.bias_fold(cls_set_info_list, bn_dict)
Cross Layer Equalization in manual mode for Depthwise Separable layer
def cross_layer_equalization_depthwise_layers():
model = MobileNetV2().to(torch.device('cpu'))
model.eval()
# Batch Norm Fold
# Create a list of conv/linear and BN layers for folding forward or backward
layer_list = [(model.features[0][0], model.features[0][1]),
(model.features[1].conv[0], model.features[1].conv[1]),
(model.features[1].conv[3], model.features[1].conv[4])]
# Save the corresponding BN layers (needed only for high bias folding)
bn_dict = {}
for conv_bn in layer_list:
bn_dict[conv_bn[0]] = conv_bn[1]
batch_norm_fold.fold_given_batch_norms(model, layer_list)
# Replace any ReLU6 layers with ReLU
utils.replace_modules_of_type1_with_type2(model, torch.nn.ReLU6, torch.nn.ReLU)
# Cross Layer Scaling
# Create a list of consecutive conv layers to be equalized
consecutive_layer_list = [(model.features[0][0], model.features[1].conv[0], model.features[1].conv[3])]
scaling_factor_list = cross_layer_equalization.CrossLayerScaling.scale_cls_sets(consecutive_layer_list)
# High Bias Fold
# Create a list of consecutive conv layers whose previous layers bias has to be folded to next layers bias
ClsSetInfo = cross_layer_equalization.ClsSetInfo
ClsPairInfo = cross_layer_equalization.ClsSetInfo.ClsSetLayerPairInfo
cls_set_info_list = [ClsSetInfo(ClsPairInfo(model.features[0][0], model.features[1].conv[0], scaling_factor_list[0][0], True)),
ClsSetInfo(ClsPairInfo(model.features[1].conv[0], model.features[1].conv[3], scaling_factor_list[0][1], True))]
cross_layer_equalization.HighBiasFold.bias_fold(cls_set_info_list, bn_dict)