Sequential MSE

Context

Sequential MSE (SeqMSE) is a quantization technique that optimizes the parameter encodings of each layer of a model individually to minimize the difference between the layer’s original and quantized outputs. Rather than relying on training, SeqMSE uses a search-based approach, offering several benefits:

  • It requires only a small amount of unlabeled data

  • It approximates the global minimum without getting trapped in local minima

  • It is robust to overfitting

Workflow

Prerequisites

To use SeqMSE, you must have the following:

  • A pre-trained PyTorch or ONNX model (TensorFlow is not supported)

  • A set of representative input samples for the model

Procedure

Setup

import torch
import torchvision
from torchvision import transforms

# Load the model
# General setup that can be changed as needed
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval().to(device)
DATASET_ROOT = ... # Set your path to imagenet dataset root directory
BATCH_SIZE = 32
NUM_CALIBRATION_SAMPLES = 128

preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

imagenet_data = torchvision.datasets.ImageNet(
    DATASET_ROOT,
    split="val",
    transform=preprocess    
)

dataloader = torch.utils.data.DataLoader(
    imagenet_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

Not supported.

import os
import onnxruntime as ort
import onnx
import torch
from torchvision import transforms
import torchvision

# Load the model
pt_model = torchvision.models.mobilenet_v2(pretrained=True)
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)

# Modify file_path as you wish
file_path = os.path.join(".", f"mobilenet_v2.onnx")
torch.onnx.export(
    pt_model,
    (dummy_input,),
    file_path,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
)
# Load exported ONNX model
model = onnx.load_model(file_path)

# Choose providers
if "CUDAExecutionProvider" in ort.get_available_providers():
    providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
    providers = ["CPUExecutionProvider"]

DATASET_ROOT = ... # Set your path to imagenet dataset root directory
BATCH_SIZE = 32
NUM_CALIBRATION_SAMPLES = 128

preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

imagenet_data = torchvision.datasets.ImageNet(
    DATASET_ROOT,
    split="val",
    transform=preprocess    
)

dataloader = torch.utils.data.DataLoader(
    imagenet_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

Step 1

Create a QuantizationSimModel object for the model.

from aimet_torch.quantsim import QuantizationSimModel

dummy_input = torch.randn(1, 3, 224, 224).to(device)
sim = QuantizationSimModel(model,
                           dummy_input=dummy_input,
                           default_param_bw=4,
                           default_output_bw=8)

Not supported.

# Create the QuantizationSimModel
import aimet_onnx

sim = aimet_onnx.QuantizationSimModel(
    model,
    param_type=aimet_onnx.int4,
    activation_type=aimet_onnx.int8,
    providers=providers
)

Step 2

Apply SeqMSE to find optimized parameter encodings for supported layer types.

import itertools
from aimet_torch.seq_mse import  apply_seq_mse, SeqMseParams

# Get unlabeled data
num_batches = NUM_CALIBRATION_SAMPLES // BATCH_SIZE
unlabeled_data = [data[0] for data in itertools.islice(dataloader, num_batches)]

# Configure SeqMSE parameters
params = SeqMseParams(num_batches=num_batches,
                      num_candidates=20)

# Find and freeze optimal encodings candidate for parameters of supported layer(s)/operations(s).
apply_seq_mse(model=model, sim=sim, data_loader=unlabeled_data, params=params)

Not supported.

import itertools

# Get unlabeled onnx data
input_name = model.graph.input[0].name
num_batches = NUM_CALIBRATION_SAMPLES // BATCH_SIZE
unlabeled_data = [{input_name: data.numpy()} for data, _ in itertools.islice(dataloader, num_batches)]

# Apply SeqMSE to the sim
aimet_onnx.apply_seq_mse(sim, unlabeled_data)

Step 3

Compute encodings for remaining uninitialized quantizers.

@torch.no_grad()
def forward_pass(model: torch.nn.Module):
    for batch_idx, (images, _) in enumerate(dataloader):
        if batch_idx >= num_batches:
            break
        model(images.to(device))

# compute encodings for all activations and parameters of uninitialized layers.
sim.compute_encodings(forward_pass)

Not supported.

sim.compute_encodings(unlabeled_data)

Step 4

Evaluate the quantized model.

# Determine simulated quantized accuracy
from tqdm import tqdm

correct_predictions = 0
total_samples = 0
for inputs, labels in tqdm(dataloader):
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = sim.model(inputs)
    _, pred_labels = torch.max(outputs, 1)
    correct_predictions += torch.sum(pred_labels == labels)
    total_samples += labels.shape[0]

accuracy = correct_predictions / total_samples
print(f"Quantized accuracy: {accuracy}")

Not supported.

from tqdm import tqdm
import numpy as np

correct_predictions = 0
total_samples = 0
for inputs, labels in tqdm(dataloader):
    inputs, labels = inputs.numpy(), labels.numpy()
    output, = sim.session.run(None, {input_name: inputs})
    pred_labels = np.argmax(output, axis=1)
    correct_predictions += np.sum(pred_labels == labels)
    total_samples += labels.shape[0]

accuracy = correct_predictions / total_samples
print(f"Quantized accuracy: {accuracy}")

Step 5

If the resulting quantized accuracy is satisfactory, export the model.

# Export the model for on-target inference.
sim.export(path=".", filename_prefix="quantized_mobilenet_v2", dummy_input=dummy_input.cpu())

Not supported.

sim.export(path=".", filename_prefix="quantized_mobilenet_v2")

API

Top level APIs

aimet_torch.seq_mse.apply_seq_mse(model, sim, data_loader, params, modules_to_exclude=None, checkpoints_config=None)

Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings.

1 Disable all input/output quantizers, param quantizers of non-supported modules 2 Find and feeze optimal parameter encodings candidate for remaining supported modules 3 Re-enable disabled quantizers from step 1

Example userflow: model = Model().eval() sim = QuantizationSimModel(…) apply_seq_mse(…) sim.compute_encodings(…) [compute encodings for all activations and parameters of non-supported modules] sim.export(…)

NOTE: 1) module reference passed to modules_to_exclude should be from FP32 model. 2) module from modules_to_exclude won’t be quantized and skipped when applying sequential MSE. 3) Except finding param encodings for supported modules, config JSON file will be respected and final state of sim will be unchanged.

