AIMET TensorFlow BatchNorm Re-estimation APIs
Examples Notebook Link
For an end-to-end notebook showing how to use Keras Quantization-Aware Training with BatchNorm Re-estimation, please see here.
Introduction
Batch Norm (BN) Re-estimation re-estimates the statistics of BN layers after performing QAT. Using the re-estimated statistics, the BN layers are folded in to preceding Conv and Linear layers
Top-level APIs
API for BatchNorm Re-estimation
- aimet_tensorflow.bn_reestimation.reestimate_bn_stats(sim, start_op_names, output_op_names, dataset, num_batches=100)[source]
Reestimate BatchNorm statistics (running mean and var).
- Parameters
sim (
QuantizationSimModel
) – QuantizationSimModel object.start_op_names (
List
[str
]) – List of starting op names of the modeloutput_op_names (
List
[str
]) – List of output op names of the modeldataset (
DatasetV1
) – Training datasetnum_batches (
int
) – The number of batches to be used for reestimation
- Return type
Handle
- Returns
Handle that undos the effect of BN reestimation upon handle.remove()
API for BatchNorm fold to scale
- aimet_tensorflow.batch_norm_fold.fold_all_batch_norms_to_scale(sim, starting_op_names, output_op_names)[source]
Fold all batch_norm layers in a model into the quantization scale parameter of the corresponding conv layers
- Parameters
sim (
QuantizationSimModel
) – tf quantized modelstarting_op_names (
List
[str
]) – List of starting op names of the modeloutput_op_names (
List
[str
]) – List of output op names of the model
Code Example - BN-Reestimation
Step 1. Load the model
For this example, we are going to load a pretrained ResNet18 model.
def load_fp32_model():
from tensorflow.compat.v1.keras.applications.resnet import ResNet50
tf.keras.backend.clear_session()
model = ResNet50(weights='imagenet', input_shape=(224, 224, 3))
sess = tf.keras.backend.get_session()
# Following lines are additional steps to make keras model work with AIMET.
from Examples.tensorflow.utils.add_computational_nodes_in_graph import add_image_net_computational_nodes_in_graph
add_image_net_computational_nodes_in_graph(sess, model.output.name, image_net_config.dataset['images_classes'])
input_op_names = [model.input.op.name]
output_op_names = [model.output.op.name]
return sess, input_op_names, output_op_names
Step 2. Create QuantSim with Range Learning and Per Channel Quantization Enabled
For an example of creating QuantSim with Range Learning QuantScheme, please see here
For how to enable Per Channel Quantization, please see here
Step 3. Perform QAT
update_ops_name = [op.name for op in model.updates] # Used for finetuning
# User action required
# The following line of code is an example of how to use an example ImageNetPipeline's train function.
# Replace the following line with your own pipeline's train function.
ImageNetDataPipeline.finetune(quant_sim.session, update_ops_name=update_ops_name, epochs=1, learning_rate=5e-7, decay_steps=5)
Step 4 a. Perform BatchNorm Re-estimation
from aimet_tensorflow.bn_reestimation import reestimate_bn_stats
reestimate_bn_stats(quant_sim, start_op_names=input_op_names, output_op_names=output_op_names,
bn_re_estimation_dataset=bn_re_restimation_dataset, bn_num_batches=100)
Step 4 b. Perform BatchNorm Fold to scale
from aimet_tensorflow.batch_norm_fold import fold_all_batch_norms_to_scale
fold_all_batch_norms_to_scale(quant_sim, input_op_names, output_op_names)
Step 5. Export the model and encodings and test on target
For how to export the model and encodings, please see here