Batch norm re-estimation

Context

If applying batch norm folding to your model negatively impacts performance, the batch norm re-estimation feature may be of use. This feature uses a small subset of training data to re-estimate the statistics of the batch norm (BN) layers in a model. Using the re-estimated statistics, the BN layers are folded into the preceding convolution or linear layers.

BN re-estimation is also recommended in the following cases:

  • Models where the main issue is weight quantization

  • Quantization of depth-wise separable layers as their batch norm statistics are sensitive to oscillations

Workflow

Prerequisites

To use BN re-estimation, you must:

  • Load a trained model

  • Create a training dataloader for the model

  • Hold off on folding the batch norm layers until after quantization aware training (QAT)

Setup

import torch
from torchvision.models import mobilenet_v2
from torch.utils.data import DataLoader
from datasets import load_dataset

# General setup that can be changed as needed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(pretrained=True).eval().to(device)
num_batches = 32
data = load_dataset('imagenet-1k', streaming=True, split="train")
data_loader = DataLoader(data, batch_size=num_batches, num_workers = 4)
dummy_input = torch.randn(1, 3, 224, 224).to(device)

from aimet_common.defs import QuantScheme
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms_to_scale
from aimet_tensorflow.keras.bn_reestimation import reestimate_bn_stats
from aimet_tensorflow.keras.quantsim import QuantizationSimModel
from tensorflow.keras import applications, losses, metrics, optimizers, preprocessing
from tensorflow.keras.applications import mobilenet_v2

model = applications.MobileNetV2()

# Set up dataset
BATCH_SIZE = 32
imagenet_dataset = preprocessing.image_dataset_from_directory(
    directory='<your_imagenet_validation_data_path>',
    label_mode='categorical',
    image_size=(224, 224),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

imagenet_dataset = imagenet_dataset.map(
    lambda x, y: (mobilenet_v2.preprocess_input(x), y)
)
NUM_CALIBRATION_SAMPLES = 2048
calibration_dataset = imagenet_dataset.take(NUM_CALIBRATION_SAMPLES // BATCH_SIZE)
eval_dataset = imagenet_dataset.skip(NUM_CALIBRATION_SAMPLES // BATCH_SIZE)

Step 1

Create the QuantizationSimModel

When creating the QuantizationSimModel model, ensure that per channel quantization is enabled. Please update the config file if needed.

from aimet_torch.quantsim import QuantizationSimModel
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config

sim = QuantizationSimModel(model=model, dummy_input=dummy_input, config_file=get_path_for_per_channel_config())
def pass_calibration_data(model, _):
    for inputs, _ in calibration_dataset:
        _ = model(inputs)


sim = QuantizationSimModel(
    model,
    quant_scheme=QuantScheme.training_range_learning_with_tf_init,
    default_param_bw=4,
    default_output_bw=8,
    config_file=get_path_for_per_channel_config(),
)
sim.compute_encodings(pass_calibration_data, None)

Step 2

Perform QAT

This involves training your model for a few additional epochs (usually around 15-20). When training, be aware of the hyper-parameters being used.

num_epochs = 20
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        output = sim.model(images)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

sim.model.compile(
    optimizer=optimizers.SGD(learning_rate=1e-5),
    loss=[losses.CategoricalCrossentropy()],
    metrics=[metrics.CategoricalAccuracy()],
)

sim.model.fit(calibration_dataset, epochs=10)
_, accuracy = sim.model.evaluate(eval_dataset)
print(f'Model accuracy before BN re-estimation: {accuracy:.4f}')
Model accuracy before BN re-estimation: 0.0428

Step 3

Re-estimate the BN statistics and fold the BN layers.

from aimet_torch.bn_reestimation import reestimate_bn_stats
from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale

reestimate_bn_stats(sim.model, data_loader)
fold_all_batch_norms_to_scale(sim)

unlabeled_dataset = calibration_dataset.map(lambda x, _: x)
reestimate_bn_stats(
    sim.model, unlabeled_dataset, bn_num_batches=NUM_CALIBRATION_SAMPLES // BATCH_SIZE
)
_, accuracy = sim.model.evaluate(eval_dataset)
print(f'Model accuracy after BN re-estimation: {accuracy:.4f}')

fold_all_batch_norms_to_scale(sim)
Model accuracy after BN re-estimation: 0.5876

Step 4

If BN re-estimation resulted in satisfactory accuracy, export the model.

path = './'
filename = 'mobilenet'
sim.export(path=path, filename_prefix=filename, dummy_input=dummy_input.cpu())
sim.export(path='/tmp', filename_prefix='quantized_mobilenet_v2')

API

Top-level API

aimet_torch.bn_reestimation.reestimate_bn_stats(model, dataloader, num_batches=100, forward_fn=None)[source]

Reestimate BatchNorm statistics (running mean and var).

Parameters:
  • model (Module) – Model to reestimate the BN stats.

  • dataloader (DataLoader) – Training dataset.

  • num_batches (int) – The number of batches to be used for reestimation.

  • forward_fn (Optional[Callable[[Module, Any], Any]]) – Optional adapter function that performs forward pass given a model and a input batch yielded from the data loader.

Return type:

Handle

Returns:

Handle that undos the effect of BN reestimation upon handle.remove().

Top-level API

aimet_tensorflow.keras.bn_reestimation.reestimate_bn_stats(model, bn_re_estimation_dataset, bn_num_batches=100)[source]

top level api for end user directly call

Parameters:
  • model (Model) – tf.keras.Model

  • bn_re_estimation_dataset (DatasetV2) – Training dataset

  • bn_num_batches (int) – The number of batches to be used for reestimation

Return type:

Handle

Returns:

Handle that undos the effect of BN reestimation upon handle.remove()

aimet_tensorflow.keras.batch_norm_fold.fold_all_batch_norms_to_scale(sim)[source]

Fold all batch_norm layers in a model into the quantization scale parameter of the corresponding conv layers

Parameters:

sim (QuantizationSimModel) – QuantizationSimModel to be folded

Return type:

List[Tuple[QcQuantizeWrapper, QcQuantizeWrapper]]

Returns:

A list of pairs of layers [(Conv/Linear, BN layer that got folded)]