AIMET PyTorch BatchNorm Re-estimation APIs


Batch Norm (BN) Re-estimation re-estimates the statistics of BN layers after performing QAT. Using the re-estimated statistics, the BN layers are folded in to preceding Conv and Linear layers

Top-level APIs

API for BatchNorm Re-estimation

aimet_torch.bn_reestimation.reestimate_bn_stats(model, dataloader, num_batches=100, forward_fn=None)[source]

Reestimate BatchNorm statistics (running mean and var).

  • model (Module) – Model to reestimate the BN stats.

  • dataloader (DataLoader) – Training dataset.

  • num_batches (int) – The number of batches to be used for reestimation.

  • forward_fn (Optional[Callable[[Module, Any], Any]]) – Optional adapter function that performs forward pass given a model and a input batch yielded from the data loader.

Return type:



Handle that undos the effect of BN reestimation upon handle.remove().

API for BatchNorm fold to scale


Fold all batch_norm layers in a model into the quantization scale parameter of the corresponding conv layers


sim (QuantizationSimModel) – QuantizationSimModel

Return type:

List[Tuple[QcQuantizeWrapper, QcQuantizeWrapper]]


A list of pairs of layers [(Conv/Linear, BN layer that got folded)]

Code Example - BN-Reestimation

** Step 1. Load the model**

For this example, we are going to load a pretrained ResNet18 model from torchvision.

def load_fp32_model():

    import torchvision
    from torchvision.models import resnet18
    from aimet_torch.model_preparer import prepare_model

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        device = torch.device("cuda")
        device = torch.device("cpu")

    model = resnet18(pretrained=True).to(device)
    model = prepare_model(model)

    return model, use_cuda

Step 2. Create QuantSim with Range Learning and Per Channel Quantization Enabled

  1. For an example of creating QuantSim with Range Learning QuantScheme, please see here

  2. For how to enable Per Channel Quantization, please see here

Step 3. Perform QAT

    # User action required
    # The following line of code is an example of how to use an example ImageNetPipeline's train function.
    # Replace the following line with your own pipeline's  train function.
    ImageNetPipeline.train(sim.model, epochs=1, learning_rate=5e-7, learning_rate_schedule=[5, 10], use_cuda=use_cuda)

Step 4 a. Perform BatchNorm Re-estimation

    from aimet_torch.bn_reestimation import reestimate_bn_stats

    # User action required
    # The following line of code is an example of how to use the ImageNet data's training data loader.
    # Replace the following line with your own dataset's training data loader.
    train_loader = ImageNetDataPipeline.get_train_dataloader()

    reestimate_bn_stats(quant_sim.model, train_loader, forward_fn=forward_fn)

Step 4 b. Perform BatchNorm Fold to scale

    from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale


Step 5. Export the model and encodings and test on target

For how to export the model and encodings, please see here