AIMET TensorFlow Cross Layer Equalization Primitive API

Introduction

If a user wants to modify the order of Cross Layer Equalization, not use some features, or manually tweak the list of layers that need to be equalized, the following APIs can be used.

Higher level API can be used for using one or more features one after the other. It automatically finds the layers to be folded or scaled.

Lower level APIs can be used to manually tweak the list of layers to be folded. The user has to pass the list of layers in the correct order that they appear in the model.

Note: Before using High Bias fold, Cross Layer Scaling (CLS) needs to be applied and scaling factors obtained from CLS need to be plugged in to High Bias Fold. And, if there are batchnorm layers, they need to be folded and the info saved to be plugged into high bias fold API.

Higher Level APIs for Cross Layer Equalization

API for Batch Norm Folding

aimet_tensorflow.keras.batch_norm_fold.fold_all_batch_norms(model)[source]

Fold all batch_norm layers in a model into corresponding conv/linear layers

Parameters:

model (Model) – model to find all batch norms for

Return type:

Tuple[List[Tuple[Union[Conv2D, Dense, Conv2DTranspose, DepthwiseConv2D], BatchNormalization]], Model]

Returns:

A tuple of List of conv/linear layers with associated bn op / activation info and a new model with the

Batch Normalization layers folded

API for Cross Layer Scaling

aimet_tensorflow.keras.cross_layer_equalization.CrossLayerScaling.scale_model(model)

Uses cross-layer scaling to scale all applicable layers in the given model :type model: Model :param model: tf.keras.Model :rtype: List[ClsSetInfo] :return: CLS information for each CLS set

API for High Bias Folding

aimet_tensorflow.keras.cross_layer_equalization.HighBiasFold.bias_fold(cls_set_info_list, bn_layers)

Folds bias values greater than 3 * sigma to next layer’s bias :type cls_set_info_list: List[ClsSetInfo] :param cls_set_info_list: List of info elements for each cls set :type bn_layers: Dict[Conv2D, BatchNormalization] :param bn_layers: Key: Conv/Linear layer Value: Corresponding folded BN layer

Code Examples for Higher Level APIs

Required imports

import tensorflow as tf
from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms
from aimet_tensorflow.keras.cross_layer_equalization import HighBiasFold, CrossLayerScaling
from aimet_tensorflow.keras.utils.model_transform_utils import replace_relu6_with_relu

Perform Cross Layer Equalization in auto mode step by step

def cross_layer_equalization_auto_stepwise():
    """
    Individual api calls to perform cross layer equalization one step at a time. Pairs to fold and
    scale are found automatically.
    1. Replace Relu6 with Relu
    2. Fold batch norms
    3. Perform cross layer scaling
    4. Perform high bias fold
    """

    # Load the model to equalize
    model = tf.keras.applications.resnet50.ResNet50(weights=None, classes=10)

    # 1. Replace Relu6 layer with Relu
    model_for_cle, _ = replace_relu6_with_relu(model)

    # 2. Fold all batch norms
    folded_pairs, model = fold_all_batch_norms(model_for_cle)

    bn_dict = {}
    for conv_or_linear, bn in folded_pairs:
        bn_dict[conv_or_linear] = bn

    # 3. Perform cross-layer scaling on applicable layer groups
    cls_set_info_list = CrossLayerScaling.scale_model(model_for_cle)

    # 4. Perform high bias fold
    HighBiasFold.bias_fold(cls_set_info_list, bn_dict)

    return model_for_cle

Lower Level APIs for Cross Layer Equalization

API for Batch Norm Folding on subsets of convolution-batchnorm layer pairs

aimet_tensorflow.keras.batch_norm_fold.fold_given_batch_norms(model, layer_pairs)[source]

Fold a given set of batch_norm layers into conv_linear layers

Parameters:
  • model (Model) – Either a Keras Model or a QuantizationSimModel’s model

  • layer_pairs (List[Union[Tuple[Union[Conv2D, Dense, Conv2DTranspose, DepthwiseConv2D], BatchNormalization, bool], Tuple[BatchNormalization, Union[Conv2D, Dense, Conv2DTranspose, DepthwiseConv2D], bool]]]) – Tuple of conv, bn layers and is_batch_norm_second flag

Return type:

Optional[Model]

Returns:

new model with batch norm layers folded if model is a functional model, else None


API for Cross Layer Scaling on subset of conv layer groups

aimet_tensorflow.keras.cross_layer_equalization.CrossLayerScaling.scale_cls_sets(cls_sets)

Scale each cls set :type cls_sets: List[Union[Tuple[Conv2D, Conv2D], Tuple[Conv2D, DepthwiseConv2D, Conv2D]]] :param cls_sets: Cls sets to scale :rtype: List[Union[ndarray, Tuple[ndarray, ndarray]]] :return: List of scale factors corresponding to each scaled cls set


API for High bias folding

aimet_tensorflow.keras.cross_layer_equalization.HighBiasFold.bias_fold(cls_set_info_list, bn_layers)

Folds bias values greater than 3 * sigma to next layer’s bias :type cls_set_info_list: List[ClsSetInfo] :param cls_set_info_list: List of info elements for each cls set :type bn_layers: Dict[Conv2D, BatchNormalization] :param bn_layers: Key: Conv/Linear layer Value: Corresponding folded BN layer


Custom Datatype used

class aimet_tensorflow.keras.cross_layer_equalization.ClsSetInfo(cls_pair_1, cls_pair_2=None)[source]

This class hold information about the layers in a CLS set, along with corresponding scaling factors and other information like if there is a ReLU activation function between the CLS set layers

