aimet_onnx.experimental.spinquant¶
Top level APIs
- aimet_onnx.experimental.spinquant.apply_spinquant(backbone_sim, visual_sim=None, embedding=None)[source]¶
Apply SpinQuant rotation transforms to an ONNX transformer model.
SpinQuant applies orthogonal Hadamard rotations to model weights to reduce quantization error. This function modifies the QuantizationSimModel(s) in-place by:
Fusing RMS normalization layers into subsequent linear layers
Applying R1 Hadamard rotations to embeddings, attention, and MLP layers
For VLMs: applying R_L to PatchMerger linear_fc2 (non-negotiable)
Must be called BEFORE sim.compute_encodings(). The rotation modifies float weight initializers; compute_encodings must run afterward to calibrate quantizer scales on the rotated weights.
- Supported architectures:
LLaMA, Qwen2, Qwen3, Phi3 (backbone only)
Qwen2.5-VL, Qwen3-VL (backbone + visual)
- Parameters:
backbone_sim (
QuantizationSimModel) – QuantizationSimModel wrapping backbone.onnx.visual_sim (
Optional[QuantizationSimModel]) – Optional QuantizationSimModel wrapping visual.onnx (VLM only).embedding (
Optional[Tensor]) – Optionaltorch.Tensorof shape[vocab, hidden]loaded fromembedding.pth(VLM only). Rotated in-place with R_L.
- Raises:
ValueError – If block detection or role classification fails, if any expected weight is missing or has the wrong shape.
- Return type:
None
Example (LLM):
sim = QuantizationSimModel(model) apply_spinquant(sim) sim.compute_encodings(calibration_data)
Example (VLM):
backbone_sim = QuantizationSimModel(backbone_model) visual_sim = QuantizationSimModel(visual_model) embedding = torch.load("embedding.pth") # torch.Tensor [vocab, hidden] apply_spinquant(backbone_sim, visual_sim=visual_sim, embedding=embedding) torch.save(embedding, "embedding.pth") # overwrite with rotated weights backbone_sim.compute_encodings(backbone_calibration_data) visual_sim.compute_encodings(visual_calibration_data)