SpinQuant

Note

This feature is currently experimental. The API may change in the future.

Context

SpinQuant (arXiv:2405.16406) is a post-training quantization technique that reduces activation outliers by inserting orthogonal Hadamard rotations at key points in the model. Because the rotations are absorbed into adjacent weight matrices, the final model architecture is unchanged.

AIMET implements R1 rotations (fixed Hadamard, no optimization). R1 rotation fuses RMSNorm scale weights into downstream linear layers, then applies a Hadamard rotation across the residual stream to reduce outliers in Q/K/V/O and gate/up/down projections.

Supported architectures

Model family

PyTorch

ONNX

LlamaForCausalLM

Qwen2ForCausalLM, Qwen3ForCausalLM

MistralForCausalLM

Phi3ForCausalLM

Qwen2.5-VL, Qwen3-VL

Note

Support for additional model families is added continuously as new architectures are validated. See the release notes for the latest additions.

Workflow

Prerequisites

To use SpinQuant, you need:

  • A pre-trained model loaded from HuggingFace.

  • ONNX only: the model must be exported to ONNX — Step 2 handles this.

Note

For a complete working example, see Examples/torch/quantize.py or Examples/onnx/quantize.py (run with --recipe pcq_spinquant).

Procedure

Step 1: Load model

Load the HuggingFace model and wrap it with ONNXExportableModuleWithCache to enable JIT tracing with a static graph.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from GenAILab.qai_hub_lm.utils.model_utils import ONNXExportableModuleWithCache

SEQUENCE_LENGTH = 2048
CONTEXT_LENGTH = 4096

model_id = "meta-llama/Llama-3.2-1B-Instruct"
hf_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)

# Wrap model to satisfy static graph constraints for JIT trace
traceable_model = ONNXExportableModuleWithCache(hf_model)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from GenAILab.qai_hub_lm.utils.model_utils import ONNXExportableModuleWithCache

SEQUENCE_LENGTH = 2048
CONTEXT_LENGTH = 4096

model_id = "meta-llama/Llama-3.2-1B-Instruct"
hf_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)

# Wrap model to satisfy static graph constraints for JIT trace
traceable_model = ONNXExportableModuleWithCache(hf_model)

Step 2: Create QuantizationSimModel

Create a QuantizationSimModel with the desired quantization configuration. For ONNX, this step also exports the model to ONNX.

from aimet_torch.common.defs import QuantScheme
from aimet_torch import QuantizationSimModel
from GenAILab.qai_hub_lm.models.generator import Generator

assembled_dummy_inputs = Generator.prepare_inputs(
    model=traceable_model,
    input_ids=torch.zeros((1, SEQUENCE_LENGTH), dtype=torch.int),
    attention_mask=torch.ones((1, SEQUENCE_LENGTH), dtype=torch.int),
    past_key_values=[],
    context_length=CONTEXT_LENGTH,
    sequence_length=SEQUENCE_LENGTH,
)

quantsim = QuantizationSimModel(
    model=traceable_model,
    quant_scheme=QuantScheme.post_training_tf,
    dummy_input=assembled_dummy_inputs,
    default_output_bw=16,
    default_param_bw=8,
    in_place=True,
    config_file="htp_v73",
)

generator = Generator(quantsim.model, tokenizer, SEQUENCE_LENGTH, CONTEXT_LENGTH)
import os
import tempfile
import onnx
from aimet_onnx.quantsim import QuantizationSimModel
from GenAILab.qai_hub_lm.models.base import LLM
from GenAILab.qai_hub_lm.utils.layer_cache import build_layer_cache_descriptors
from GenAILab.qai_hub_lm.models.generator import Generator
from GenAILab.qai_hub_lm.backends.onnx.torch_onnx_interface import TorchONNXInterface
from GenAILab.qai_hub_lm.backends.onnx.quantsim_utils import (
    _set_tensors_to_output_n_bit_symmmetric,
    _tie_quantizers_for_kv_cache,
    _set_lm_head_precision,
)
from GenAILab.qai_hub_lm.precision import WeightPrecision
from aimet_onnx.common.defs import int8

assembled_dummy_inputs = Generator.prepare_inputs(
    model=traceable_model,
    input_ids=torch.zeros((1, SEQUENCE_LENGTH), dtype=torch.int),
    attention_mask=torch.ones((1, SEQUENCE_LENGTH), dtype=torch.int),
    past_key_values=[],
    context_length=CONTEXT_LENGTH,
    sequence_length=SEQUENCE_LENGTH,
)

with tempfile.TemporaryDirectory() as tmpdir:
    torch.onnx.export(
        traceable_model,
        assembled_dummy_inputs,
        os.path.join(tmpdir, "model.onnx"),
        input_names=LLM.get_backbone_input_names(build_layer_cache_descriptors(hf_model.config)),
        output_names=LLM.get_backbone_output_names(build_layer_cache_descriptors(hf_model.config)),
        opset_version=17,
        dynamo=False,
    )
    onnx_model = onnx.load(os.path.join(tmpdir, "model.onnx"))

