AIMET TensorFlow BatchNorm Re-estimation APIs

Introduction

Batch Norm (BN) Re-estimation re-estimates the statistics of BN layers after performing QAT. Using the re-estimated statistics, the BN layers are folded in to preceding Conv and Linear layers

Top-level APIs

API for BatchNorm Re-estimation

aimet_tensorflow.bn_reestimation.reestimate_bn_stats(sim, start_op_names, output_op_names, dataset, num_batches=100)[source]

Reestimate BatchNorm statistics (running mean and var).

Parameters
  • sim (QuantizationSimModel) – QuantizationSimModel object.

  • start_op_names (List[str]) – List of starting op names of the model

  • output_op_names (List[str]) – List of output op names of the model

  • dataset (DatasetV1) – Training dataset

  • num_batches (int) – The number of batches to be used for reestimation

Return type

Handle

Returns

Handle that undos the effect of BN reestimation upon handle.remove()

API for BatchNorm fold to scale

aimet_tensorflow.batch_norm_fold.fold_all_batch_norms_to_scale(sim, starting_op_names, output_op_names)[source]

Fold all batch_norm layers in a model into the quantization scale parameter of the corresponding conv layers

Parameters
  • sim (QuantizationSimModel) – tf quantized model

  • starting_op_names (List[str]) – List of starting op names of the model

  • output_op_names (List[str]) – List of output op names of the model

Code Example - BN-Reestimation

Step 1. Load the model

For this example, we are going to load a pretrained ResNet18 model.


def load_fp32_model():

    from tensorflow.compat.v1.keras.applications.resnet import ResNet50

    tf.keras.backend.clear_session()
    model = ResNet50(weights='imagenet', input_shape=(224, 224, 3))
    sess = tf.keras.backend.get_session()

    # Following lines are additional steps to make keras model work with AIMET.
    from Examples.tensorflow.utils.add_computational_nodes_in_graph import add_image_net_computational_nodes_in_graph
    add_image_net_computational_nodes_in_graph(sess, model.output.name, image_net_config.dataset['images_classes'])

    input_op_names = [model.input.op.name]
    output_op_names = [model.output.op.name]

    return sess, input_op_names, output_op_names

Step 2. Create QuantSim with Range Learning and Per Channel Quantization Enabled

  1. For an example of creating QuantSim with Range Learning QuantScheme, please see here

  2. For how to enable Per Channel Quantization, please see here

Step 3. Perform QAT


    update_ops_name = [op.name for op in model.updates] # Used for finetuning

    # User action required
    # The following line of code is an example of how to use an example ImageNetPipeline's train function.
    # Replace the following line with your own pipeline's  train function.
    ImageNetDataPipeline.finetune(quant_sim.session, update_ops_name=update_ops_name, epochs=1, learning_rate=5e-7, decay_steps=5)

Step 4 a. Perform BatchNorm Re-estimation


    from aimet_tensorflow.bn_reestimation import reestimate_bn_stats

    reestimate_bn_stats(quant_sim, start_op_names=input_op_names, output_op_names=output_op_names,
                        bn_re_estimation_dataset=bn_re_restimation_dataset, bn_num_batches=100)

Step 4 b. Perform BatchNorm Fold to scale


    from aimet_tensorflow.batch_norm_fold import fold_all_batch_norms_to_scale

    fold_all_batch_norms_to_scale(quant_sim, input_op_names, output_op_names)

Step 5. Export the model and encodings and test on target

For how to export the model and encodings, please see here