Source code for aimet_torch.experimental.spinquant.spinquant_optimizer

# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
"""Optimizer for Spinquant"""

from aimet_torch.experimental.spinquant.hadamard_utils import get_hadamard_matrix
import torch
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from transformers.models.mistral.modeling_mistral import MistralForCausalLM

from aimet_common.utils import AimetLogger

_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)

RMSNORM_LINEAR_PAIRS = "RMSNORM_LINEAR_PAIRS"
R1_FUSION_PAIRS = "R1_FUSION_PAIRS"


def _default_r1_fusion_func(llm_model):
    """Default R1 fusion function"""
    r1_direction_pairs = []
    for layer in llm_model.model.layers:
        r1_direction_pairs.extend(
            [
                (layer.self_attn.q_proj, True),
                (layer.self_attn.k_proj, True),
                (layer.self_attn.v_proj, True),
                (layer.self_attn.o_proj, False),
                (layer.mlp.gate_proj, True),
                (layer.mlp.up_proj, True),
                (layer.mlp.down_proj, False),
            ]
        )
    r1_direction_pairs.extend(
        [(llm_model.model.embed_tokens, False), (llm_model.lm_head, True)]
    )
    return r1_direction_pairs


def _default_rmsnorm_linear_pairs_func(llm_model):
    """Default RMSNorm Linear pairs function"""
    rmsnorm_linear_pairs = []
    for layer in llm_model.model.layers:
        rmsnorm_linear_pairs.extend(
            [
                (
                    layer.input_layernorm,
                    [
                        layer.self_attn.q_proj,
                        layer.self_attn.k_proj,
                        layer.self_attn.v_proj,
                    ],
                )
            ]
        )
        rmsnorm_linear_pairs.extend(
            [
                (
                    layer.post_attention_layernorm,
                    [
                        layer.mlp.gate_proj,
                        layer.mlp.up_proj,
                    ],
                )
            ]
        )
    rmsnorm_linear_pairs.extend([(llm_model.model.norm, [llm_model.lm_head])])
    return rmsnorm_linear_pairs


# Dictionary of supported modules and associated information for RMSNORM Linear fusion pairs as well as R1 fusion pairs.
SUPPORTED_MODULE_DICT = {
    LlamaForCausalLM: {
        RMSNORM_LINEAR_PAIRS: _default_rmsnorm_linear_pairs_func,
        R1_FUSION_PAIRS: _default_r1_fusion_func,
    },
    Qwen2ForCausalLM: {
        RMSNORM_LINEAR_PAIRS: _default_rmsnorm_linear_pairs_func,
        R1_FUSION_PAIRS: _default_r1_fusion_func,
    },
    MistralForCausalLM: {
        RMSNORM_LINEAR_PAIRS: _default_rmsnorm_linear_pairs_func,
        R1_FUSION_PAIRS: _default_r1_fusion_func,
    },
}


[docs] def apply_spinquant(model: torch.nn.Module): """ Apply SpinQuant to the model, modifying weights in place. https://arxiv.org/pdf/2405.16406 Currently only R1 rotations without optimization are supported. The model is updated in place. :param model: The model to apply SpinQuant to. """ # Ensure that any user registered module types are fully registered for supported_module_type, module_info in SUPPORTED_MODULE_DICT.items(): if module_info.get(RMSNORM_LINEAR_PAIRS) is None: raise RuntimeError( f"{RMSNORM_LINEAR_PAIRS} info missing for module type {supported_module_type.__name__}" ) if module_info.get(R1_FUSION_PAIRS) is None: raise RuntimeError( f"{R1_FUSION_PAIRS} info missing for module type {supported_module_type.__name__}" ) found_module = False for module in model.modules(): if isinstance(module, tuple(SUPPORTED_MODULE_DICT.keys())): if module.model.embed_tokens.weight is module.lm_head.weight: raise RuntimeError( "SpinQuant requires embed_tokens and lm_head weights to be untied. Ensure that model.config.tie_word_embeddings or a similar relevant " "setting is set to False for the model." ) found_module = True _identify_and_fuse_rmsnorms_into_linears(module) _fuse_r1_rotations(module) if not found_module: _logger.warning( "SpinQuant optimizer did not find any modules to apply SpinQuant for in the model." )
def _identify_and_fuse_rmsnorms_into_linears(llm_model: torch.nn.Module): for rmsnorm, linears in SUPPORTED_MODULE_DICT[type(llm_model)][ RMSNORM_LINEAR_PAIRS ](llm_model): _fuse_rmsnorm_into_linear(rmsnorm, linears) def _fuse_rmsnorm_into_linear(rmsnorm, linear_layers): for linear in linear_layers: linear_dtype = linear.weight.dtype W = linear.weight.data linear.weight.data = (W * rmsnorm.weight).to(linear_dtype) if hasattr(rmsnorm, "bias"): if linear.bias is not None: linear.bias.data = linear.bias.data.double() + torch.matmul( W, rmsnorm.bias.double() ) linear.bias.data = linear.bias.data.to(linear_dtype) rmsnorm.weight.data = torch.ones_like(rmsnorm.weight.data) def _fuse_r1_rotations(llm_model: torch.nn.Module): modules_and_fuse_directions = SUPPORTED_MODULE_DICT[type(llm_model)][ R1_FUSION_PAIRS ](llm_model) if modules_and_fuse_directions: hidden_size = modules_and_fuse_directions[0][0].weight.shape[1] had_matrix = get_hadamard_matrix(hidden_size).to( modules_and_fuse_directions[0][0].weight.device ) / torch.sqrt(torch.tensor(hidden_size)) for module, fuse_before in modules_and_fuse_directions: _fuse_r1_rotation(module, fuse_before, had_matrix) def _fuse_r1_rotation(module, fuse_before, had_matrix): with torch.no_grad(): if isinstance(module, torch.nn.Linear): if fuse_before: module.weight.copy_(module.weight @ had_matrix.T) else: module.weight.copy_((module.weight.T @ had_matrix.T).T) if module.bias is not None: module.bias.copy_((module.bias.T @ had_matrix.T).T) elif isinstance(module, torch.nn.Embedding): if not fuse_before: module.weight.copy_(module.weight @ had_matrix.T) else: raise RuntimeError("Embedding module is expected to fuse after only")