Quantization-Aware Training (QAT)ΒΆ
This notebook provides a working example of Quantization-Aware Training (QAT) of a Vision Transformer (ViT) using AIMET. QAT improves the accuracy of quantized models by fine-tuning model weights and/or quantization parameters while simulating quantization effects during training.
π Before Getting Started: Prepare the ImageNet DatasetΒΆ
To run this notebook successfully, you need to have the ImageNet dataset downloaded and organized in a specific directory structure. AIMET expects the dataset to be accessible via the environment variable IMAGENET_DIR, and structured as follows:
IMAGENET_DIR/
βββ train/
β βββ n01440764/
β β βββ image1.JPEG
β β βββ image2.JPEG
β β βββ ...
β βββ n01443537/
β β βββ ...
β βββ ...
βββ val/
β βββ n01440764/
β β βββ image1.JPEG
β β βββ ...
β βββ ...
βββ test/
β βββ ...
For more information, see torchvision.datasets.ImageFolder and torchvision.datasets.DatasetFolder
[ ]:
import os
os.environ["IMAGENET_DIR"] = "/path/to/imagenet" # TODO: Overwrite this path with your local imagenet directory
1. Set Random Seeds and Instantiate ImageNet Data LoadersΒΆ
To ensure reproducibility, we set random seeds for Python, NumPy, and PyTorch. We then define a function to load the ImageNet dataset with standard preprocessing steps.
[ ]:
import random
import numpy as np
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
# Set device to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Set random seeds for reproducibility
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
random.seed(1234)
np.random.seed(1234)
def imagenet_dataset(split: str) -> Dataset:
# Load ImageNet directory from environment variable
IMAGENET_DIR = os.getenv("IMAGENET_DIR")
if not IMAGENET_DIR:
raise RuntimeError(
"Environment variable 'IMAGENET_DIR' has not been set. "
"Please set this variable to the path where ImageNet dataset is downloaded "
"and organized in the following directory structure:\n\n"
"<IMAGENET_DIR>\n"
" βββ test\n"
" βββ train\n"
" βββ val\n"
)
# Define preprocessing transformations
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
return ImageFolder(root=os.path.join(IMAGENET_DIR, split), transform=transform)
# Create DataLoaders for training and testing
test_data_loader = DataLoader(imagenet_dataset("test"), batch_size=128, shuffle=False)
train_data_loader = DataLoader(imagenet_dataset("train"), batch_size=32, shuffle=True)
2. Create W4A8 QuantizationSimModel with Vision Transformer (ViT)ΒΆ
We load a pretrained ViT model and simulate quantization using AIMETβs QuantizationSimModel. This simulates 4-bit weights and 8-bit activations (W4A8), targeting Qualcommβs Hexagon NPU (HTP V81).
[ ]:
import aimet_torch
from torchvision.models import vit_b_16
model = vit_b_16(weights="IMAGENET1K_V1").to(device=device).eval()
dummy_input, _ = next(iter(train_data_loader))
dummy_input = dummy_input.to(device=device)
# Create QuantizationSimModel with W4A8 configuration
sim = aimet_torch.QuantizationSimModel(
model,
dummy_input,
default_param_bw=4, # 4-bit weights
default_output_bw=8, # 8-bit activations
in_place=True,
config_file="htp_v81", # AIMET config for Hexagon NPU with HTP V81
)
# Compute quantization encodings using a few training batches
with torch.no_grad(), aimet_torch.nn.compute_encodings(sim.model):
for i, (images, _) in enumerate(train_data_loader):
if i == 8:
break
_ = sim.model(images.to(device=device))
3. Evaluate Initial Accuracy Before QATΒΆ
Before applying QAT, we evaluate the modelβs accuracy in two scenarios:
Floating-point accuracy: by temporarily removing all quantizers.
Fake-quantized accuracy: using the quantized simulation model.
This helps establish a baseline to compare the impact of QAT later.
[ ]:
@torch.no_grad()
def evaluate(model: torch.nn.Module, data_loader: DataLoader):
from tqdm import tqdm
top1 = top5 = 0.0
n_images = 0
pbar = tqdm(data_loader)
for images, labels in pbar:
images = images.to(device=device)
labels = labels.unsqueeze(-1).to(device=device)
logits = model(images)
top1 += (logits.topk(1).indices == labels).sum()
top5 += (logits.topk(5).indices == labels).sum()
n_images += images.size(0)
top1_accuracy = top1 / n_images
top5_accuracy = top5 / n_images
pbar.set_description(
f"Top-1: {top1_accuracy * 100:.2f}%, Top-5: {top5_accuracy * 100:.2f}%"
)
top1_accuracy = top1 / n_images
top5_accuracy = top5 / n_images
return top1_accuracy, top5_accuracy
from aimet_torch.v2.utils import remove_all_quantizers
# Evaluate floating-point accuracy
with remove_all_quantizers(sim.model):
top1, top5 = evaluate(sim.model, test_data_loader)
print("FP Accuracy:")
print(f" * Top-1: {top1 * 100:.2f}%")
print(f" * Top-5: {top5 * 100:.2f}%")
# Evaluate fake-quantized accuracy before QAT
top1, top5 = evaluate(sim.model, test_data_loader)
print("Fake-quantized Accuracy (before QAT):")
print(f" * Top-1: {top1 * 100:.2f}%")
print(f" * Top-5: {top5 * 100:.2f}%")
4. Run QAT and Evaluate Post-QAT AccuracyΒΆ
In this notebook, we perform QAT by training only the quantization parameters (not the base model weights). The model is trained using the AdamW optimizer and evaluated after 2000 iterations.
[ ]:
def train(model: torch.nn.Module, data_loader: DataLoader, n_iter: int):
from tqdm import tqdm
from aimet_torch.quantization.affine import AffineQuantizerBase
# Train only quantization parameters
optimizer = torch.optim.AdamW(
params={
param
for module in model.modules()
for param in module.parameters()
if isinstance(module, AffineQuantizerBase)
},
lr=0.001,
)
pbar = tqdm(data_loader, total=n_iter)
for i, (images, labels) in enumerate(pbar):
if i == n_iter:
break
optimizer.zero_grad()
images = images.to(device=device)
labels = labels.to(device=device)
logits = model(images)
loss = torch.nn.functional.cross_entropy(logits, labels)
loss.backward()
optimizer.step()
pbar.set_description(f"loss: {loss:.2f}")
# Run QAT training
train(sim.model.train(), train_data_loader, n_iter=2000)
# Evaluate accuracy after QAT
top1, top5 = evaluate(sim.model.eval(), test_data_loader)
print("Fake-quantized Accuracy (after QAT):")
print(f" * Top-1: {top1 * 100:.2f}%")
print(f" * Top-5: {top5 * 100:.2f}%")
π ConclusionΒΆ
In this notebook, we demonstrated how to apply Quantization-Aware Training (QAT) to a pretrained Vision Transformer (ViT) model using AIMET. Starting from a floating-point baseline, we simulated quantization effects and observed a significant drop in accuracy. By fine-tuning the quantization parameters through QAT, we were able to recover most of the lost performance, achieving near-floating-point accuracy in the quantized model. This workflow highlights the effectiveness of QAT in preparing models for deployment on resource-constrained hardware such as Qualcommβs Hexagon NPU. With proper dataset preparation, quantization simulation, and targeted fine-tuning, high-performance deep learning models can be made both efficient and accurate for real-world applications.
Model Type |
Top-1 accuracy |
Top-5 accuracy |
|---|---|---|
Floating-point |
81.07% |
95.32% |
Fake-quantized (W4A8, before QAT) |
8.16% |
20.06% |
Fake-quantized (W4A8, after QAT) |
79.21% |
94.54% |