Parameters:
  • model (Module) – Original fp32 model

  • sim (QuantizationSimModel) – Corresponding QuantizationSimModel object

  • data_loader (DataLoader) – Data loader

  • params (SeqMseParams) – Sequential MSE parameters

  • modules_to_exclude (Optional[List[Module]]) – List of supported type module(s) to exclude when applying Sequential MSE

  • checkpoints_config (Optional[str]) – Config files to split fp32/quant model by checkpoints to speedup activations sampling

Sequential MSE parameters

class aimet_torch.seq_mse.SeqMseParams(num_batches, num_candidates=20, inp_symmetry='symqt', loss_fn='mse', forward_fn=<function default_forward_fn>)[source]

Sequential MSE parameters

Parameters:
  • num_batches (int) – Number of batches.

  • num_candidates (int) – Number of candidates to perform grid search. Default 20.

  • inp_symmetry (str) – Input symmetry. Available options are ‘asym’, ‘symfp’ and ‘symqt’. Default ‘symqt’.

  • loss_fn (str) – Loss function. Available options are ‘mse’, ‘l1’ and ‘sqnr’. Default ‘mse’.

  • forward_fn (Callable) – Optional adapter function that performs forward pass given a model and inputs yielded from the data loader. The function expects model as first argument and inputs to model as second argument.

forward_fn(inputs)

Default forward function. :type model: :param model: pytorch model :type inputs: :param inputs: model inputs

get_loss_fn()[source]

Returns loss function

Return type:

Callable

Not supported.

Top level APIs

aimet_onnx.apply_seq_mse(sim, inputs, num_candidates=20)[source]

Sequentially optimizes the QuantizationSimModel’s weight encodings to reduce MSE loss at layer outputs.

Parameters:
  • sim (QuantizationSimModel) – QuantizationSimModel instance to optimize

  • inputs (Collection[Dict[str, np.ndarray]]) – The set of input samples to use during optimization

  • num_candidates (int) – Number of encoding candidates to sweep for each weight. Decreasing this can reduce runtime but may lead to lower accuracy.