Source code for aimet_onnx.experimental.spinquant.spinquant

# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause

"""Top-level SpinQuant API for ONNX QuantizationSimModel."""

from typing import Optional, List
import numpy as np
import torch

from aimet_onnx.common.hadamard import get_hadamard_matrix
from aimet_onnx.common.utils import AimetLogger
from aimet_onnx.meta.operations import Op
from aimet_onnx.quantsim import QuantizationSimModel
from aimet_onnx.experimental.spinquant.apply_rotation import (
    apply_r1_rotation,
    apply_r1_rotation_merger,
    _infer_hidden_size,
    _validate_backbone_weights,
    _validate_merger_linear2,
)
from aimet_onnx.experimental.spinquant.block_identifier import (
    DecoderModelRoleMap,
    get_decoder_block_boundaries,
    get_decoder_role_map,
    find_merger_linear2,
)
from aimet_onnx.experimental.spinquant.fuse_norm import (
    ActiveNorm,
    fuse_norm_layers_into_linears,
)

_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.SpinQuant)


[docs] def apply_spinquant( backbone_sim: QuantizationSimModel, visual_sim: Optional[QuantizationSimModel] = None, embedding: Optional[torch.Tensor] = None, ) -> None: """Apply SpinQuant rotation transforms to an ONNX transformer model. SpinQuant applies orthogonal Hadamard rotations to model weights to reduce quantization error. This function modifies the QuantizationSimModel(s) in-place by: 1. Fusing RMS normalization layers into subsequent linear layers 2. Applying R1 Hadamard rotations to embeddings, attention, and MLP layers 3. For VLMs: applying R_L to PatchMerger linear_fc2 (non-negotiable) Must be called BEFORE sim.compute_encodings(). The rotation modifies float weight initializers; compute_encodings must run afterward to calibrate quantizer scales on the rotated weights. Supported architectures: - LLaMA, Qwen2, Qwen3, Phi3 (backbone only) - Qwen2.5-VL, Qwen3-VL (backbone + visual) :param backbone_sim: QuantizationSimModel wrapping backbone.onnx. :param visual_sim: Optional QuantizationSimModel wrapping visual.onnx (VLM only). :param embedding: Optional ``torch.Tensor`` of shape ``[vocab, hidden]`` loaded from ``embedding.pth`` (VLM only). Rotated in-place with R_L. :raises ValueError: If block detection or role classification fails, if any expected weight is missing or has the wrong shape. Example (LLM):: sim = QuantizationSimModel(model) apply_spinquant(sim) sim.compute_encodings(calibration_data) Example (VLM):: backbone_sim = QuantizationSimModel(backbone_model) visual_sim = QuantizationSimModel(visual_model) embedding = torch.load("embedding.pth") # torch.Tensor [vocab, hidden] apply_spinquant(backbone_sim, visual_sim=visual_sim, embedding=embedding) torch.save(embedding, "embedding.pth") # overwrite with rotated weights backbone_sim.compute_encodings(backbone_calibration_data) visual_sim.compute_encodings(visual_calibration_data) """ bb_model = backbone_sim.model.model bb_cg = backbone_sim.connected_graph bb_boundaries, bb_active_norms = get_decoder_block_boundaries(bb_model, bb_cg) bb_role_map = get_decoder_role_map(bb_cg, bb_boundaries, bb_active_norms) bb_hidden_size = _infer_hidden_size(bb_model, bb_role_map) if visual_sim is not None: v_model = visual_sim.model.model v_cg = visual_sim.connected_graph v_merger_linear2 = find_merger_linear2(v_cg) if embedding is not None and bb_role_map.embed_tokens: raise ValueError( "embedding was provided but backbone contains embed_tokens op(s). " "Pass embedding only for VLM backbones exported with use_inputs_embeds=True " "(i.e. backbone has no Gather op for token embeddings)." ) if embedding is None and not bb_role_map.embed_tokens: raise ValueError( "Backbone has no embed_tokens op but no external embedding was provided. " "Pass embedding=torch.load('embedding.pth') for VLM backbones exported with " "use_inputs_embeds=True so the embedding weight is rotated alongside the backbone." ) # Validate upfront _validate_backbone_weights(bb_model, bb_role_map, bb_hidden_size) if visual_sim is not None: _validate_merger_linear2(v_model, v_merger_linear2, bb_hidden_size) # Apply rotation _apply_spinquant_backbone( backbone_sim, bb_role_map, bb_active_norms, bb_hidden_size, embedding ) if visual_sim is not None: _apply_spinquant_visual(visual_sim, v_merger_linear2, bb_hidden_size)
def _apply_spinquant_backbone( sim: QuantizationSimModel, role_map: DecoderModelRoleMap, active_norms: List[ActiveNorm], backbone_hidden_size: int, embedding: Optional[torch.Tensor] = None, ) -> None: """Fuse norms, apply R1 rotation, and optionally rotate external embedding in-place. :param sim: QuantizationSimModel wrapping backbone.onnx. :param role_map: Decoder role map. :param active_norms: Norm initializer names to fuse. :param backbone_hidden_size: Hidden dimension of the language model residual stream. :param embedding: Optional raw embedding weight tensor. """ model = sim.model.model fuse_norm_layers_into_linears(model, active_norms) apply_r1_rotation(model, role_map, backbone_hidden_size) if embedding is not None: R_L = ( get_hadamard_matrix(backbone_hidden_size) / np.sqrt(backbone_hidden_size) ).astype(np.float64) original_device = embedding.device W = embedding.detach().cpu().numpy() W_rot = (W @ R_L).astype(W.dtype) embedding.data.copy_(torch.from_numpy(W_rot).to(original_device)) _logger.info( "Backbone: Rotated external embedding weight with R_L (shape %s).", W.shape ) sim._rebuild_session() # pylint: disable=protected-access def _apply_spinquant_visual( sim: QuantizationSimModel, merger_linear2: List[Op], backbone_hidden_size: int, ) -> None: """Apply R_L rotation to PatchMerger linear_fc2 in the ViT visual encoder. :param sim: QuantizationSimModel wrapping visual.onnx. :param merger_linear2: List of merger_linear2 ops from find_merger_linear2. :param backbone_hidden_size: Hidden dimension of the language model residual stream. """ model = sim.model.model apply_r1_rotation_merger(model, merger_linear2, backbone_hidden_size) sim._rebuild_session() # pylint: disable=protected-access