AIMET TensorFlow Cross Layer Equalization APIs

Introduction

AIMET functionality for TensorFlow Cross Layer Equalization supports three techniques-
  • BatchNorm Folding

  • Cross Layer Scaling

  • High Bias Fold

Cross Layer Equalization API

Listed below is a comprehensive API to apply all available techniques under cross layer equalization. It performs ‘auto’ detection of candidate layers and applies the techniques. If there are no BatchNorm layers in a given model, BatchNorm fold and high bias fold shall be skipped.

API(s) for Cross Layer Equalization

aimet_tensorflow.cross_layer_equalization.equalize_model(sess, start_op_names, output_op_names)

High-level API to perform Cross-Layer Equalization (CLE) on the given model. The model is equalized in place.

Parameters
  • sess (Session) – tf.compat.v1.Session with model to equalize

  • start_op_names (Union[str, List[str]]) – Names of starting ops in the given model

  • output_op_names (Union[str, List[str]]) – List of output op names of the model, used to help ConnectedGraph determine valid ops (to ignore training ops for example).

Return type

Session

Returns

updated session after bn fold, cls and hbf.

Code Example

Required imports

# tensorflow
import tensorflow as tf

from tensorflow.keras.applications.resnet50 import ResNet50

# Cross layer Equalization related imports
from aimet_tensorflow.cross_layer_equalization import equalize_model

Cross Layer Equalization in auto mode comprehensive

def cross_layer_equalization_auto():
    """ perform auto cross layer equalization """

    # load a model
    tf.keras.backend.clear_session()
    _ = ResNet50(weights='imagenet', input_shape=(224, 224, 3))
    sess = tf.compat.v1.keras.backend.get_session()

    # get starting op name to invoke api for cle
    input_op_name = 'input_1'
    output_op_name = 'fc1000/Softmax'

    # Equalize a model with Batchnorms
    # Performs BatchNorm fold, replacing Relu6 with Relu, Cross layer scaling and High bias fold
    # use the new session returned for further evaluations on TF graph
    with sess.as_default():
        new_session = equalize_model(sess, input_op_name, output_op_name)
    sess.close()

Primitive APIs

If the user would like to call the APIs individually, then the following APIs can be used-