Source code for aimet_onnx.experimental.adascale.adascale_optimizer

# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause

"""AdaScale implementation"""

import contextlib
from typing import Collection, Dict, List, Tuple
import copy
from dataclasses import dataclass
import numpy as np
import torch
import tqdm
import tempfile
import gc
import onnx
from pathlib import Path
import os

from aimet_onnx.common.utils import AimetLogger  # pylint: disable=import-error
from aimet_onnx.experimental.adascale.utils import (
    convert_to_torch,
    change_tensor_device_placement,
)
from aimet_onnx.utils import (
    get_torch_device,
)
from aimet_onnx.quantsim import QuantizationSimModel
from aimet_onnx.experimental.adascale.find_blocks import (
    get_decoder_blocks_end_points,
)

from aimet_onnx.experimental.adascale.quantizer import (
    add_qlinear_layers,
    get_adascale_trainable_params,
    replace_with_adascale_quantizers,
)

from aimet_onnx.experimental.adascale.activation_sampler import ActivationSampler
from aimet_onnx.experimental.adascale.model_converter import (
    get_pt_block,
    copy_pt_weights_to_onnx,
    copy_pt_encodings_to_sim,
)

_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.AdaScale)


_QT_SAMPLING_PROB = 1.0
_LOSS_FN = torch.nn.MSELoss()
_DEBUG_NUM_PARTIAL_ITERATIONS = None
_DEBUG_NUM_PARTIAL_ITERATIONS_START = None
_DEBUG_NUM_PARTIAL_ITERATIONS_END = None


