aimet_torch.experimental.spinquant

Top level APIs

aimet_torch.experimental.spinquant.apply_spinquant(model)

Apply SpinQuant rotation transforms to a transformer-based language model.

SpinQuant applies orthogonal Hadamard rotations to model weights to reduce quantization error. This method modifies the model in-place by:

  1. Fusing RMS normalization layers into subsequent linear layers

  2. Applying R1 Hadamard rotations to embeddings, attention, and MLP layers

  3. Merging all transforms into the weight matrices

Supported architectures:
  • LLaMA

  • Qwen2, Qwen3

  • Phi3

  • Qwen2.5-VL (Vision-Language Model)

Parameters:

model (Module) – A HuggingFace transformer model (e.g., LlamaForCausalLM, Qwen2ForCausalLM). The model must have untied embed_tokens and lm_head weights.

Raises:

RuntimeError – If embed_tokens and lm_head weights are tied.

Example

>>> from transformers import AutoModelForCausalLM
>>> from aimet_torch.experimental.spinquant import apply_spinquant
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
>>> # Untie embedding and lm_head weights if they are tied
>>> old_weight = model.lm_head.weight
>>> model.lm_head.weight = torch.nn.Parameter(
...     old_weight.data.clone().detach().to(old_weight.device),
...     requires_grad=old_weight.requires_grad,
... )
>>> apply_spinquant(model)