AIMET Keras Cross Layer Equalization Primitive API¶
Introduction¶
If a user wants to modify the order of Cross Layer Equalization, not use some features, or manually tweak the list of layers that need to be equalized, the following APIs can be used.
Higher level API can be used for using one or more features one after the other. It automatically finds the layers to be folded or scaled.
Lower level APIs can be used to manually tweak the list of layers to be folded. The user has to pass the list of layers in the correct order that they appear in the model.
Note: Before using High Bias fold, Cross Layer Scaling (CLS) needs to be applied and scaling factors obtained from CLS need to be plugged in to High Bias Fold. And, if there are batchnorm layers, they need to be folded and the info saved to be plugged into high bias fold API.
Higher Level APIs for Cross Layer Equalization¶
API for Batch Norm Folding
-
aimet_tensorflow.keras.batch_norm_fold.
fold_all_batch_norms
(model)[source]¶ Fold all batch_norm layers in a model into corresponding conv/linear layers
- Parameters
model (
Model
) – model to find all batch norms for- Return type
Tuple
[List
[Tuple
[Union
[Conv2D
,Dense
,Conv2DTranspose
,DepthwiseConv2D
],BatchNormalization
]],Model
]- Returns
A tuple of List of conv/linear layers with associated bn op / activation info and a new model with the
Batch Normalization layers folded
API for Cross Layer Scaling
-
aimet_tensorflow.keras.cross_layer_equalization.CrossLayerScaling.
scale_model
(model)¶ Uses cross-layer scaling to scale all applicable layers in the given model :type model:
Model
:param model: tf.keras.Model :rtype:List
[ClsSetInfo
] :return: CLS information for each CLS set
API for High Bias Folding
-
aimet_tensorflow.keras.cross_layer_equalization.HighBiasFold.
bias_fold
(cls_set_info_list, bn_layers)¶ Folds bias values greater than 3 * sigma to next layer’s bias :type cls_set_info_list:
List
[ClsSetInfo
] :param cls_set_info_list: List of info elements for each cls set :type bn_layers:Dict
[Conv2D
,BatchNormalization
] :param bn_layers: Key: Conv/Linear layer Value: Corresponding folded BN layer
Code Examples for Higher Level APIs¶
Required imports
import tensorflow as tf
from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms
from aimet_tensorflow.keras.cross_layer_equalization import HighBiasFold, CrossLayerScaling
from aimet_tensorflow.keras.utils.model_transform_utils import replace_relu6_with_relu
Perform Cross Layer Equalization in auto mode step by step
def cross_layer_equalization_auto_stepwise():
"""
Individual api calls to perform cross layer equalization one step at a time. Pairs to fold and
scale are found automatically.
1. Replace Relu6 with Relu
2. Fold batch norms
3. Perform cross layer scaling
4. Perform high bias fold
"""
# Load the model to equalize
model = tf.keras.applications.resnet50.ResNet50(weights=None, classes=10)
# 1. Replace Relu6 layer with Relu
model_for_cle, _ = replace_relu6_with_relu(model)
# 2. Fold all batch norms
folded_pairs, model = fold_all_batch_norms(model_for_cle)
bn_dict = {}
for conv_or_linear, bn in folded_pairs:
bn_dict[conv_or_linear] = bn
# 3. Perform cross-layer scaling on applicable layer groups
cls_set_info_list = CrossLayerScaling.scale_model(model_for_cle)
# 4. Perform high bias fold
HighBiasFold.bias_fold(cls_set_info_list, bn_dict)
return model_for_cle
Lower Level APIs for Cross Layer Equalization¶
API for Batch Norm Folding on subsets of convolution-batchnorm layer pairs
-
aimet_tensorflow.keras.batch_norm_fold.
fold_given_batch_norms
(model, layer_pairs)[source]¶ Fold a given set of batch_norm layers into conv_linear layers
- Parameters
model (
Model
) – Either a Keras Model or a QuantizationSimModel’s modellayer_pairs (
List
[Union
[Tuple
[Union
[Conv2D
,Dense
,Conv2DTranspose
,DepthwiseConv2D
],BatchNormalization
,bool
],Tuple
[BatchNormalization
,Union
[Conv2D
,Dense
,Conv2DTranspose
,DepthwiseConv2D
],bool
]]]) – Tuple of conv, bn layers and is_batch_norm_second flag
- Return type
Optional
[Model
]- Returns
new model with batch norm layers folded if model is a functional model, else None
API for Cross Layer Scaling on subset of conv layer groups
-
aimet_tensorflow.keras.cross_layer_equalization.CrossLayerScaling.
scale_cls_sets
(cls_sets)¶ Scale each cls set :type cls_sets:
List
[Union
[Tuple
[Conv2D
,Conv2D
],Tuple
[Conv2D
,DepthwiseConv2D
,Conv2D
]]] :param cls_sets: Cls sets to scale :rtype:List
[Union
[ndarray
,Tuple
[ndarray
,ndarray
]]] :return: List of scale factors corresponding to each scaled cls set
API for High bias folding
-
aimet_tensorflow.keras.cross_layer_equalization.HighBiasFold.
bias_fold
(cls_set_info_list, bn_layers) Folds bias values greater than 3 * sigma to next layer’s bias :type cls_set_info_list:
List
[ClsSetInfo
] :param cls_set_info_list: List of info elements for each cls set :type bn_layers:Dict
[Conv2D
,BatchNormalization
] :param bn_layers: Key: Conv/Linear layer Value: Corresponding folded BN layer
Custom Datatype used¶
-
class
aimet_tensorflow.keras.cross_layer_equalization.
ClsSetInfo
(cls_pair_1, cls_pair_2=None)[source]¶ This class hold information about the layers in a CLS set, along with corresponding scaling factors and other information like if there is a ReLU activation function between the CLS set layers
Constructor takes 2 pairs if Depth-wise separable layer is being folded :type cls_pair_1:
ClsSetLayerPairInfo
:param cls_pair_1: Pair between two conv or conv and depth-wise conv :type cls_pair_2:Optional
[ClsSetLayerPairInfo
] :param cls_pair_2: Pair between depth-wise conv and point-wise conv-
class
ClsSetLayerPairInfo
(layer1, layer2, scale_factor, relu_activation_between_layers)[source]¶ Models a pair of layers that were scaled using CLS. And related information.
- Parameters
layer1 (
Conv2D
) – Layer whose bias is foldedlayer2 (
Conv2D
) – Layer to which bias of previous layer’s bias is foldedscale_factor (
ndarray
) – Scale Factor found from Cross Layer Scaling to scale BN parametersrelu_activation_between_layers (
bool
) – If the activation between layer1 and layer2 is Relu
-
class
Code Example for Lower level APIs¶
Required imports
import tensorflow as tf
from aimet_tensorflow.keras.batch_norm_fold import fold_given_batch_norms
from aimet_tensorflow.keras.cross_layer_equalization import HighBiasFold, CrossLayerScaling
from aimet_tensorflow.keras.utils.model_transform_utils import replace_relu6_with_relu
Perform Cross Layer Equalization in manual mode
def cross_layer_equalization_manual():
"""
Individual api calls to perform cross layer equalization one step at a time. Pairs to fold and
scale are provided by the user.
1. Replace Relu6 with Relu
2. Fold batch norms
3. Perform cross layer scaling
4. Perform high bias fold
"""
# Load the model to equalize
model = tf.keras.applications.resnet50.ResNet50(weights=None, classes=10)
# replace any ReLU6 layers with ReLU
model_for_cle, _ = replace_relu6_with_relu(model)
# pick potential pairs of conv and bn ops for fold
layer_pairs = get_example_layer_pairs_resnet50_for_folding(model_for_cle)
# fold given layers
fold_given_batch_norms(model_for_cle, layer_pairs=layer_pairs)
# Cross Layer Scaling
# Create a list of consecutive conv layers to be equalized
consecutive_layer_list = get_consecutive_layer_list_from_resnet50_for_scaling(model_for_cle)
# invoke api to perform scaling on given list of cls pairs
scaling_factor_list = CrossLayerScaling.scale_cls_sets(consecutive_layer_list)
# get info from bn fold and cross layer scaling in format required for high bias fold
folded_pairs, cls_set_info_list = format_info_for_high_bias_fold(layer_pairs,
consecutive_layer_list,
scaling_factor_list)
HighBiasFold.bias_fold(cls_set_info_list, folded_pairs)
return model_for_cle
Example helper methods to perform CLE in manual mode¶
Helper to pick layers for batchnorm fold
def get_example_layer_pairs_resnet50_for_folding(model: tf.keras.Model):
"""
Function to pick example conv-batchnorm layer pairs for folding.
:param model: Keras model containing conv batchnorm pairs to fold
:return: pairs of conv and batchnorm layers for batch norm folding in Resnet50 model.
"""
conv_op_1 = model.layers[2]
bn_op_1 = model.layers[3]
conv_op_2 = model.layers[7]
bn_op_2 = model.layers[8]
conv_op_3 = model.layers[10]
bn_op_3 = model.layers[11]
# make a layer pair list with potential the conv op and bn_op pair along with a flag
# to indicate if given bn op can be folded upstream or downstream.
# example of two pairs of conv and bn op shown below
layer_pairs = [(conv_op_1, bn_op_1, True),
(conv_op_2, bn_op_2, True),
(conv_op_3, bn_op_3, True)]
return layer_pairs
Helper to pick layers for cross layer scaling
def get_consecutive_layer_list_from_resnet50_for_scaling(model: tf.keras.Model):
"""
helper function to pick example consecutive layer list for scaling.
:param model: tf.keras.Model
:return: sample layers for scaling as consecutive_layer_list from Resnet50 model
"""
conv_op_1 = model.layers[2]
conv_op_2 = model.layers[7]
conv_op_3 = model.layers[10]
consecutive_layer_list = [(conv_op_1, conv_op_2), (conv_op_2, conv_op_3)]
return consecutive_layer_list
Helper to format data from batchnorm fold and cross layer scaling for usage by high bias fold
def format_info_for_high_bias_fold(layer_pairs, consecutive_layer_list, scaling_factor_list):
"""
Helper function that formats data from cross layer scaling and bn fold for usage by high bias fold
:param layer_pairs: info obtained after batchnorm fold
:param consecutive_layer_list: info obtained after cross layer scaling
:param scaling_factor_list: scaling params corresponding to consecutive_layer_list
:return: data formatted for high bias fold
"""
# convert info after batch norm fold and cross layer scaling for usage by high bias fold api
folded_pairs = []
for (conv_op, bn_op_with_meta, _fold_upstream_flag) in layer_pairs:
folded_pairs.append((conv_op, bn_op_with_meta.op))
# List that hold a boolean for if there were relu activations between layers of each cross layer scaling set
is_relu_activation_in_cls_sets = []
# Note the user is expected to fill in this list manually
# Convert to a list of cls-set-info elements
cls_set_info_list = CrossLayerScaling.create_cls_set_info_list(consecutive_layer_list,
scaling_factor_list,
is_relu_activation_in_cls_sets)
return folded_pairs, cls_set_info_list