Batch norm re-estimation¶
Context¶
If applying batch norm folding to your model negatively impacts performance, the batch norm re-estimation feature may be of use. This feature uses a small subset of training data to re-estimate the statistics of the batch norm (BN) layers in a model. Using the re-estimated statistics, the BN layers are folded into the preceding convolution or linear layers.
BN re-estimation is also recommended in the following cases:
Models where the main issue is weight quantization
Quantization of depth-wise separable layers as their batch norm statistics are sensitive to oscillations
Workflow¶
Prerequisites¶
To use BN re-estimation, you must:
Load a trained model
Create a training dataloader for the model
Hold off on folding the batch norm layers until after quantization aware training (QAT)
Setup¶
import torch
from torchvision.models import mobilenet_v2
from torch.utils.data import DataLoader
from datasets import load_dataset
# General setup that can be changed as needed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(pretrained=True).eval().to(device)
num_batches = 32
data = load_dataset('imagenet-1k', streaming=True, split="train")
data_loader = DataLoader(data, batch_size=num_batches, num_workers = 4)
dummy_input = torch.randn(1, 3, 224, 224).to(device)
from aimet_common.defs import QuantScheme
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms_to_scale
from aimet_tensorflow.keras.bn_reestimation import reestimate_bn_stats
from aimet_tensorflow.keras.quantsim import QuantizationSimModel
from tensorflow.keras import applications, losses, metrics, optimizers, preprocessing
from tensorflow.keras.applications import mobilenet_v2
model = applications.MobileNetV2()
# Set up dataset
BATCH_SIZE = 32
imagenet_dataset = preprocessing.image_dataset_from_directory(
directory='<your_imagenet_validation_data_path>',
label_mode='categorical',
image_size=(224, 224),
batch_size=BATCH_SIZE,
shuffle=True,
)
imagenet_dataset = imagenet_dataset.map(
lambda x, y: (mobilenet_v2.preprocess_input(x), y)
)
NUM_CALIBRATION_SAMPLES = 2048
calibration_dataset = imagenet_dataset.take(NUM_CALIBRATION_SAMPLES // BATCH_SIZE)
eval_dataset = imagenet_dataset.skip(NUM_CALIBRATION_SAMPLES // BATCH_SIZE)
Step 1¶
Create the QuantizationSimModel
When creating the QuantizationSimModel model, ensure that per channel quantization is enabled. Please update the config file if needed.
from aimet_torch.quantsim import QuantizationSimModel
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
sim = QuantizationSimModel(model=model, dummy_input=dummy_input, config_file=get_path_for_per_channel_config())
def pass_calibration_data(model, _):
for inputs, _ in calibration_dataset:
_ = model(inputs)
sim = QuantizationSimModel(
model,
quant_scheme=QuantScheme.training_range_learning_with_tf_init,
default_param_bw=4,
default_output_bw=8,
config_file=get_path_for_per_channel_config(),
)
sim.compute_encodings(pass_calibration_data, None)
Step 2¶
Perform QAT
This involves training your model for a few additional epochs (usually around 15-20). When training, be aware of the hyper-parameters being used.
num_epochs = 20
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(num_epochs):
for images, labels in data_loader:
images, labels = images.to(device), labels.to(device)
output = sim.model(images)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
sim.model.compile(
optimizer=optimizers.SGD(learning_rate=1e-5),
loss=[losses.CategoricalCrossentropy()],
metrics=[metrics.CategoricalAccuracy()],
)
sim.model.fit(calibration_dataset, epochs=10)
_, accuracy = sim.model.evaluate(eval_dataset)
print(f'Model accuracy before BN re-estimation: {accuracy:.4f}')
Model accuracy before BN re-estimation: 0.0428
Step 3¶
Re-estimate the BN statistics and fold the BN layers.
from aimet_torch.bn_reestimation import reestimate_bn_stats
from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale
reestimate_bn_stats(sim.model, data_loader)
fold_all_batch_norms_to_scale(sim)
unlabeled_dataset = calibration_dataset.map(lambda x, _: x)
reestimate_bn_stats(
sim.model, unlabeled_dataset, bn_num_batches=NUM_CALIBRATION_SAMPLES // BATCH_SIZE
)
_, accuracy = sim.model.evaluate(eval_dataset)
print(f'Model accuracy after BN re-estimation: {accuracy:.4f}')
fold_all_batch_norms_to_scale(sim)
Model accuracy after BN re-estimation: 0.5876
Step 4¶
If BN re-estimation resulted in satisfactory accuracy, export the model.
path = './'
filename = 'mobilenet'
sim.export(path=path, filename_prefix=filename, dummy_input=dummy_input.cpu())
sim.export(path='/tmp', filename_prefix='quantized_mobilenet_v2')
API¶
Top-level API
- aimet_torch.bn_reestimation.reestimate_bn_stats(model, dataloader, num_batches=100, forward_fn=None)[source]¶
Reestimate BatchNorm statistics (running mean and var).
- Parameters:
model (
Module
) – Model to reestimate the BN stats.dataloader (
DataLoader
) – Training dataset.num_batches (
int
) – The number of batches to be used for reestimation.forward_fn (
Optional
[Callable
[[Module
,Any
],Any
]]) – Optional adapter function that performs forward pass given a model and a input batch yielded from the data loader.
- Return type:
Handle
- Returns:
Handle that undos the effect of BN reestimation upon handle.remove().
Top-level API
- aimet_tensorflow.keras.bn_reestimation.reestimate_bn_stats(model, bn_re_estimation_dataset, bn_num_batches=100)[source]¶
top level api for end user directly call
- Parameters:
model (
Model
) – tf.keras.Modelbn_re_estimation_dataset (
DatasetV2
) – Training datasetbn_num_batches (
int
) – The number of batches to be used for reestimation
- Return type:
Handle
- Returns:
Handle that undos the effect of BN reestimation upon handle.remove()
- aimet_tensorflow.keras.batch_norm_fold.fold_all_batch_norms_to_scale(sim)[source]¶
Fold all batch_norm layers in a model into the quantization scale parameter of the corresponding conv layers
- Parameters:
sim (
QuantizationSimModel
) – QuantizationSimModel to be folded- Return type:
List
[Tuple
[QcQuantizeWrapper
,QcQuantizeWrapper
]]- Returns:
A list of pairs of layers [(Conv/Linear, BN layer that got folded)]