Sequential MSE¶
Context¶
Sequential MSE (SeqMSE) is a method that searches for optimal quantization encodings per operation (i.e. per layer) such that the difference between the original output activation and the corresponding quantization-aware output activation is minimized.
Since SeqMSE is search-based rather than learning-based, it possesses several advantages:
It requires only a small amount of calibration data,
It approximates the global minimum without getting trapped in local minima, and
It is robust to overfitting.
Workflow¶
Prerequisites¶
To use Seq MSE, you must:
Load a pre-trained model
Create a training or validation dataloader for the model.
Code example¶
Setup¶
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from evaluate import evaluator
# Load the model
# General setup that can be changed as needed
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval().to(device)
num_batches = 32
data = load_dataset('imagenet-1k', streaming=True, split="train")
data_loader = DataLoader(data, batch_size=num_batches, num_workers=4)
Step 1¶
Create QuantizationSimModel object (simulate quantization through AIMET’s QuantSim).
from aimet_common.defs import QuantScheme
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_torch.quantsim import QuantizationSimModel
dummy_input = torch.randn(1, 3, 224, 224).to(device)
sim = QuantizationSimModel(model,
dummy_input=dummy_input,
quant_scheme=QuantScheme.training_range_learning_with_tf_init,
default_param_bw=4,
default_output_bw=8,
config_file=get_path_for_per_channel_config())
Step 2¶
Apply Seq MSE to decide optimal quantization encodings for parameters of supported layer(s)/operation(s).
# Find and freeze optimal encodings candidate for parameters of supported layer(s)/operations(s).
from aimet_torch.seq_mse import apply_seq_mse, SeqMseParams
params = SeqMseParams(num_batches=num_batches,
num_candidates=20,
inp_symmetry='symqt',
loss_fn='mse')
apply_seq_mse(model=model, sim=sim, data_loader=data_loader, params=params)
Step 3¶
Compute encodings for all activations and remaining parameters of uninitialized layer(s)/operations(s).
def forward_pass(model: torch.nn.Module):
with torch.no_grad():
for images, _ in data_loader:
model(images)
# End of calibration callback
# Compute the Quantization Encodings
# compute encodings for all activations and parameters of uninitialized layer(s)/operations(s).
sim.compute_encodings(forward_pass)
Step 4¶
Evaluate the quantized model using ImageClassificationEvaluator
.
# Determine simulated quantized accuracy
evaluator = evaluator("image-classification")
accuracy = evaluator.compute(model_or_pipeline=model, data=data, metric="accuracy")
Step 4¶
If resulted quantized accuracy is satisfactory, export the model.
# Export the model for on-target inference.
# Export the model which saves pytorch model without any simulation nodes and saves encodings file for both
# activations and parameters in JSON format at provided path.
path = './'
filename = 'mobilenet'
sim.export(path=path, filename_prefix="quantized_" + filename, dummy_input=dummy_input.cpu())
API¶
Top level APIs
- aimet_torch.seq_mse.apply_seq_mse(model, sim, data_loader, params, modules_to_exclude=None, checkpoints_config=None)¶
Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings.
1 Disable all input/output quantizers, param quantizers of non-supported modules 2 Find and feeze optimal parameter encodings candidate for remaining supported modules 3 Re-enable disabled quantizers from step 1
Example userflow: model = Model().eval() sim = QuantizationSimModel(…) apply_seq_mse(…) sim.compute_encodings(…) [compute encodings for all activations and parameters of non-supported modules] sim.export(…)
NOTE: 1) module reference passed to modules_to_exclude should be from FP32 model. 2) module from modules_to_exclude won’t be quantized and skipped when applying sequential MSE. 3) Except finding param encodings for supported modules, config JSON file will be respected and final state of sim will be unchanged.
- Parameters:
model (
Module
) – Original fp32 modelsim (
QuantizationSimModel
) – Corresponding QuantizationSimModel objectdata_loader (
DataLoader
) – Data loaderparams (
SeqMseParams
) – Sequential MSE parametersmodules_to_exclude (
Optional
[List
[Module
]]) – List of supported type module(s) to exclude when applying Sequential MSEcheckpoints_config (
Optional
[str
]) – Config files to split fp32/quant model by checkpoints to speedup activations sampling
Sequential MSE parameters
- class aimet_torch.seq_mse.SeqMseParams(num_batches, num_candidates=20, inp_symmetry='symqt', loss_fn='mse', forward_fn=<function default_forward_fn>)[source]¶
Sequential MSE parameters
- Parameters:
num_batches (
int
) – Number of batches.num_candidates (
int
) – Number of candidates to perform grid search. Default 20.inp_symmetry (
str
) – Input symmetry. Available options are ‘asym’, ‘symfp’ and ‘symqt’. Default ‘symqt’.loss_fn (
str
) – Loss function. Available options are ‘mse’, ‘l1’ and ‘sqnr’. Default ‘mse’.forward_fn (
Callable
) – Optional adapter function that performs forward pass given a model and inputs yielded from the data loader. The function expects model as first argument and inputs to model as second argument.
- forward_fn(inputs)¶
Default forward function. :type model: :param model: pytorch model :type inputs: :param inputs: model inputs