Constructor takes 2 pairs if Depth-wise separable layer is being folded :type cls_pair_1: ClsSetLayerPairInfo :param cls_pair_1: Pair between two conv or conv and depth-wise conv :type cls_pair_2: Optional[ClsSetLayerPairInfo] :param cls_pair_2: Pair between depth-wise conv and point-wise conv

class ClsSetLayerPairInfo(layer1, layer2, scale_factor, relu_activation_between_layers)[source]

Models a pair of layers that were scaled using CLS. And related information.

Parameters:
  • layer1 (Conv2D) – Layer whose bias is folded

  • layer2 (Conv2D) – Layer to which bias of previous layer’s bias is folded

  • scale_factor (ndarray) – Scale Factor found from Cross Layer Scaling to scale BN parameters

  • relu_activation_between_layers (bool) – If the activation between layer1 and layer2 is Relu


Code Example for Lower level APIs

Required imports

import tensorflow as tf
from aimet_tensorflow.keras.batch_norm_fold import fold_given_batch_norms
from aimet_tensorflow.keras.cross_layer_equalization import HighBiasFold, CrossLayerScaling
from aimet_tensorflow.keras.utils.model_transform_utils import replace_relu6_with_relu

Perform Cross Layer Equalization in manual mode

def cross_layer_equalization_manual():
    """
    Individual api calls to perform cross layer equalization one step at a time. Pairs to fold and
    scale are provided by the user.
    1. Replace Relu6 with Relu
    2. Fold batch norms
    3. Perform cross layer scaling
    4. Perform high bias fold
    """

    # Load the model to equalize
    model = tf.keras.applications.resnet50.ResNet50(weights=None, classes=10)

    # replace any ReLU6 layers with ReLU
    model_for_cle, _ = replace_relu6_with_relu(model)

    # pick potential pairs of conv and bn ops for fold
    layer_pairs = get_example_layer_pairs_resnet50_for_folding(model_for_cle)

    # fold given layers
    fold_given_batch_norms(model_for_cle, layer_pairs=layer_pairs)

    # Cross Layer Scaling
    # Create a list of consecutive conv layers to be equalized
    consecutive_layer_list = get_consecutive_layer_list_from_resnet50_for_scaling(model_for_cle)

    # invoke api to perform scaling on given list of cls pairs
    scaling_factor_list = CrossLayerScaling.scale_cls_sets(consecutive_layer_list)

    # get info from bn fold and cross layer scaling in format required for high bias fold
    folded_pairs, cls_set_info_list = format_info_for_high_bias_fold(layer_pairs,
                                                                     consecutive_layer_list,
                                                                     scaling_factor_list)

    HighBiasFold.bias_fold(cls_set_info_list, folded_pairs)
    return model_for_cle

Example helper methods to perform CLE in manual mode

Helper to pick layers for batchnorm fold

def get_example_layer_pairs_resnet50_for_folding(model: tf.keras.Model):
    """
    Function to pick example conv-batchnorm layer pairs for folding.
    :param model: Keras model containing conv batchnorm pairs to fold
    :return: pairs of conv and batchnorm layers for batch norm folding in Resnet50 model.
    """

    conv_op_1 = model.layers[2]
    bn_op_1 = model.layers[3]

    conv_op_2 = model.layers[7]
    bn_op_2 = model.layers[8]

    conv_op_3 = model.layers[10]
    bn_op_3 = model.layers[11]

    # make a layer pair list with potential the conv op and bn_op pair along with a flag
    # to indicate if given bn op can be folded upstream or downstream.
    # example of two pairs of conv and bn op  shown below
    layer_pairs = [(conv_op_1, bn_op_1, True),
                   (conv_op_2, bn_op_2, True),
                   (conv_op_3, bn_op_3, True)]

    return layer_pairs

Helper to pick layers for cross layer scaling

def get_consecutive_layer_list_from_resnet50_for_scaling(model: tf.keras.Model):
    """
    helper function to pick example consecutive layer list for scaling.
    :param model: tf.keras.Model
    :return: sample layers for scaling as consecutive_layer_list from Resnet50 model
    """
    conv_op_1 = model.layers[2]
    conv_op_2 = model.layers[7]
    conv_op_3 = model.layers[10]

    consecutive_layer_list = [(conv_op_1, conv_op_2), (conv_op_2, conv_op_3)]
    return consecutive_layer_list

Helper to format data from batchnorm fold and cross layer scaling for usage by high bias fold

def format_info_for_high_bias_fold(layer_pairs, consecutive_layer_list, scaling_factor_list):
    """
    Helper function that formats data from cross layer scaling and bn fold for usage by high bias fold
    :param layer_pairs: info obtained after batchnorm fold
    :param consecutive_layer_list: info obtained after cross layer scaling
    :param scaling_factor_list: scaling params corresponding to consecutive_layer_list
    :return: data formatted for high bias fold
    """

    # convert info after batch norm fold and cross layer scaling for usage by high bias fold api
    folded_pairs = []
    for (conv_op, bn_op_with_meta, _fold_upstream_flag) in layer_pairs:
        folded_pairs.append((conv_op, bn_op_with_meta.op))

    # List that hold a boolean for if there were relu activations between layers of each cross layer scaling set
    is_relu_activation_in_cls_sets = []
    # Note the user is expected to fill in this list manually

    # Convert to a list of cls-set-info elements
    cls_set_info_list = CrossLayerScaling.create_cls_set_info_list(consecutive_layer_list,
                                                                   scaling_factor_list,
                                                                   is_relu_activation_in_cls_sets)

    return folded_pairs, cls_set_info_list