Sequential MSE

Context

Sequential MSE (SeqMSE) is a quantization technique that optimizes the parameter encodings of each layer of a model individually to minimize the difference between the layer’s original and quantized outputs. Rather than relying on training, SeqMSE uses a search-based approach, offering several benefits:

  • It requires only a small amount of unlabeled data

  • It approximates the global minimum without getting trapped in local minima

  • It is robust to overfitting

Workflow

Prerequisites

To use SeqMSE, you must have the following:

  • A pre-trained PyTorch or ONNX model

  • A set of representative input samples for the model

Procedure

Setup

import torch
import torchvision
from torchvision import transforms

# Load the model
# General setup that can be changed as needed
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval().to(device)
DATASET_ROOT = ... # Set your path to imagenet dataset root directory
BATCH_SIZE = 32
NUM_CALIBRATION_SAMPLES = 128

preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

imagenet_data = torchvision.datasets.ImageNet(
    DATASET_ROOT,
    split="val",
    transform=preprocess    
)

dataloader = torch.utils.data.DataLoader(
    imagenet_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)
import os
import onnxruntime as ort
import onnx
import torch
from torchvision import transforms
import torchvision

# Load the model
pt_model = torchvision.models.mobilenet_v2(pretrained=True)
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)

# Modify file_path as you wish
file_path = os.path.join(".", f"mobilenet_v2.onnx")
torch.onnx.export(
    pt_model,
    (dummy_input,),
    file_path,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
)
# Load exported ONNX model
model = onnx.load_model(file_path)

# Choose providers
if "CUDAExecutionProvider" in ort.get_available_providers():
    providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
    providers = ["CPUExecutionProvider"]

DATASET_ROOT = ... # Set your path to imagenet dataset root directory
BATCH_SIZE = 32
NUM_CALIBRATION_SAMPLES = 128

preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

imagenet_data = torchvision.datasets.ImageNet(
    DATASET_ROOT,
    split="val",
    transform=preprocess    
)

