Batch norm re-estimation¶
Context¶
Batch norm re-estimation (BN re-estimation) uses a small subset of training data to re-estimate the statistics of the batch norm (BN) layers in a model. AIMET then folds the BN layers into the preceding convolution or linear layers.
BN re-estimation is recommended under the following conditions:
When batch norm folding (BNF) reduces performance
In models where the main issue is weight quantization
In quantization of depth-wise separable layers, as their batch norm statistics are sensitive to oscillations
Workflow¶
Prerequisites¶
To use BN re-estimation, you must:
Load a trained model
Create a training dataloader for the model
Hold off on folding the batch norm layers until after quantization aware training (QAT)
Execution¶
Setup¶
import torch
from torchvision.models import mobilenet_v2
from torch.utils.data import DataLoader
from datasets import load_dataset
# General setup that can be changed as needed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(pretrained=True).eval().to(device)
num_batches = 32
data = load_dataset('imagenet-1k', streaming=True, split="train")
data_loader = DataLoader(data, batch_size=num_batches, num_workers = 4)
dummy_input = torch.randn(1, 3, 224, 224).to(device)
Not supported.
Step 1¶
Create the quantization simulation mdoel (QuantizationSimModel).
When creating the QuantizationSimModel model, ensure that per channel quantization is enabled. Update the config file if needed.
from aimet_torch.quantsim import QuantizationSimModel
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
sim = QuantizationSimModel(model=model, dummy_input=dummy_input, config_file=get_path_for_per_channel_config())
Not supported.
Step 2¶
Perform Quantization-aware training (QAT).
QAT involves training your model for a few additional epochs (usually 15-20). When training, be aware of the hyper-parameters being used.
num_epochs = 20
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(num_epochs):
for images, labels in data_loader:
images, labels = images.to(device), labels.to(device)
output = sim.model(images)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Model accuracy before BN re-estimation: 0.0428
Not supported.
Step 3¶
Re-estimate the BN statistics and fold the BN layers.
from aimet_torch.bn_reestimation import reestimate_bn_stats
from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale
reestimate_bn_stats(sim.model, data_loader)
fold_all_batch_norms_to_scale(sim)
Model accuracy after BN re-estimation: 0.5876
Not supported.
Step 4¶
If BN re-estimation resulted in satisfactory accuracy, export the model.
path = './'
filename = 'mobilenet'
sim.export(path=path, filename_prefix=filename, dummy_input=dummy_input.cpu())
Not supported.
API¶
Top-level API
- aimet_torch.bn_reestimation.reestimate_bn_stats(model, dataloader, num_batches=100, forward_fn=None)[source]¶
Reestimate BatchNorm statistics (running mean and var).
- Parameters:
model (
Module
) – Model to reestimate the BN stats.dataloader (
DataLoader
) – Training dataset.num_batches (
int
) – The number of batches to be used for reestimation.forward_fn (
Optional
[Callable
[[Module
,Any
],Any
]]) – Optional adapter function that performs forward pass given a model and a input batch yielded from the data loader.
- Return type:
Handle
- Returns:
Handle that undos the effect of BN reestimation upon handle.remove().
Not supported.