AIMET PyTorch BatchNorm Re-estimation APIs


Batch Norm (BN) Re-estimation re-estimates the statistics of BN layers after performing QAT. Using the re-estimated statistics, the BN layers are folded in to preceding Conv and Linear layers

Top-level APIs

API for BatchNorm Re-estimation

aimet_torch.bn_reestimation.reestimate_bn_stats(model, dataloader, num_batches=100, forward_fn=None)[source]

Reestimate BatchNorm statistics (running mean and var).

  • model (Module) – Model to reestimate the BN stats.

  • dataloader (DataLoader) – Training dataset.

  • num_batches (int) – The number of batches to be used for reestimation.

  • forward_fn (Optional[Callable[[Module, Any], Any]]) – Optional adapter function that performs forward pass given a model and a input batch yielded from the data loader.

Return type:



Handle that undos the effect of BN reestimation upon handle.remove().

API for BatchNorm fold to scale


Fold all batch_norm layers in a model into the quantization scale parameter of the corresponding conv layers


sim (QuantizationSimModel) – QuantizationSimModel

Return type:

List[Tuple[QcQuantizeWrapper, QcQuantizeWrapper]]


A list of pairs of layers [(Conv/Linear, BN layer that got folded)]

Code Example - BN-Reestimation

** Step 1. Load the model**

For this example, we are going to load a pretrained ResNet18 model from torchvision.

def load_fp32_model():

    import torchvision
    from torchvision.models import resnet18
    from aimet_torch.model_preparer import prepare_model

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        device = torch.device("cuda")
        device = torch.device("cpu")

    model = resnet18(pretrained=True).to(device)
    model = prepare_model(model)

    return model, use_cuda

Step 2. Create QuantSim with Range Learning and Per Channel Quantization Enabled

Step 3. Perform QAT

    # User action required
    # The following line of code is an example of how to use an example ImageNetPipeline's train function.
    # Replace the following line with your own pipeline's  train function.
    ImageNetPipeline.train(sim.model, epochs=1, learning_rate=5e-7, learning_rate_schedule=[5, 10], use_cuda=use_cuda)

Step 4 a. Perform BatchNorm Re-estimation

    from aimet_torch.bn_reestimation import reestimate_bn_stats

    # User action required
    # The following line of code is an example of how to use the ImageNet data's training data loader.
    # Replace the following line with your own dataset's training data loader.
    train_loader = ImageNetDataPipeline.get_train_dataloader()

    reestimate_bn_stats(quant_sim.model, train_loader, forward_fn=forward_fn)

Step 4 b. Perform BatchNorm Fold to scale

    from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale


Step 5. Export the model and encodings and test on target

