AIMET TensorFlow Cross Layer Equalization APIs¶
Introduction¶
- AIMET functionality for TensorFlow Cross Layer Equalization supports three techniques-
BatchNorm Folding
Cross Layer Scaling
High Bias Fold
Cross Layer Equalization API¶
Listed below is a comprehensive API to apply all available techniques under cross layer equalization. It performs ‘auto’ detection of candidate layers and applies the techniques. If there are no BatchNorm layers in a given model, BatchNorm fold and high bias fold shall be skipped.
API(s) for Cross Layer Equalization
-
aimet_tensorflow.cross_layer_equalization.
equalize_model
(sess, start_op_names, output_op_names)¶ High-level API to perform Cross-Layer Equalization (CLE) on the given model. The model is equalized in place.
- Parameters
sess (
Session
) – tf.compat.v1.Session with model to equalizestart_op_names (
Union
[str
,List
[str
]]) – Names of starting ops in the given modeloutput_op_names (
Union
[str
,List
[str
]]) – List of output op names of the model, used to help ConnectedGraph determine valid ops (to ignore training ops for example).
- Return type
Session
- Returns
updated session after bn fold, cls and hbf.
Code Example¶
Required imports
# tensorflow
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50
# Cross layer Equalization related imports
from aimet_tensorflow.cross_layer_equalization import equalize_model
Cross Layer Equalization in auto mode comprehensive
def cross_layer_equalization_auto():
""" perform auto cross layer equalization """
# load a model
tf.keras.backend.clear_session()
_ = ResNet50(weights='imagenet', input_shape=(224, 224, 3))
sess = tf.compat.v1.keras.backend.get_session()
# get starting op name to invoke api for cle
input_op_name = 'input_1'
output_op_name = 'fc1000/Softmax'
# Equalize a model with Batchnorms
# Performs BatchNorm fold, replacing Relu6 with Relu, Cross layer scaling and High bias fold
# use the new session returned for further evaluations on TF graph
with sess.as_default():
new_session = equalize_model(sess, input_op_name, output_op_name)
sess.close()
Primitive APIs¶
If the user would like to call the APIs individually, then the following APIs can be used-