AIMET Keras BatchNorm Re-estimation APIs

Introduction

AIMET functionality for Keras BatchNorm Re-estimation recalculates the batchnorm statistics based on the model after QAT. By doing so, we aim to make our model learn batchnorm statistics from from stable outputs after QAT, rather than from likely noisy outputs during QAT.

Top-level APIs

API for BatchNorm Re-estimation

aimet_tensorflow.keras.bn_reestimation.reestimate_bn_stats(model, bn_re_estimation_dataset, bn_num_batches=100)[source]

top level api for end user directly call

Parameters
  • model (Model) – tf.keras.Model

  • bn_re_estimation_dataset (DatasetV2) – Training dataset

  • bn_num_batches (int) – The number of batches to be used for reestimation

Return type

Handle

Returns

Handle that undos the effect of BN reestimation upon handle.remove()

API for BatchNorm fold to scale

aimet_tensorflow.keras.batch_norm_fold.fold_all_batch_norms_to_scale(sim)[source]

Fold all batch_norm layers in a model into the quantization scale parameter of the corresponding conv layers

Parameters

sim (QuantizationSimModel) – QuantizationSimModel to be folded

Return type

List[Tuple[QcQuantizeWrapper, QcQuantizeWrapper]]

Returns

A list of pairs of layers [(Conv/Linear, BN layer that got folded)]

Code Example

Required imports

from aimet_tensorflow.keras.bn_reestimation import reestimate_bn_stats
from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms_to_scale

Prepare BatchNorm Re-estimation dataset

batch_size = 4
dataset = tf.data.Dataset.from_tensor_slices(x_train[0:100])
dataset = dataset.batch(batch_size=batch_size)
dummy_inputs = x_train[0:4]

Perform BatchNorm Re-estimation

reestimate_bn_stats(qsim.model, dataset, 1)

Perform BatchNorm Fold to scale

fold_all_batch_norms_to_scale(qsim)

Limitations

Please see The AIMET Keras ModelPreparer API limitations: