SpinQuant¶
Note
This feature is currently experimental. The API may change in the future.
Context¶
SpinQuant (arXiv:2405.16406) is a post-training quantization technique that reduces activation outliers by inserting orthogonal Hadamard rotations at key points in the model. Because the rotations are absorbed into adjacent weight matrices, the final model architecture is unchanged.
AIMET implements R1 rotations (fixed Hadamard, no optimization). R1 rotation fuses RMSNorm scale weights into downstream linear layers, then applies a Hadamard rotation across the residual stream to reduce outliers in Q/K/V/O and gate/up/down projections.
Model family |
PyTorch |
ONNX |
|---|---|---|
|
✓ |
✓ |
|
✓ |
✓ |
|
✓ |
✓ |
|
✓ |
✓ |
|
✓ |
✓ |
Note
Support for additional model families is added continuously as new architectures are validated. See the release notes for the latest additions.
Workflow¶
Prerequisites¶
To use SpinQuant, you need:
A pre-trained model loaded from HuggingFace.
ONNX only: the model must be exported to ONNX — Step 2 handles this.
Note
For a complete working example, see
Examples/torch/quantize.py
or
Examples/onnx/quantize.py
(run with --recipe pcq_spinquant).
Procedure¶
Step 1: Load model¶
Load the HuggingFace model and wrap it with ONNXExportableModuleWithCache to enable JIT tracing
with a static graph.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from GenAILab.qai_hub_lm.utils.model_utils import ONNXExportableModuleWithCache
SEQUENCE_LENGTH = 2048
CONTEXT_LENGTH = 4096
model_id = "meta-llama/Llama-3.2-1B-Instruct"
hf_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
# Wrap model to satisfy static graph constraints for JIT trace
traceable_model = ONNXExportableModuleWithCache(hf_model)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from GenAILab.qai_hub_lm.utils.model_utils import ONNXExportableModuleWithCache
SEQUENCE_LENGTH = 2048
CONTEXT_LENGTH = 4096
model_id = "meta-llama/Llama-3.2-1B-Instruct"
hf_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
# Wrap model to satisfy static graph constraints for JIT trace
traceable_model = ONNXExportableModuleWithCache(hf_model)
Step 2: Create QuantizationSimModel¶
Create a QuantizationSimModel with the desired quantization configuration. For ONNX, this step also exports the model to ONNX.
from aimet_torch.common.defs import QuantScheme
from aimet_torch import QuantizationSimModel
from GenAILab.qai_hub_lm.models.generator import Generator
assembled_dummy_inputs = Generator.prepare_inputs(
model=traceable_model,
input_ids=torch.zeros((1, SEQUENCE_LENGTH), dtype=torch.int),
attention_mask=torch.ones((1, SEQUENCE_LENGTH), dtype=torch.int),
past_key_values=[],
context_length=CONTEXT_LENGTH,
sequence_length=SEQUENCE_LENGTH,
)
quantsim = QuantizationSimModel(
model=traceable_model,
quant_scheme=QuantScheme.post_training_tf,
dummy_input=assembled_dummy_inputs,
default_output_bw=16,
default_param_bw=8,
in_place=True,
config_file="htp_v73",
)
generator = Generator(quantsim.model, tokenizer, SEQUENCE_LENGTH, CONTEXT_LENGTH)
import os
import tempfile
import onnx
from aimet_onnx.quantsim import QuantizationSimModel
from GenAILab.qai_hub_lm.models.base import LLM
from GenAILab.qai_hub_lm.utils.layer_cache import build_layer_cache_descriptors
from GenAILab.qai_hub_lm.models.generator import Generator
from GenAILab.qai_hub_lm.backends.onnx.torch_onnx_interface import TorchONNXInterface
from GenAILab.qai_hub_lm.backends.onnx.quantsim_utils import (
_set_tensors_to_output_n_bit_symmmetric,
_tie_quantizers_for_kv_cache,
_set_lm_head_precision,
)
from GenAILab.qai_hub_lm.precision import WeightPrecision
from aimet_onnx.common.defs import int8
assembled_dummy_inputs = Generator.prepare_inputs(
model=traceable_model,
input_ids=torch.zeros((1, SEQUENCE_LENGTH), dtype=torch.int),
attention_mask=torch.ones((1, SEQUENCE_LENGTH), dtype=torch.int),
past_key_values=[],
context_length=CONTEXT_LENGTH,
sequence_length=SEQUENCE_LENGTH,
)
with tempfile.TemporaryDirectory() as tmpdir:
torch.onnx.export(
traceable_model,
assembled_dummy_inputs,
os.path.join(tmpdir, "model.onnx"),
input_names=LLM.get_backbone_input_names(build_layer_cache_descriptors(hf_model.config)),
output_names=LLM.get_backbone_output_names(build_layer_cache_descriptors(hf_model.config)),
opset_version=17,
dynamo=False,
)
onnx_model = onnx.load(os.path.join(tmpdir, "model.onnx"))
quantsim = QuantizationSimModel(
model=onnx_model,
quant_scheme="min_max",
default_activation_bw=16,
default_param_bw=8,
config_file="htp_v73",
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
_set_tensors_to_output_n_bit_symmmetric(quantsim, kv_bits=8)
_set_lm_head_precision(quantsim, WeightPrecision(qtype=int8, granularity="PCQ"))
_tie_quantizers_for_kv_cache(quantsim)
quantsim_with_torch_interface = TorchONNXInterface(quantsim, hf_model.config)
generator = Generator(quantsim_with_torch_interface, tokenizer, SEQUENCE_LENGTH, CONTEXT_LENGTH)
Step 3: Apply SpinQuant¶
Apply SpinQuant to the model. This fuses RMSNorm scale weights into downstream linear layers and applies the R1 Hadamard rotation to all weight matrices in-place.
Important
apply_spinquant must be called before compute_encodings. The rotation modifies float
weight initializers; compute_encodings must run afterward to calibrate quantizer scales on
the rotated weights.
from aimet_torch.experimental.spinquant import apply_spinquant
# apply_spinquant modifies the model in-place. Must be called BEFORE compute_encodings.
apply_spinquant(model=quantsim.model)
from aimet_onnx.experimental.spinquant import apply_spinquant
# apply_spinquant modifies quantsim in-place. Must be called BEFORE compute_encodings.
apply_spinquant(quantsim)
Step 4: Compute activation encodings¶
Calibrate activation quantizers by running the model through a representative dataset.
import itertools
from tqdm import tqdm
from GenAILab.bench.datasets import Wikitext
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
train_dataset = Wikitext.load_encoded_dataset(tokenizer, CONTEXT_LENGTH, "train")
def calibration_callback(model):
for batch in tqdm(
itertools.islice(train_dataset, 20), total=20, desc="Calibrating"
):
generator(input_ids=batch["input_ids"].to(device=device))
quantsim.compute_encodings(calibration_callback)
from tqdm import tqdm
from GenAILab.bench.datasets import Wikitext
from GenAILab.bench.onnx.quant_recipes import _prefill_inputs
train_dataset = Wikitext.load_encoded_dataset(tokenizer, CONTEXT_LENGTH, "train")
calib_inputs = _prefill_inputs(quantsim, generator, train_dataset, num_batches=20)
def _forward(session, _):
for batch in tqdm(calib_inputs, total=len(calib_inputs), desc="Calibrating"):
session.run(None, batch)
quantsim.compute_encodings(_forward, tuple())
After completing these steps, export the quantized model:
PyTorch:
quantsim.export(...)ONNX:
quantsim.export(...)
API¶
Top level APIs
- aimet_torch.experimental.spinquant.apply_spinquant(model)¶
Apply SpinQuant rotation transforms to a transformer-based language model.
SpinQuant applies orthogonal Hadamard rotations to model weights to reduce quantization error. This method modifies the model in-place by:
Fusing RMS normalization layers into subsequent linear layers
Applying R1 Hadamard rotations to embeddings, attention, and MLP layers
Merging all transforms into the weight matrices
- Supported architectures:
LLaMA
Qwen2, Qwen3
Phi3
Qwen2.5-VL (Vision-Language Model)
- Parameters:
model (
Module) – A HuggingFace transformer model (e.g., LlamaForCausalLM, Qwen2ForCausalLM). The model must have untied embed_tokens and lm_head weights.- Raises:
RuntimeError – If embed_tokens and lm_head weights are tied.
Example
>>> from transformers import AutoModelForCausalLM >>> from aimet_torch.experimental.spinquant import apply_spinquant >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") >>> # Untie embedding and lm_head weights if they are tied >>> old_weight = model.lm_head.weight >>> model.lm_head.weight = torch.nn.Parameter( ... old_weight.data.clone().detach().to(old_weight.device), ... requires_grad=old_weight.requires_grad, ... ) >>> apply_spinquant(model)
Top level APIs
- aimet_onnx.experimental.spinquant.apply_spinquant(backbone_sim, visual_sim=None, embedding=None)[source]¶
Apply SpinQuant rotation transforms to an ONNX transformer model.
SpinQuant applies orthogonal Hadamard rotations to model weights to reduce quantization error. This function modifies the QuantizationSimModel(s) in-place by:
Fusing RMS normalization layers into subsequent linear layers
Applying R1 Hadamard rotations to embeddings, attention, and MLP layers
For VLMs: applying R_L to PatchMerger linear_fc2 (non-negotiable)
Must be called BEFORE sim.compute_encodings(). The rotation modifies float weight initializers; compute_encodings must run afterward to calibrate quantizer scales on the rotated weights.
- Supported architectures:
LLaMA, Qwen2, Qwen3, Phi3 (backbone only)
Qwen2.5-VL, Qwen3-VL (backbone + visual)
- Parameters:
backbone_sim (
QuantizationSimModel) – QuantizationSimModel wrapping backbone.onnx.visual_sim (
Optional[QuantizationSimModel]) – Optional QuantizationSimModel wrapping visual.onnx (VLM only).embedding (
Optional[Tensor]) – Optionaltorch.Tensorof shape[vocab, hidden]loaded fromembedding.pth(VLM only). Rotated in-place with R_L.
- Raises:
ValueError – If block detection or role classification fails, if any expected weight is missing or has the wrong shape.
- Return type:
None
Example (LLM):
sim = QuantizationSimModel(model) apply_spinquant(sim) sim.compute_encodings(calibration_data)
Example (VLM):
backbone_sim = QuantizationSimModel(backbone_model) visual_sim = QuantizationSimModel(visual_model) embedding = torch.load("embedding.pth") # torch.Tensor [vocab, hidden] apply_spinquant(backbone_sim, visual_sim=visual_sim, embedding=embedding) torch.save(embedding, "embedding.pth") # overwrite with rotated weights backbone_sim.compute_encodings(backbone_calibration_data) visual_sim.compute_encodings(visual_calibration_data)