Quickstart Guide
In this tutorial, we will go through the end-to-end process of using AIMET and PyTorch to create, calibrate, and export a simple quantized model. Note that this is intended to show the most basic workflow in AIMET. It is not meant to demonstrate the most state-of-the-art techniques available in AIMET.
Overall flow
Define the basic floating-point PyTorch model, training, and eval loops
Prepare the trained model for quantization
Create quantization simulation (quantsim) model in AIMET to simulate the effects of quantization
Calibrate the quantsim model on training data and evaluate the quantized accuracy
Fine-tune the quantized model to improve the quantized accuracy
Export the quantized model
PyTorch prerequisites
To see clearly what happens inside AIMET, let’s first start with some simple PyTorch code for defining, training, and evaluating a model. The code below is adapted from PyTorch’s basic optimization tutorial. Note that AIMET does not have any special requirement on what these training/eval loops look like.
import torch
import torchvision
import torch.nn.functional as F
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 1) Start with some data loaders to train, evaluate, and calibrate the model
cifar10_train_data = torchvision.datasets.FashionMNIST('/tmp/cifar10', train=True, download=True, transform=torchvision.transforms.ToTensor())
cifar10_test_data = torchvision.datasets.FashionMNIST('/tmp/cifar10', train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(cifar10_train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(cifar10_train_data, batch_size=128, shuffle=True)
# 2) Define a simple model to train on this dataset
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, padding=1, stride=2)
self.bn_1 = torch.nn.BatchNorm2d(128)
self.conv2 = torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, stride=2)
self.bn_2 = torch.nn.BatchNorm2d(256)
self.linear = torch.nn.Linear(in_features=7*7*256, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(self.bn_1(x))
x = self.conv2(x)
x = F.relu(self.bn_2(x))
x = self.linear(x.view(x.shape[0], -1))
return F.softmax(x, dim=-1)
# 3) Define an evaluation loop for the model
def evaluate(model, data_loader):
model.eval()
correct = total = 0
for x, y in data_loader:
x, y = x.to(device), y.to(device)
output = model(x)
correct += (torch.argmax(output, dim=1) == y).sum()
total += x.shape[0]
accuracy = correct / total * 100.
return accuracy
Now, let’s instantiate a network and train for a few epochs on our dataset to establish a baseline floating-point model
# Create a model
model = Network()
# Send the model to the desired device (optional)
model.to(device)
# Define some loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# Train for 4 epochs
model.train()
for epoch in range(4):
for batch_idx, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
output = model(x)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Evaluate the floating-point model
model.eval()
fp_accuracy = evaluate(model, test_loader)
print(f"Floating point accuracy: {fp_accuracy}")
Floating point accuracy: 91.70999908447266
Prepare the floating point model for quantization
Before we can (accurately) simulate quantization, there are a couple important steps to take care of:
1) Model preparation
AIMET’s quantization simulation tool (QuantizationSimModel
) expects the floating point model to conform to some
specific guidelines. For example, QuantizationSimModel
is only able to quantize math operations performed by
torch.nn.Module
objects, whereas torch.nn.functional
calls will be (incorrectly) ignored.
If we look back at our previous model definition, we see it calls F.relu()
and F.softmax()
in the forward
function. Does this mean we need to completely redefine our model to use AIMET? Thankfully, no. AIMET provides the
model_preparer
API to transform our incompatible model into a new fully-compatible model.
from aimet_torch import model_preparer
prepared_model = model_preparer.prepare_model(model)
print(prepared_model)
# Note: This transformation should not change the model's forward function at all
fp_accuracy_prepared = evaluate(prepared_model, test_loader)
assert fp_accuracy_prepared == fp_accuracy
2024-05-07 14:39:22,747 - root - INFO - AIMET
2024-05-07 14:39:22,806 - ModelPreparer - INFO - Functional : Adding new module for node: {module_relu}
2024-05-07 14:39:22,806 - ModelPreparer - INFO - Functional : Adding new module for node: {module_relu_1}
2024-05-07 14:39:22,806 - ModelPreparer - INFO - Functional : Adding new module for node: {module_softmax}
GraphModule(
(conv1): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn_1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn_2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(linear): Linear(in_features=12544, out_features=10, bias=True)
(module_relu): ReLU()
(module_relu_1): ReLU()
(module_softmax): Softmax(dim=-1)
)
def forward(self, x):
conv1 = self.conv1(x); x = None
bn_1 = self.bn_1(conv1); conv1 = None
module_relu = self.module_relu(bn_1); bn_1 = None
conv2 = self.conv2(module_relu); module_relu = None
bn_2 = self.bn_2(conv2); conv2 = None
module_relu_1 = self.module_relu_1(bn_2); bn_2 = None
getattr_1 = module_relu_1.shape
getitem = getattr_1[0]; getattr_1 = None
view = module_relu_1.view(getitem, -1); module_relu_1 = getitem = None
linear = self.linear(view); view = None
module_softmax = self.module_softmax(linear); linear = None
return module_softmax
# To see more debug info, please use `graph_module.print_readable()`
Note how the prepared model now contains distinct modules for the relu()
and softmax()
operations.
2) BatchNorm fold
When models are executed in a quantized runtime, batchnorm layers are typically folded into the weight and bias of
an adjacent convolution layer whenever possible in order to remove unnecessary computations. To accurately simulate
inference in these runtimes, it is generally a good idea to perform this batchnorm folding on the floating point model
before applying quantization. AIMET provides the batch_norm_fold
tool to do this.
from aimet_torch import batch_norm_fold
sample_input, _ = next(iter(train_loader))
batch_norm_fold.fold_all_batch_norms(prepared_model, input_shapes=sample_input.shape)
print(prepared_model)
GraphModule(
(conv1): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn_1): Identity()
(conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn_2): Identity()
(linear): Linear(in_features=12544, out_features=10, bias=True)
(module_relu): ReLU()
(module_relu_1): ReLU()
(module_softmax): Softmax(dim=-1)
)
def forward(self, x):
conv1 = self.conv1(x); x = None
bn_1 = self.bn_1(conv1); conv1 = None
module_relu = self.module_relu(bn_1); bn_1 = None
conv2 = self.conv2(module_relu); module_relu = None
bn_2 = self.bn_2(conv2); conv2 = None
module_relu_1 = self.module_relu_1(bn_2); bn_2 = None
getattr_1 = module_relu_1.shape
getitem = getattr_1[0]; getattr_1 = None
view = module_relu_1.view(getitem, -1); module_relu_1 = getitem = None
linear = self.linear(view); view = None
module_softmax = self.module_softmax(linear); linear = None
return module_softmax
# To see more debug info, please use `graph_module.print_readable()`
Note that the model now has Identity
(passthrough) layers where it previously had BatchNorm2d
layers. Like the
model_preparer
step, this operation should not impact the model’s accuracy.
Quantize the model
Now, we are ready to use AIMET’s QuantizationSimModel
to simulate quantizing the floating point model. This
involves two steps:
Add quantizers to simulate quantization noise during the model’s forward pass
Calibrate the quantizer encodings (e.g., min/max ranges) on some sample inputs
Calibration is necessary to determine the range of values each activation quantizer is likely to encounter in the model’s forward pass, and should therefore be able to represent. Theoretically, we could pass the entire training dataset through the model for calibration, but in practice we usually only need about 500-1000 representative samples to accurately estimate the ranges.
import aimet_torch.v2 as aimet
from aimet_torch.v2 import quantsim
# QuantizationSimModel will convert each nn.Module in prepared_model into a quantized equivalent module and configure the module's quantizers
# In this case, we will quantize all parameters to 4 bits and all activations to 8 bits.
sim = quantsim.QuantizationSimModel(prepared_model,
dummy_input=sample_input.to(device),
default_output_bw=8, # Simulate 8-bit activations
default_param_bw=4) # Simulate 4-bit weights
# Inside the compute_encodings context, quantizers will observe the statistics of the activations passing through them. These statistics will be used
# to compute properly calibrated encodings upon exiting the context.
with aimet.nn.compute_encodings(sim.model):
for idx, (x, _) in enumerate(train_loader):
x = x.to(device)
sim.model(x)
if idx >= 10:
break
# Compare the accuracy before and after quantization:
quantized_accuracy = evaluate(sim.model, test_loader)
print(sim.model)
print(f"Floating point model accuracy: {fp_accuracy} %\n"
f"Quantized model accuracy: {quantized_accuracy} %")
GraphModule(
(conv1): QuantizedConv2d(
1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
(param_quantizers): ModuleDict(
(weight): QuantizeDequantize(shape=[1], bitwidth=4, symmetric=True)
(bias): None
)
(input_quantizers): ModuleList(
(0): QuantizeDequantize(shape=[1], bitwidth=8, symmetric=False)
)
(output_quantizers): ModuleList(
(0): None
)
)
(bn_1): Identity()
(module_relu): QuantizedReLU(
(param_quantizers): ModuleDict()
(input_quantizers): ModuleList(
(0): None
)
(output_quantizers): ModuleList(
(0): QuantizeDequantize(shape=[1], bitwidth=8, symmetric=False)
)
)
(conv2): QuantizedConv2d(
128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
(param_quantizers): ModuleDict(
(weight): QuantizeDequantize(shape=[1], bitwidth=4, symmetric=True)
(bias): None
)
(input_quantizers): ModuleList(
(0): None
)
(output_quantizers): ModuleList(
(0): None
)
)
(bn_2): Identity()
(module_relu_1): QuantizedReLU(
(param_quantizers): ModuleDict()
(input_quantizers): ModuleList(
(0): None
)
(output_quantizers): ModuleList(
(0): QuantizeDequantize(shape=[1], bitwidth=8, symmetric=False)
)
)
(linear): QuantizedLinear(
in_features=12544, out_features=10, bias=True
(param_quantizers): ModuleDict(
(weight): QuantizeDequantize(shape=[1], bitwidth=4, symmetric=True)
(bias): None
)
(input_quantizers): ModuleList(
(0): None
)
(output_quantizers): ModuleList(
(0): QuantizeDequantize(shape=[1], bitwidth=8, symmetric=False)
)
)
(module_softmax): QuantizedSoftmax(
dim=-1
(param_quantizers): ModuleDict()
(input_quantizers): ModuleList(
(0): None
)
(output_quantizers): ModuleList(
(0): QuantizeDequantize(shape=[1], bitwidth=8, symmetric=False)
)
)
)
def forward(self, x):
conv1 = self.conv1(x); x = None
bn_1 = self.bn_1(conv1); conv1 = None
module_relu = self.module_relu(bn_1); bn_1 = None
conv2 = self.conv2(module_relu); module_relu = None
bn_2 = self.bn_2(conv2); conv2 = None
module_relu_1 = self.module_relu_1(bn_2); bn_2 = None
getattr_1 = module_relu_1.shape
getitem = getattr_1[0]; getattr_1 = None
view = module_relu_1.view(getitem, -1); module_relu_1 = getitem = None
linear = self.linear(view); view = None
module_softmax = self.module_softmax(linear); linear = None
return module_softmax
# To see more debug info, please use `graph_module.print_readable()`
Floating point model accuracy: 91.70999908447266 %
Quantized model accuracy: 91.1500015258789 %
Here, we can see that sim.model
is nothing more than the prepared_model
with every layer replaced with a
quantized version of the layer. The quantization behavior of each module is determined by the configuration of its
held quantizers.
For example, we can see that sim.model.conv2
has a 4-bit weight quantizer and an 8-bit output quantizer as specified
during construction. We will discuss more advanced ways to configure these quantizers to optimize performance and
accuracy in a later tutorial.
Fine-tune the model with quantization aware training
If we’re not satisfied with our accuracy after applying quantization, there are some steps we can take to further optimize the quantized accuracy. One such step is quantization aware training (QAT), during which the model is trained with the fake-quantization ops present.
Let’s repeat our floating-point training loop for one more epoch, but this time use the quantized model.
# Define some loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(sim.model.parameters(), lr=1e-4)
# Train for one more epoch on the quantsim model
for epoch in range(1):
for batch_idx, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
output = sim.model(x)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Compare the accuracy before and after QAT:
post_QAT_accuracy = evaluate(sim.model, test_loader)
print(f"Original quantized model accuracy: {quantized_accuracy} %\n"
f"Post-QAT model accuracy: {post_QAT_accuracy} %")
Original quantized model accuracy: 91.1500015258789 %
Post-QAT model accuracy: 92.05333709716797 %
Export the quantsim model
Now that we are happy with our quantized model’s accuracy, we are ready to export the model with its quantization parameters.
export_path = "/tmp/"
model_name = "fashion_mnist_model"
sample_input, _ = next(iter(train_loader))
sim.export(export_path, model_name, dummy_input=sample_input)
This export method will save the model with quantization nodes removed, along with an encodings file containing quantization parameters for each activation and weight tensor in the model. These artifacts can then be sent to a quantized runtime such as Qualcomm® Neural Processing SDK.