aimet_onnx.experimental.spinquant

Top level APIs

aimet_onnx.experimental.spinquant.apply_spinquant(backbone_sim, visual_sim=None, embedding=None)[source]

Apply SpinQuant rotation transforms to an ONNX transformer model.

SpinQuant applies orthogonal Hadamard rotations to model weights to reduce quantization error. This function modifies the QuantizationSimModel(s) in-place by:

  1. Fusing RMS normalization layers into subsequent linear layers

  2. Applying R1 Hadamard rotations to embeddings, attention, and MLP layers

  3. For VLMs: applying R_L to PatchMerger linear_fc2 (non-negotiable)

Must be called BEFORE sim.compute_encodings(). The rotation modifies float weight initializers; compute_encodings must run afterward to calibrate quantizer scales on the rotated weights.

Supported architectures:
  • LLaMA, Qwen2, Qwen3, Phi3 (backbone only)

  • Qwen2.5-VL, Qwen3-VL (backbone + visual)

Parameters:
  • backbone_sim (QuantizationSimModel) – QuantizationSimModel wrapping backbone.onnx.

  • visual_sim (Optional[QuantizationSimModel]) – Optional QuantizationSimModel wrapping visual.onnx (VLM only).

  • embedding (Optional[Tensor]) – Optional torch.Tensor of shape [vocab, hidden] loaded from embedding.pth (VLM only). Rotated in-place with R_L.

Raises:

ValueError – If block detection or role classification fails, if any expected weight is missing or has the wrong shape.

Return type:

None

Example (LLM):

sim = QuantizationSimModel(model)
apply_spinquant(sim)
sim.compute_encodings(calibration_data)

Example (VLM):

backbone_sim = QuantizationSimModel(backbone_model)
visual_sim = QuantizationSimModel(visual_model)
embedding = torch.load("embedding.pth")   # torch.Tensor [vocab, hidden]
apply_spinquant(backbone_sim, visual_sim=visual_sim, embedding=embedding)
torch.save(embedding, "embedding.pth")    # overwrite with rotated weights
backbone_sim.compute_encodings(backbone_calibration_data)
visual_sim.compute_encodings(visual_calibration_data)