AIMET TensorFlow Cross Layer Equalization APIs

Introduction

AIMET functionality for TensorFlow Cross Layer Equalization supports three techniques-
  • BatchNorm Folding

  • Cross Layer Scaling

  • High Bias Fold

Cross Layer Equalization API

Listed below is a comprehensive API to apply all available techniques under cross layer equalization. It performs ‘auto’ detection of candidate layers and applies the techniques. If there are no BatchNorm layers in a given model, BatchNorm fold and high bias fold shall be skipped.

API(s) for Cross Layer Equalization

aimet_tensorflow.cross_layer_equalization.equalize_model(sess, start_op_names, output_op_names)[source]

High-level API to perform Cross-Layer Equalization (CLE) on the given model. The model is equalized in place.

Parameters:
  • sess (Session) – tf.compat.v1.Session with model to equalize

  • start_op_names (Union[str, List[str]]) – Names of starting ops in the given model

  • output_op_names (Union[str, List[str]]) – List of output op names of the model, used to help ConnectedGraph determine valid ops (to ignore training ops for example).

Return type:

Session

Returns:

updated session after bn fold, cls and hbf.

Code Example

Required imports

import tensorflow as tf

from tensorflow.keras.applications.resnet50 import ResNet50

# Cross layer Equalization related imports
from aimet_tensorflow.cross_layer_equalization import equalize_model

Cross Layer Equalization in auto mode comprehensive

def cross_layer_equalization_auto():
    """ perform auto cross layer equalization """

    # load a model
    tf.keras.backend.clear_session()
    _ = ResNet50(weights='imagenet', input_shape=(224, 224, 3))
    sess = tf.compat.v1.keras.backend.get_session()

    # get starting op name to invoke api for cle
    input_op_name = 'input_1'
    output_op_name = 'fc1000/Softmax'

    # Equalize a model with Batchnorms
    # Performs BatchNorm fold, replacing Relu6 with Relu, Cross layer scaling and High bias fold
    # use the new session returned for further evaluations on TF graph
    with sess.as_default():
        new_session = equalize_model(sess, input_op_name, output_op_name)
    sess.close()

Primitive APIs

If the user would like to call the APIs individually, then the following APIs can be used-