Sequential MSE

Context

Sequential MSE (SeqMSE) is a method that searches for optimal quantization encodings per operation (i.e. per layer) such that the difference between the original output activation and the corresponding quantization-aware output activation is minimized.

Since SeqMSE is search-based rather than learning-based, it has several advantages:

  • It requires only a small amount of calibration data

  • It approximates the global minimum without getting trapped in local minima

  • It is robust to overfitting

Workflow

Prerequisites

To use Seq MSE, you must:

  • Use PyTorch or ONNX. Sequential MSE does not support TensorFlow models

  • Load a pre-trained model

  • Create a training or validation dataloader for the model

Procedure

Setup

import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from evaluate import evaluator

# Load the model
# General setup that can be changed as needed
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval().to(device)
num_batches = 32
data = load_dataset('imagenet-1k', streaming=True, split="train")
data_loader = DataLoader(data, batch_size=num_batches, num_workers=4)

Not supported.

tbd.

Step 1

Use AIMET’s quantization simulation to create a QuantSimModel object.

from aimet_common.defs import QuantScheme
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_torch.quantsim import QuantizationSimModel

dummy_input = torch.randn(1, 3, 224, 224).to(device)
sim = QuantizationSimModel(model,
                           dummy_input=dummy_input,
                           quant_scheme=QuantScheme.training_range_learning_with_tf_init,
                           default_param_bw=4,
                           default_output_bw=8,
                           config_file=get_path_for_per_channel_config())

Not supported.

tbd.

Step 2

Apply SeqMSE to decide optimal quantization encodings for parameters of supported layers and operations.

# Find and freeze optimal encodings candidate for parameters of supported layer(s)/operations(s).
from aimet_torch.seq_mse import  apply_seq_mse, SeqMseParams
params = SeqMseParams(num_batches=num_batches,
                      num_candidates=20,
                      inp_symmetry='symqt',
                      loss_fn='mse')

apply_seq_mse(model=model, sim=sim, data_loader=data_loader, params=params)

Not supported.

tbd.

Step 3

Apply SeqMSE to compute encodings for remaining parameters of uninitialized layers and operations.

def forward_pass(model: torch.nn.Module):
    with torch.no_grad():
        for images, _ in data_loader:
            model(images)
# End of calibration callback

# Compute the Quantization Encodings
# compute encodings for all activations and parameters of uninitialized layer(s)/operations(s).
sim.compute_encodings(forward_pass)

Not supported.

tbd.

Step 4

Evaluate the quantized model using ImageClassificationEvaluator.

# Determine simulated quantized accuracy
evaluator = evaluator("image-classification")
accuracy = evaluator.compute(model_or_pipeline=model, data=data, metric="accuracy")

Not supported.

tbd.

Step 5

If the resulting quantized accuracy is satisfactory, export the model.

# Export the model for on-target inference.
# Export the model which saves pytorch model without any simulation nodes and saves encodings file for both
# activations and parameters in JSON format at provided path.
path = './'
filename = 'mobilenet'
sim.export(path=path, filename_prefix="quantized_" + filename, dummy_input=dummy_input.cpu())

API

Top level APIs

aimet_torch.seq_mse.apply_seq_mse(model, sim, data_loader, params, modules_to_exclude=None, checkpoints_config=None)

Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings.

1 Disable all input/output quantizers, param quantizers of non-supported modules 2 Find and feeze optimal parameter encodings candidate for remaining supported modules 3 Re-enable disabled quantizers from step 1

Example userflow: model = Model().eval() sim = QuantizationSimModel(…) apply_seq_mse(…) sim.compute_encodings(…) [compute encodings for all activations and parameters of non-supported modules] sim.export(…)

NOTE: 1) module reference passed to modules_to_exclude should be from FP32 model. 2) module from modules_to_exclude won’t be quantized and skipped when applying sequential MSE. 3) Except finding param encodings for supported modules, config JSON file will be respected and final state of sim will be unchanged.

Parameters:
  • model (Module) – Original fp32 model

  • sim (QuantizationSimModel) – Corresponding QuantizationSimModel object

  • data_loader (DataLoader) – Data loader

  • params (SeqMseParams) – Sequential MSE parameters

  • modules_to_exclude (Optional[List[Module]]) – List of supported type module(s) to exclude when applying Sequential MSE

  • checkpoints_config (Optional[str]) – Config files to split fp32/quant model by checkpoints to speedup activations sampling

Sequential MSE parameters

class aimet_torch.seq_mse.SeqMseParams(num_batches, num_candidates=20, inp_symmetry='symqt', loss_fn='mse', forward_fn=<function default_forward_fn>)[source]

Sequential MSE parameters

Parameters:
  • num_batches (int) – Number of batches.

  • num_candidates (int) – Number of candidates to perform grid search. Default 20.

  • inp_symmetry (str) – Input symmetry. Available options are ‘asym’, ‘symfp’ and ‘symqt’. Default ‘symqt’.

  • loss_fn (str) – Loss function. Available options are ‘mse’, ‘l1’ and ‘sqnr’. Default ‘mse’.

  • forward_fn (Callable) – Optional adapter function that performs forward pass given a model and inputs yielded from the data loader. The function expects model as first argument and inputs to model as second argument.

forward_fn(inputs)

Default forward function. :type model: :param model: pytorch model :type inputs: :param inputs: model inputs

get_loss_fn()[source]

Returns loss function

Return type:

Callable

Not supported.

Top level APIs

Sequential MSE parameters