AIMET PyTorch BatchNorm Re-estimation APIs
Examples Notebook Link
For an end-to-end notebook showing how to use PyTorch Quantization-Aware Training followed by BatchNorm Re-estimation, please see here
Introduction
Batch Norm (BN) Re-estimation re-estimates the statistics of BN layers after performing QAT. Using the re-estimated statistics, the BN layers are folded in to preceding Conv and Linear layers
Top-level APIs
API for BatchNorm Re-estimation
- aimet_torch.bn_reestimation.reestimate_bn_stats(model, dataloader, num_batches=100, forward_fn=None)[source]
Reestimate BatchNorm statistics (running mean and var).
- Parameters:
model (
Module
) – Model to reestimate the BN stats.dataloader (
DataLoader
) – Training dataset.num_batches (
int
) – The number of batches to be used for reestimation.forward_fn (
Optional
[Callable
[[Module
,Any
],Any
]]) – Optional adapter function that performs forward pass given a model and a input batch yielded from the data loader.
- Return type:
Handle
- Returns:
Handle that undos the effect of BN reestimation upon handle.remove().
API for BatchNorm fold to scale
- aimet_torch.batch_norm_fold.fold_all_batch_norms_to_scale(sim)
Fold all batch_norm layers in a model into the quantization scale parameter of the corresponding conv layers
- Parameters:
sim (
QuantizationSimModel
) – QuantizationSimModel- Return type:
List
[Tuple
[QcQuantizeWrapper
,QcQuantizeWrapper
]]- Returns:
A list of pairs of layers [(Conv/Linear, BN layer that got folded)]
Code Example - BN-Reestimation
** Step 1. Load the model**
For this example, we are going to load a pretrained ResNet18 model from torchvision.
def load_fp32_model():
import torchvision
from torchvision.models import resnet18
from aimet_torch.model_preparer import prepare_model
use_cuda = torch.cuda.is_available()
if use_cuda:
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = resnet18(pretrained=True).to(device)
model = prepare_model(model)
return model, use_cuda
Step 2. Create QuantSim with Range Learning and Per Channel Quantization Enabled
For an example of creating QuantSim with Range Learning QuantScheme, please see here
For how to enable Per Channel Quantization, please see here
Step 3. Perform QAT
# User action required
# The following line of code is an example of how to use an example ImageNetPipeline's train function.
# Replace the following line with your own pipeline's train function.
ImageNetPipeline.train(sim.model, epochs=1, learning_rate=5e-7, learning_rate_schedule=[5, 10], use_cuda=use_cuda)
Step 4 a. Perform BatchNorm Re-estimation
from aimet_torch.bn_reestimation import reestimate_bn_stats
# User action required
# The following line of code is an example of how to use the ImageNet data's training data loader.
# Replace the following line with your own dataset's training data loader.
train_loader = ImageNetDataPipeline.get_train_dataloader()
reestimate_bn_stats(quant_sim.model, train_loader, forward_fn=forward_fn)
Step 4 b. Perform BatchNorm Fold to scale
from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale
fold_all_batch_norms_to_scale(quant_sim)
Step 5. Export the model and encodings and test on target
For how to export the model and encodings, please see here