quantsim = QuantizationSimModel(
    model=onnx_model,
    quant_scheme="min_max",
    default_activation_bw=16,
    default_param_bw=8,
    config_file="htp_v73",
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
_set_tensors_to_output_n_bit_symmmetric(quantsim, kv_bits=8)
_set_lm_head_precision(quantsim, WeightPrecision(qtype=int8, granularity="PCQ"))
_tie_quantizers_for_kv_cache(quantsim)

quantsim_with_torch_interface = TorchONNXInterface(quantsim, hf_model.config)
generator = Generator(quantsim_with_torch_interface, tokenizer, SEQUENCE_LENGTH, CONTEXT_LENGTH)

Step 3: Apply SpinQuant

Apply SpinQuant to the model. This fuses RMSNorm scale weights into downstream linear layers and applies the R1 Hadamard rotation to all weight matrices in-place.

Important

apply_spinquant must be called before compute_encodings. The rotation modifies float weight initializers; compute_encodings must run afterward to calibrate quantizer scales on the rotated weights.

from aimet_torch.experimental.spinquant import apply_spinquant

# apply_spinquant modifies the model in-place. Must be called BEFORE compute_encodings.
apply_spinquant(model=quantsim.model)
from aimet_onnx.experimental.spinquant import apply_spinquant

# apply_spinquant modifies quantsim in-place. Must be called BEFORE compute_encodings.
apply_spinquant(quantsim)

Step 4: Compute activation encodings

Calibrate activation quantizers by running the model through a representative dataset.

import itertools
from tqdm import tqdm
from GenAILab.bench.datasets import Wikitext

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
train_dataset = Wikitext.load_encoded_dataset(tokenizer, CONTEXT_LENGTH, "train")


def calibration_callback(model):
    for batch in tqdm(
        itertools.islice(train_dataset, 20), total=20, desc="Calibrating"
    ):
        generator(input_ids=batch["input_ids"].to(device=device))


quantsim.compute_encodings(calibration_callback)
from tqdm import tqdm
from GenAILab.bench.datasets import Wikitext
from GenAILab.bench.onnx.quant_recipes import _prefill_inputs

train_dataset = Wikitext.load_encoded_dataset(tokenizer, CONTEXT_LENGTH, "train")
calib_inputs = _prefill_inputs(quantsim, generator, train_dataset, num_batches=20)


def _forward(session, _):
    for batch in tqdm(calib_inputs, total=len(calib_inputs), desc="Calibrating"):
        session.run(None, batch)


quantsim.compute_encodings(_forward, tuple())

After completing these steps, export the quantized model:

  • PyTorch: quantsim.export(...)

  • ONNX: quantsim.export(...)

API

Top level APIs

aimet_torch.experimental.spinquant.apply_spinquant(model)

Apply SpinQuant rotation transforms to a transformer-based language model.

SpinQuant applies orthogonal Hadamard rotations to model weights to reduce quantization error. This method modifies the model in-place by:

  1. Fusing RMS normalization layers into subsequent linear layers

  2. Applying R1 Hadamard rotations to embeddings, attention, and MLP layers

  3. Merging all transforms into the weight matrices

Supported architectures:
  • LLaMA

  • Qwen2, Qwen3

  • Phi3

  • Qwen2.5-VL (Vision-Language Model)

Parameters:

model (Module) – A HuggingFace transformer model (e.g., LlamaForCausalLM, Qwen2ForCausalLM). The model must have untied embed_tokens and lm_head weights.

Raises:

RuntimeError – If embed_tokens and lm_head weights are tied.

Example

>>> from transformers import AutoModelForCausalLM
>>> from aimet_torch.experimental.spinquant import apply_spinquant
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
>>> # Untie embedding and lm_head weights if they are tied
>>> old_weight = model.lm_head.weight
>>> model.lm_head.weight = torch.nn.Parameter(
...     old_weight.data.clone().detach().to(old_weight.device),
...     requires_grad=old_weight.requires_grad,
... )
>>> apply_spinquant(model)

Top level APIs

aimet_onnx.experimental.spinquant.apply_spinquant(backbone_sim, visual_sim=None, embedding=None)[source]

Apply SpinQuant rotation transforms to an ONNX transformer model.

SpinQuant applies orthogonal Hadamard rotations to model weights to reduce quantization error. This function modifies the QuantizationSimModel(s) in-place by:

  1. Fusing RMS normalization layers into subsequent linear layers

  2. Applying R1 Hadamard rotations to embeddings, attention, and MLP layers

  3. For VLMs: applying R_L to PatchMerger linear_fc2 (non-negotiable)

Must be called BEFORE sim.compute_encodings(). The rotation modifies float weight initializers; compute_encodings must run afterward to calibrate quantizer scales on the rotated weights.

Supported architectures:
  • LLaMA, Qwen2, Qwen3, Phi3 (backbone only)

  • Qwen2.5-VL, Qwen3-VL (backbone + visual)

Parameters:
  • backbone_sim (QuantizationSimModel) – QuantizationSimModel wrapping backbone.onnx.

  • visual_sim (Optional[QuantizationSimModel]) – Optional QuantizationSimModel wrapping visual.onnx (VLM only).

  • embedding (Optional[Tensor]) – Optional torch.Tensor of shape [vocab, hidden] loaded from embedding.pth (VLM only). Rotated in-place with R_L.

Raises:

ValueError – If block detection or role classification fails, if any expected weight is missing or has the wrong shape.

Return type:

None

Example (LLM):

sim = QuantizationSimModel(model)
apply_spinquant(sim)
sim.compute_encodings(calibration_data)

Example (VLM):

backbone_sim = QuantizationSimModel(backbone_model)
visual_sim = QuantizationSimModel(visual_model)
embedding = torch.load("embedding.pth")   # torch.Tensor [vocab, hidden]
apply_spinquant(backbone_sim, visual_sim=visual_sim, embedding=embedding)
torch.save(embedding, "embedding.pth")    # overwrite with rotated weights
backbone_sim.compute_encodings(backbone_calibration_data)
visual_sim.compute_encodings(visual_calibration_data)