dataloader = torch.utils.data.DataLoader(
    imagenet_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

Step 1

Create a QuantizationSimModel object for the model.

from aimet_torch.quantsim import QuantizationSimModel

dummy_input = torch.randn(1, 3, 224, 224).to(device)
sim = QuantizationSimModel(model,
                           dummy_input=dummy_input,
                           default_param_bw=4,
                           default_output_bw=8)
# Create the QuantizationSimModel
import aimet_onnx

sim = aimet_onnx.QuantizationSimModel(
    model,
    param_type=aimet_onnx.int4,
    activation_type=aimet_onnx.int8,
    providers=providers
)

Step 2

Apply SeqMSE to find optimized parameter encodings for supported layer types.

import itertools
from aimet_torch.seq_mse import  apply_seq_mse, SeqMseParams

# Get unlabeled data
num_batches = NUM_CALIBRATION_SAMPLES // BATCH_SIZE
unlabeled_data = [data[0] for data in itertools.islice(dataloader, num_batches)]

# Configure SeqMSE parameters
params = SeqMseParams(num_batches=num_batches,
                      num_candidates=20)

# Find and freeze optimal encodings candidate for parameters of supported layer(s)/operations(s).
apply_seq_mse(sim=sim, data_loader=unlabeled_data, num_candidates=20)
import itertools

# Get unlabeled onnx data
input_name = model.graph.input[0].name
num_batches = NUM_CALIBRATION_SAMPLES // BATCH_SIZE
unlabeled_data = [{input_name: data.numpy()} for data, _ in itertools.islice(dataloader, num_batches)]

# Apply SeqMSE to the sim
aimet_onnx.apply_seq_mse(sim, unlabeled_data)

Step 3

Compute encodings for remaining uninitialized quantizers.

@torch.no_grad()
def forward_pass(model: torch.nn.Module):
    for batch_idx, (images, _) in enumerate(dataloader):
        if batch_idx >= num_batches:
            break
        model(images.to(device))

# compute encodings for all activations and parameters of uninitialized layers.
sim.compute_encodings(forward_pass)
sim.compute_encodings(unlabeled_data)

Step 4

Evaluate the quantized model.

# Determine simulated quantized accuracy
from tqdm import tqdm

correct_predictions = 0
total_samples = 0
for inputs, labels in tqdm(dataloader):
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = sim.model(inputs)
    _, pred_labels = torch.max(outputs, 1)
    correct_predictions += torch.sum(pred_labels == labels)
    total_samples += labels.shape[0]

accuracy = correct_predictions / total_samples
print(f"Quantized accuracy: {accuracy}")
from tqdm import tqdm
import numpy as np

correct_predictions = 0
total_samples = 0
for inputs, labels in tqdm(dataloader):
    inputs, labels = inputs.numpy(), labels.numpy()
    output, = sim.session.run(None, {input_name: inputs})
    pred_labels = np.argmax(output, axis=1)
    correct_predictions += np.sum(pred_labels == labels)
    total_samples += labels.shape[0]

accuracy = correct_predictions / total_samples
print(f"Quantized accuracy: {accuracy}")

Step 5

If the resulting quantized accuracy is satisfactory, export the model.

# Export the model for on-target inference.
sim.export(path=".", filename_prefix="quantized_mobilenet_v2", dummy_input=dummy_input.cpu())
sim.export(path=".", filename_prefix="quantized_mobilenet_v2")

API

Top level APIs

aimet_torch.seq_mse.apply_seq_mse(*args, **kwargs)[source]

Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings.

1 Disable all input/output quantizers, param quantizers of non-supported modules 2 Find and feeze optimal parameter encodings candidate for remaining supported modules 3 Re-enable disabled quantizers from step 1

Example userflow: model = Model().eval() sim = QuantizationSimModel(…) apply_seq_mse(…) sim.compute_encodings(…) [compute encodings for all activations and parameters of non-supported modules] sim.export(…)

NOTE: modules in modules_to_exclude won’t be quantized and skipped when applying sequential MSE.

Parameters:
  • sim – QuantizationSimModel object

  • data_loader – Data loader

  • num_candidates – Number of candidate encodings to evaluate for each layer

  • forward_fn – callback function to perform forward pass given accepts model, inputs

  • modules_to_exclude – List of supported type module(s) to exclude when applying Sequential MSE

  • checkpoints_config – Config files to split fp32/quant model by checkpoints to speedup activations sampling

Sequential MSE parameters

class aimet_torch.seq_mse.SeqMseParams(num_batches, num_candidates=20, inp_symmetry='symqt', loss_fn='mse', forward_fn=<function default_forward_fn>)[source]

Sequential MSE parameters

Parameters:
  • num_batches (Optional[int]) – Number of batches.

  • num_candidates (int) – Number of candidates to perform grid search. Default 20.

  • inp_symmetry (str) – Input symmetry. Available options are ‘asym’, ‘symfp’ and ‘symqt’. Default ‘symqt’.

  • loss_fn (str) – Loss function. Available options are ‘mse’, ‘l1’ and ‘sqnr’. Default ‘mse’.

  • forward_fn (Callable) – Optional adapter function that performs forward pass given a model and inputs yielded from the data loader. The function expects model as first argument and inputs to model as second argument.

forward_fn(inputs)

Default forward function. :type model: :param model: pytorch model :type inputs: :param inputs: model inputs

get_loss_fn()[source]

Returns loss function

Return type:

Callable

Top level APIs

aimet_onnx.apply_seq_mse(sim, inputs, num_candidates=20, nodes_to_exclude=None)[source]

Sequentially optimizes the QuantizationSimModel’s weight encodings to reduce MSE loss at layer outputs.

Parameters:
  • sim (QuantizationSimModel) – QuantizationSimModel instance to optimize

  • inputs (Collection[Dict[str, np.ndarray]]) – The set of input samples to use during optimization

  • num_candidates (int) – Number of encoding candidates to sweep for each weight. Decreasing this can reduce runtime but may lead to lower accuracy.

  • nodes_to_exclude (Optional[List[str]]) – List of supported node name(s) to exclude from sequential MSE optimization