[docs] @dataclass class AdaScaleModelConfig: model_type: str beta_gamma_lr: float = 1e-3 # lr for beta and gamma scales_lr: float = 5e-4 # lr for s2, s3, [s4]
# mapping of model type and the corresponding adascale config adascale_model_config_dict = { "llama": AdaScaleModelConfig( model_type="llama", beta_gamma_lr=1e-3, scales_lr=5e-4 ), "qwen2": AdaScaleModelConfig( model_type="qwen2", beta_gamma_lr=1e-3, scales_lr=5e-4 ), "mistral": AdaScaleModelConfig( model_type="mistral", beta_gamma_lr=1e-3, scales_lr=5e-4 ), "qwen3": AdaScaleModelConfig( model_type="qwen3", beta_gamma_lr=1e-3, scales_lr=5e-4 ), "phi3": AdaScaleModelConfig(model_type="phi3", beta_gamma_lr=1e-3, scales_lr=5e-4), } class AdaScale: """ AdaScale is PTQ technique which performs Knowledge Distillation on blocks of modules by using the FP32 output as its reference output. Adascale is based on FlexRound: https://arxiv.org/abs/2306.00317 but integrates LWC from Omniquant. The optimization is performed on a block-by-block basis by comparing the quantized output of the block with its FP32 equivalent and by training the parameters (gamma, beta, s2, s3) which are temporarily introduced in every supported module. A block is defined as a non-leaf module which takes in one activation input tensor and outputs one activation tensor Currently only Linear layers are supported, and all the linears in a block are optimized at the same time. While performing the optimization, the activation quantizers are disabled, linear modules' weight quantizers are changed to specialized QDQ (with learnable parameters introduced) and rest of the param's are left quantized with default QuantizeDequantize. """ ADASCALE_PARAM_BW = 4 # TODO remove this temporary solution # pylint: disable=unused-argument, unused-variable @classmethod def apply_adascale( cls, sim: QuantizationSimModel, inputs: Collection[Dict[str, np.ndarray]], adascale_model_config: AdaScaleModelConfig, num_iterations: int = 1500, ): """ :param sim: Quantization Sim model :param inputs: (Collection[Dict[str, np.ndarray]]): The set of input samples to use during optimization. :param adascale_model_config: Adascale model config. There are pre-defined configs for Llama, Qwen2, Mistral, Qwen3, Phi3. For other models use AdaScaleModelConfig :param num_iterations: Number of iterations to optimize for during AdaScale Example usage: >>> model = DummyModel() >>> inputs = ... >>> adascale_model_config = adascale_model_config['llama'] >>> sim = QuantizationSimModel(model) >>> apply_adascale(sim, inputs, adascale_model_config, num_iterations=num_iterations) >>> sim.compute_encodings(...) >>> sim.export(...) .. note:: 1. apply_adascale modifies the weights in-place in the model 2. compute encodings should not be called before the apply_adascale call 3. Activation quantizers will remain uninitialized throughout the feature, and so compute encodings needs to be called by the user afterwards. This is so activation encodings will be computed with updated weights taken into account. Warning: This feature is currently considered experimental pending API changes """ # pylint: disable=protected-access with cls._disable_activation_quantizers(sim): # Compute param encodings sim._compute_param_encodings(overwrite=False) blocks_end_points = get_decoder_blocks_end_points( sim, adascale_model_config.model_type ) device = get_torch_device(sim.session) graph_input_names = [inp.name for inp in sim.session.get_inputs()] if graph_input_names != list(inputs[0].keys()): raise ValueError( "Graph input names do not match the keys in the provided inputs." ) # create a list of common input names to be used for graph slicing and populating input_list common_input_names = [] for name in graph_input_names: if "attention" in name: common_input_names.append(name) if "position" in name: common_input_names.append(name) del sim.session gc.collect() torch.cuda.empty_cache() with tempfile.TemporaryDirectory() as tempdir: fp32_model = copy.deepcopy(sim.model.model) fp32_model = QuantizationSimModel.remove_quantizers(fp32_model) model_path = os.path.join(tempdir, "model.onnx") onnx.save_model( fp32_model, model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=Path(model_path).name + ".data", ) for idx in range(len(blocks_end_points)): if ( _DEBUG_NUM_PARTIAL_ITERATIONS is not None and idx >= _DEBUG_NUM_PARTIAL_ITERATIONS ): break if ( _DEBUG_NUM_PARTIAL_ITERATIONS_START is not None and _DEBUG_NUM_PARTIAL_ITERATIONS_END is not None and ( idx < _DEBUG_NUM_PARTIAL_ITERATIONS_START or idx >= _DEBUG_NUM_PARTIAL_ITERATIONS_END ) ): continue _logger.info("Optimizing block: %d", idx) # Query only the past_key/past_val for a given block block_kv_tensor_names = [] for name in graph_input_names: if ( f"past_key_{idx}_in" in name or f"past_value_{idx}_in" in name ): block_kv_tensor_names.append(name) block_input_names = common_input_names if len(block_kv_tensor_names) > 0: if len(block_kv_tensor_names) != 2: raise RuntimeError( f"Unable to find both past_key and past_value for block {idx}." ) block_input_names.extend(block_kv_tensor_names) qsim_sess = ActivationSampler( blocks_end_points[idx][0].inputs[0].name, sim.model.model, sim.providers, ) fp_inputs, qsim_inputs = [], [] for input in inputs: # pylint: disable=redefined-builtin qsim_inputs.append(qsim_sess.sample_acts(input)) qsim_sess.restore_graph() del qsim_sess fp32_sampler = ActivationSampler( blocks_end_points[idx][0].inputs[0].name, model_path, sim.providers, tempdir, ) for input in inputs: fp_inputs.append(fp32_sampler.sample_acts(input)) fp32_sampler.restore_graph() del fp32_sampler fp_input_list = [] qsim_input_list = [] for i in range(len(fp_inputs)): fp_list, qsim_list = [], [] fp_list.append(fp_inputs[i]) qsim_list.append(qsim_inputs[i]) for name in block_input_names: fp_list.append(inputs[i][name]) qsim_list.append(inputs[i][name]) fp_input_list.append(fp_list) qsim_input_list.append(qsim_list) block_input_output_names = AdaScale.get_block_start_end_name( blocks_end_points, idx, block_input_names ) AdaScale.optimize_adascale_block( sim, fp_input_list, qsim_input_list, block_input_output_names, adascale_model_config.beta_gamma_lr, adascale_model_config.scales_lr, num_iterations, device, ) del fp_input_list, qsim_input_list, fp_inputs, qsim_inputs gc.collect() torch.cuda.empty_cache() sim._rebuild_session() # pylint: disable=protected-access @staticmethod def get_block_start_end_name( blocks_end_points: List[Tuple], block_idx: int, input_list_names: List[str] ) -> Tuple[List[str], List[str]]: block_inputs = [blocks_end_points[block_idx][0].inputs[0].name] block_input_names = block_inputs + input_list_names block_output_names = [blocks_end_points[block_idx][1].inputs[0].name] return block_input_names, block_output_names @staticmethod def optimize_adascale_block( sim: QuantizationSimModel, fp_inputs: List[np.ndarray], quantized_inputs: List[np.ndarray], block_input_output_names: Tuple[List[str], List[str]], beta_gamma_lr: float = 1e-3, scales_lr: float = 5e-4, num_iterations: int = 1500, device: torch.device = torch.device("cpu"), ): """ :param sim: QuantizationSimModel object created using the fp32 model :param fp_inputs: List of input tensors to the block :param quantized_inputs: List of quantized input tensors to the block :param block_input_output_names: Tuple of list of input and output tensor names to the block :param beta_gamma_lr: learning rate to use for beta/gamma params :param scales_lr: learning rate to use for scales params :param num_iterations: Number of iterations to optimize for during AdaScale :param device: torch device to use for optimization This API performs adascale on the block through the following steps: - Using the block input and output tensor names, get the onnx block - Convert the above onnx block to a pytroch module - Apply AdaScale optimization on the above block using the hyperparameters, fp inputs and quantized inputs passed to the method - Copy back the weights and encodings to the original sim object passed to the method Important points to note: - fp32 model weights should be original model weights - sim would be updated in place with adascaled weights """ pytorch_block, pt_weights_to_onnx_initializers = get_pt_block( copy.deepcopy(sim.model.model), block_input_output_names ) pytorch_block.requires_grad_(False) torch_fp_input = convert_to_torch(fp_inputs) torch_quant_input = convert_to_torch(quantized_inputs) pytorch_block.to(device) fp_out = [] with torch.no_grad(): for input_tensor in torch_fp_input: if isinstance(input_tensor, torch.Tensor): input_tensor = [input_tensor] input_tensor = [ inp_t.to(device=device) for inp_t in input_tensor ] # Create a new tensor out = pytorch_block(*input_tensor).detach() out.requires_grad_(False) fp_out.append(change_tensor_device_placement(out, torch.device("cpu"))) pytorch_block = add_qlinear_layers( pytorch_block, bitwidth=AdaScale.ADASCALE_PARAM_BW ) replace_with_adascale_quantizers(pytorch_block) # only set adascale params to train mode all_beta_gamma_parameters, all_scale_parameters = get_adascale_trainable_params( pytorch_block ) adascale_params = all_beta_gamma_parameters + all_scale_parameters for p in adascale_params: p.requires_grad = True trainable_params = [ { "params": all_beta_gamma_parameters, "lr": beta_gamma_lr, }, { "params": all_scale_parameters, "lr": scales_lr, }, ] optimizer = torch.optim.Adam(trainable_params) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_iterations, eta_min=0.0 ) gc.collect() torch.cuda.empty_cache() pytorch_block.to(device) with torch.set_grad_enabled(True): for iteration in tqdm.tqdm(range(num_iterations)): fp_input = torch_fp_input[iteration % len(torch_fp_input)] quant_input = torch_quant_input[iteration % len(torch_quant_input)] if _QT_SAMPLING_PROB == 1.0: input_tensor = quant_input elif _QT_SAMPLING_PROB == 0.0: input_tensor = fp_input else: input_tensor = quant_input input_tensor[0] = torch.where( torch.rand_like(quant_input[0], dtype=quant_input[0].dtype).to( device=device ) < _QT_SAMPLING_PROB, quant_input[0].to(device=device), fp_input[0].to(device=device), ) if isinstance(input_tensor, torch.Tensor): input_tensor = [input_tensor] input_tensor = [ inp_t.to(device=device) for inp_t in input_tensor ] # Create a new tensor quant_out = pytorch_block(*input_tensor) batch_fp_out = fp_out[iteration % len(torch_fp_input)].to(device) loss = _LOSS_FN( quant_out, batch_fp_out, ) loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad() del quant_out, batch_fp_out, loss, input_tensor, fp_input, quant_input copy_pt_weights_to_onnx( pytorch_block, sim.model.model, pt_weights_to_onnx_initializers ) copy_pt_encodings_to_sim(pytorch_block, sim, pt_weights_to_onnx_initializers) del ( pytorch_block, torch_quant_input, torch_fp_input, optimizer, pt_weights_to_onnx_initializers, fp_out, fp_inputs, quantized_inputs, ) @staticmethod @contextlib.contextmanager def _disable_activation_quantizers(qsim): """ Disable activation quantizers :param qsim: Quantization simulator """ enabled_activation_quantizers = [ name for name in qsim.activation_names if qsim.qc_quantize_op_dict[name].enabled ] try: for name in enabled_activation_quantizers: qsim.qc_quantize_op_dict[name].enabled = False yield qsim finally: for name in enabled_activation_quantizers: qsim.qc_quantize_op_dict[name].enabled = True apply_adascale = AdaScale.apply_adascale apply_blocklevel_optimization = AdaScale.optimize_adascale_block