AIMET Keras BatchNorm Re-estimation APIs¶
Examples Notebook Link¶
For an end-to-end notebook showing how to use Keras Quantization-Aware Training with BatchNorm Re-estimation, please see here.
Introduction¶
AIMET functionality for Keras BatchNorm Re-estimation recalculates the batchnorm statistics based on the model after QAT. By doing so, we aim to make our model learn batchnorm statistics from from stable outputs after QAT, rather than from likely noisy outputs during QAT.
Top-level APIs¶
API for BatchNorm Re-estimation
-
aimet_tensorflow.keras.bn_reestimation.
reestimate_bn_stats
(model, bn_re_estimation_dataset, bn_num_batches=100)[source]¶ top level api for end user directly call
- Parameters
model (
Model
) – tf.keras.Modelbn_re_estimation_dataset (
DatasetV2
) – Training datasetbn_num_batches (
int
) – The number of batches to be used for reestimation
- Return type
Handle
- Returns
Handle that undos the effect of BN reestimation upon handle.remove()
API for BatchNorm fold to scale
-
aimet_tensorflow.keras.batch_norm_fold.
fold_all_batch_norms_to_scale
(sim)[source]¶ Fold all batch_norm layers in a model into the quantization scale parameter of the corresponding conv layers
- Parameters
sim (
QuantizationSimModel
) – QuantizationSimModel to be folded- Return type
List
[Tuple
[QcQuantizeWrapper
,QcQuantizeWrapper
]]- Returns
A list of pairs of layers [(Conv/Linear, BN layer that got folded)]
Code Example¶
Required imports
from aimet_tensorflow.keras.quantsim import QuantizationSimModel
from aimet_tensorflow.keras.bn_reestimation import reestimate_bn_stats
Prepare BatchNorm Re-estimation dataset
batch_size = 4
dataset = tf.data.Dataset.from_tensor_slices(x_train[0:100])
dataset = dataset.batch(batch_size=batch_size)
dummy_inputs = x_train[0:4]
Perform BatchNorm Re-estimation
reestimate_bn_stats(qsim.model, dataset, 1)
Perform BatchNorm Fold to scale
fold_all_batch_norms_to_scale(qsim)
Limitations¶
Please see The AIMET Keras ModelPreparer API limitations: