# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
import hashlib
import logging
import warnings
from typing import List, Optional, Union
import numpy as np
import torch
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM, load_peft_weights
from torch import nn
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList
from transformers.generation.streamers import BaseStreamer
from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform
from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform
from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform
from QEfficient.utils import constants
from QEfficient.utils._utils import get_padding_shape_from_config
from QEfficient.utils.cache import to_hashable
logger = logging.getLogger(__name__)
[docs]class QEffAutoPeftModelForCausalLM(QEFFBaseModel):
"""
QEff class for loading models with PEFT adapters (Only LoRA is supported currently).
Once exported and compiled for an adapter, the same can be utilized for another adapter with same base model and adapter config.
Args:
:model (nn.Module): PyTorch model
.. code-block:: python
from QEfficient import QEffAutoPeftModelForCausalLM
m = QEffAutoPeftModelForCausalLM.from_pretrained("predibase/magicoder", "magicoder")
m.export()
m.compile(prefill_seq_len=32, ctx_len=1024)
inputs = ... # A coding prompt
outputs = m.generate(**inputs)
inputs = ... # A math prompt
m.load_adapter("predibase/gsm8k", "gsm8k")
m.set_adapter("gsm8k")
outputs = m.generate(**inputs)
"""
_pytorch_transforms: List[PytorchTransform] = [CustomOpsTransform, KVCacheTransform, PeftModelInputsTransform]
_onnx_transforms: List[OnnxTransform] = [FP16ClipTransform, AdapterWeightsToInputsTransform, SplitTensorsTransform]
_hf_auto_class = AutoPeftModelForCausalLM
def __init__(self, model: nn.Module):
if not isinstance(model, PeftModelForCausalLM):
raise TypeError(f"Required pytorch module of type PeftModel, got {type(model)}")
if model.active_peft_config.peft_type != "LORA":
raise NotImplementedError("Only LoRA models are supported")
super().__init__(model)
self.num_layers = model.config.num_hidden_layers
self.exported_peft_config = None
self.adapter_weights = {
adapter_name: {
name.replace(f".{adapter_name}.weight", ".weight"): param.detach().numpy().astype("float16")
for name, param in model.named_parameters()
if name.endswith(f".{adapter_name}.weight")
}
for adapter_name in model.peft_config
}
@property
def model_name(self) -> str:
mname = self.model.get_base_model().__class__.__name__ + "-lora"
if mname.startswith("QEff"):
mname = mname[4:]
return mname
@property
def model_hash(self) -> str:
# NOTE: model_config.to_diff_dict() has "_name_or_path" attribute which is the model card name or path.
# Using same card name will result in same hash. But, using a relative path for one run and
# absolute path for another run will result in different hash.
# The added complexity to resolve different paths to same location is not worth pursuing.
# Instead, advise the user to always provide same relative paths or absolute paths for local models.
# Compute the hash with: model_config, peft_config, transforms
mhash = hashlib.sha256()
mhash.update(to_hashable(self.model.get_base_model().config.to_diff_dict()))
mhash.update(to_hashable(self.model.active_peft_config.to_dict()))
mhash.update(to_hashable(self._transform_names()))
mhash = mhash.hexdigest()[:16]
return mhash
[docs] def load_adapter(self, model_id: str, adapter_name: str):
"""Loads a new adapter from huggingface hub or local path
Args:
:model_id (str): Adapter model ID from huggingface hub or local path
:adapter_name (str): Adapter name to be used to set this adapter as current
"""
self.model.load_adapter(model_id, adapter_name)
self.adapter_weights[adapter_name] = {
k: v.numpy().astype("float16") for k, v in load_peft_weights(model_id).items()
}
@property
def active_adapter(self) -> str:
"Currently active adapter to be used for inference"
return self.model.active_adapter
[docs] def set_adapter(self, adapter_name: str):
"Sets active adapter from one of the loaded adapters"
if self.exported_peft_config is not None and self.exported_peft_config != self.model.peft_config[adapter_name]:
raise ValueError(
"Unable to activate incompatible adapter. "
"Use an adapter compatible with export-time adapter "
"or re-export with this adapter"
)
self.model.set_adapter(adapter_name)
def disable_adapter(self):
# TODO: Set zero tensors as adapter weights
raise NotImplementedError("Disabling adapters not supported currently")
@classmethod
def _from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
# Base class
model = cls._hf_auto_class.from_pretrained(pretrained_name_or_path, *args, **kwargs)
return cls(model)
[docs] @classmethod
def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
"""
Args:
:pretrained_name_or_path (str): Model card name from huggingface or local path to model directory.
:args, kwargs: Additional arguments to pass to peft.AutoPeftModelForCausalLM.
"""
if kwargs.get("full_batch_size"):
raise NotImplementedError("Continuous batching currently not supported for PEFT models")
if kwargs.get("use_cache") is False:
warnings.warn("Overriding to use_cache=True")
kwargs["use_cache"] = True
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
return obj
[docs] def export(self, export_dir: Optional[str] = None) -> str:
self.exported_peft_config = self.model.active_peft_config
example_shape = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
kv_cache_shape = get_padding_shape_from_config(self.model.config, *example_shape)
example_inputs = {
"input_ids": torch.zeros(example_shape, dtype=torch.int64),
"position_ids": torch.arange(example_shape[1], dtype=torch.int64).view(example_shape),
"past_key_values": [[] for _ in range(self.num_layers)],
}
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"position_ids": {0: "batch_size", 1: "seq_len"},
}
output_names = ["logits"]
for i in range(self.num_layers):
for kv in ["key", "value"]:
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
dynamic_axes[f"past_{kv}.{i}"] = {0: "batch_size", 2: "ctx_len"}
output_names.append(f"past_{kv}.{i}_RetainedState")
return self._export(
example_inputs,
output_names,
dynamic_axes,
export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights
onnx_transform_kwargs={"adapter_name": self.model.active_adapter},
export_dir=export_dir,
)
[docs] def compile(
self,
onnx_path: Optional[str] = None,
compile_dir: Optional[str] = None,
*,
batch_size: int = 1,
prefill_seq_len: int,
ctx_len: int,
num_devices: int = 1,
num_cores: int = 16,
mxfp6_matmul: bool = False,
mxint8_kv_cache: bool = False,
**compiler_options,
) -> str:
# Specializations
specializations = [
{"batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len},
{"batch_size": batch_size, "seq_len": 1, "ctx_len": ctx_len},
]
# Custom IO
custom_io = {}
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
for suffix in ["", "_RetainedState"]:
for i in range(self.num_layers):
for kv in ["key", "value"]:
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
for weight_name in self.adapter_weights[self.active_adapter]:
custom_io[f"{weight_name}{suffix}"] = "float16"
return self._compile(
onnx_path,
compile_dir,
compile_only=True,
retained_state=True,
specializations=specializations,
convert_to_fp16=True,
mxfp6_matmul=mxfp6_matmul,
custom_io=custom_io,
mdp_ts_num_devices=num_devices,
aic_num_cores=num_cores,
**compiler_options,
)
[docs] def generate(
self,
inputs: Optional[Union[torch.Tensor, np.ndarray]] = None,
device_ids: Optional[List[int]] = None,
generation_config: Optional[GenerationConfig] = None,
stopping_criteria: Optional[StoppingCriteria] = None,
streamer: Optional[BaseStreamer] = None,
**kwargs,
) -> np.ndarray:
"""
Generate tokens from compiled binary. This method takes same parameters as HuggingFace transformers model.generate() method.
Args:
:inputs: input_ids
:generation_config: Merge this generation_config with model-specific for the current generation.
:stopping_criteria: Pass custom stopping_criteria to stop at a specific point in generation.
:streamer: Streamer to put the generated tokens into.
:kwargs: Additional parameters for generation_config or to be passed to the model while generating.
"""
# Initialize session
if self.qpc_session is None:
if self.qpc_path is None:
raise FileNotFoundError("Please compile the model with `model.compile(...)`")
self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids)
# Skip buffers
retained_buffers = [x for x in self.qpc_session.output_names if x.endswith("_RetainedState")]
self.qpc_session.skip_buffers([x[: -len("_RetainedState")] for x in retained_buffers])
self.qpc_session.skip_buffers(retained_buffers)
generation_config = generation_config or self.model.generation_config
generation_config, model_kwargs = self.model._prepare_generation_config(generation_config, **kwargs)
self.model._prepare_special_tokens(generation_config)
if generation_config.do_sample:
raise NotImplementedError("do_sample=True not supported currently")
if generation_config.num_beams > 1:
raise NotImplementedError("num_beams>1 not supported currently")
if generation_config.max_new_tokens is None or generation_config.max_new_tokens <= 0:
raise ValueError("Required max_new_tokens>0 value in generation_config")
stopping_criteria = stopping_criteria or StoppingCriteriaList()
stopping_criteria = self.model._get_stopping_criteria(generation_config, stopping_criteria)
if inputs is not None:
inputs = {"input_ids": inputs}
else:
inputs = {}
inputs.update(model_kwargs)
inputs = {k: v.numpy() if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
batch_size = max(
[x[self.qpc_session.binding_index_map["input_ids"]][1][0] for x in self.qpc_session.allowed_shapes]
+ [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[0]]
)
passed_batch_size = inputs["input_ids"].shape[0]
if passed_batch_size != batch_size:
raise ValueError(f"Model compiled for batch_size: {batch_size}, but passed batch_size: {passed_batch_size}")
prefill_seq_len = max(
[x[self.qpc_session.binding_index_map["input_ids"]][1][1] for x in self.qpc_session.allowed_shapes]
+ [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[1]]
)
input_len = inputs["input_ids"].shape[1]
num_chunks = -(input_len // -prefill_seq_len) # Ceil divide without float
padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len
inputs["input_ids"] = np.concatenate(
[inputs["input_ids"], np.zeros((batch_size, padded_len - input_len), dtype=inputs["input_ids"].dtype)], 1
)
next_position_ids = inputs.pop("attention_mask").sum(1, keepdims=True)
inputs["position_ids"] = np.arange(padded_len).reshape(1, -1)
inputs["position_ids"] = np.where(inputs["position_ids"] < next_position_ids, inputs["position_ids"], -1)
generated_ids = np.zeros((batch_size, generation_config.max_new_tokens), dtype="int64")
if streamer:
streamer.put(inputs["input_ids"][:, :input_len])
# Set adapter weights
self.qpc_session.set_buffers(self.adapter_weights[self.active_adapter])
# Run prefill
for i in range(num_chunks):
chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
outputs = self.qpc_session.run(chunk_inputs)
# Get first token
inputs["input_ids"] = outputs["logits"].argmax(2)
inputs["position_ids"] = next_position_ids
generated_ids[:, 0] = inputs["input_ids"].squeeze(1)
if streamer:
streamer.put(inputs["input_ids"])
# Skip adapter weights
self.qpc_session.skip_buffers(list(self.adapter_weights[self.active_adapter]))
# Decode loop
for num_token in range(1, generation_config.max_new_tokens):
if stopping_criteria(torch.from_numpy(inputs["input_ids"]), torch.from_numpy(outputs["logits"])).all():
break
outputs = self.qpc_session.run(inputs)
# Prepare inputs for next iteration
inputs["input_ids"] = outputs["logits"].argmax(2)
inputs["position_ids"] += 1
generated_ids[:, num_token] = inputs["input_ids"].squeeze(1)
if streamer:
streamer.put(inputs["input_ids"])
if streamer:
streamer.end()
return generated_ids