Batch norm re-estimation

Context

Batch norm re-estimation (BN re-estimation) uses a small subset of training data to re-estimate the statistics of the batch norm (BN) layers in a model. AIMET then folds the BN layers into the preceding convolution or linear layers.

BN re-estimation is recommended under the following conditions:

  • When batch norm folding (BNF) reduces performance

  • In models where the main issue is weight quantization

  • In quantization of depth-wise separable layers, as their batch norm statistics are sensitive to oscillations

Workflow

Prerequisites

To use BN re-estimation, you must:

  • Load a trained model

  • Create a training dataloader for the model

  • Hold off on folding the batch norm layers until after quantization aware training (QAT)

Execution

Setup

import torch
from torchvision.models import mobilenet_v2
from torch.utils.data import DataLoader
from datasets import load_dataset

# General setup that can be changed as needed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(pretrained=True).eval().to(device)
num_batches = 32
data = load_dataset('imagenet-1k', streaming=True, split="train")
data_loader = DataLoader(data, batch_size=num_batches, num_workers = 4)
dummy_input = torch.randn(1, 3, 224, 224).to(device)

Not supported.

Step 1

Create the quantization simulation mdoel (QuantizationSimModel).

When creating the QuantizationSimModel model, ensure that per channel quantization is enabled. Update the config file if needed.

from aimet_torch.quantsim import QuantizationSimModel
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config

sim = QuantizationSimModel(model=model, dummy_input=dummy_input, config_file=get_path_for_per_channel_config())

Not supported.

Step 2

Perform Quantization-aware training (QAT).

QAT involves training your model for a few additional epochs (usually 15-20). When training, be aware of the hyper-parameters being used.

num_epochs = 20
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        output = sim.model(images)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Model accuracy before BN re-estimation: 0.0428

Not supported.

Step 3

Re-estimate the BN statistics and fold the BN layers.

from aimet_torch.bn_reestimation import reestimate_bn_stats
from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale

reestimate_bn_stats(sim.model, data_loader)
fold_all_batch_norms_to_scale(sim)

Model accuracy after BN re-estimation: 0.5876

Not supported.

Step 4

If BN re-estimation resulted in satisfactory accuracy, export the model.

path = './'
filename = 'mobilenet'
sim.export(path=path, filename_prefix=filename, dummy_input=dummy_input.cpu())

Not supported.

API

Top-level API

aimet_torch.bn_reestimation.reestimate_bn_stats(model, dataloader, num_batches=100, forward_fn=None)[source]

Reestimate BatchNorm statistics (running mean and var).

Parameters:
  • model (Module) – Model to reestimate the BN stats.

  • dataloader (DataLoader) – Training dataset.

  • num_batches (int) – The number of batches to be used for reestimation.

  • forward_fn (Optional[Callable[[Module, Any], Any]]) – Optional adapter function that performs forward pass given a model and a input batch yielded from the data loader.

Return type:

Handle

Returns:

Handle that undos the effect of BN reestimation upon handle.remove().

Not supported.