This page give you an overview about the all the APIs that you might need to integrate the QEfficient
into your python applications.
High Level API
QEFFAutoModelForCausalLM
- class QEfficient.transformers.models.modeling_auto.QEFFAutoModelForCausalLM(model: Module, continuous_batching: bool = False, **kwargs)[source]
The QEFF class is designed for manipulating any causal language model from the HuggingFace hub. Although it is possible to initialize the class directly, we highly recommend using the
from_pretrained
method for initialization.Mandatory
Args:- model (nn.Module):
PyTorch model
- continuous_batching (bool):
Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
from QEfficient import QEFFAutoModelForCausalLM model = QEFFAutoModelForCausalLM.from_pretrained(model_name, num_hidden_layers=2) model.compile(prefill_seq_len=32, ctx_len=1024) model.generate(prompts=["Hi there!!"])
- classmethod from_pretrained(pretrained_model_name_or_path, continuous_batching: bool = False, *args, **kwargs)[source]
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. Once the model is initialized, you can use other methods such as export, compile, and generate on the same object.
- Args:
- pretrained_name_or_path (str):
Model card name from HuggingFace or local path to model directory.
- continuous_batching (bool):
Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
- args, kwargs:
Additional arguments to pass to transformers.AutoModelForCausalLM.
from QEfficient import QEFFAutoModelForCausalLM # Initialize the model using from_pretrained similar to transformers.AutoModelForCausalLM model = QEFFAutoModelForCausalLM.from_pretrained("gpt2") # Now you can directly compile the model for Cloud AI 100 model.compile(num_cores=14, device_group=[0]) # Considering you have a Cloud AI 100 Standard SKU # You can now execute the model model.generate(prompts=["Hi there!!"])
- export(export_dir: str | None = None) str [source]
Exports the model to
ONNX
format usingtorch.onnx.export
. We currently don’t support exporting non-transformed models. Please refer to theconvert_to_cloud_bertstyle
function in the Low-Level API for a legacy function that supports this.”Optional
Args:does not any arguments.
- Returns:
- str:
Path of the generated
ONNX
graph.
- compile(onnx_path: str | None = None, compile_dir: str | None = None, *, prefill_seq_len: int = 32, ctx_len: int = 128, batch_size: int = 1, full_batch_size: int | None = None, num_devices: int = 1, num_cores: int = 16, mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, **compiler_options) str [source]
This method compiles the exported
ONNX
model using the Cloud AI 100 Platform SDK compiler binary found at/opt/qti-aic/exec/qaic-exec
and generates aqpc
package. If the model has not been exported yet, this method will handle the export process. You can pass any other arguments that the qaic-exec takes as extra kwargs.Optional
Args:- onnx_path (str, optional):
Path to pre-exported onnx model.
- compile_dir (str, optional):
Path for saving the qpc generated.
- num_cores (int):
Number of cores used to compile the model.
- num_devices (List[int]):
Number of devices for tensor-slicing is invoked, defaults to None, and automatically chooses suitable device.
- batch_size (int, optional):
Batch size.
Defaults to 1
.- prefill_seq_len (int, optional):
The length of the Prefill prompt should be less that
prefill_seq_len
.Defaults to 32
.- ctx_len (int, optional):
Maximum
ctx
that the compiled model can remember.Defaults to 128
.- full_batch_size (int, optional):
Continuous batching batch size.
- mxfp6_matmul (bool, optional):
Whether to use
mxfp6
compression for weights.Defaults to True
.- mxint8_kv_cache (bool, optional):
Whether to use
mxint8
compression for KV cache.Defaults to False
.- mos (int, optional):
Effort level to reduce on-chip memory. Defaults to -1, meaning no effort.
Defaults to -1
.- aic_enable_depth_first (bool, optional):
Enables DFS with default memory size.
Defaults to False
.
- Returns:
- str:
Path of the compiled
qpc
package.
- generate(tokenizer: PreTrainedTokenizerFast | PreTrainedTokenizer, prompts: List[str], device_id: List[int] | None = None, runtime: str = 'AI_100', **kwargs)[source]
This method generates output until
eos
orgeneration_len
by executing the compiledqpc
onCloud AI 100
Hardware cards. This is a sequential execution based on thebatch_size
of the compiled model and the number of prompts passed. If the number of prompts cannot be divided by thebatch_size
, the last unfulfilled batch will be dropped.Mandatory
Args:- prompts (List[str]):
List of prompts to run the execution.
- device_id (List[int]):
Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
optional
Args:- runtime (str, optional):
Only
AI_100
runtime is supported as of now;ONNXRT
andPyTorch
coming soon. Defaults to “AI_100”.
QEffAutoPeftModelForCausalLM
- class QEfficient.peft.auto.QEffAutoPeftModelForCausalLM(model: Module)[source]
QEff class for loading models with PEFT adapters (Only LoRA is supported currently). Once exported and compiled for an adapter, the same can be utilized for another adapter with same base model and adapter config.
- Args:
- model (nn.Module):
PyTorch model
from QEfficient import QEffAutoPeftModelForCausalLM m = QEffAutoPeftModelForCausalLM.from_pretrained("predibase/magicoder", "magicoder") m.export() m.compile(prefill_seq_len=32, ctx_len=1024) inputs = ... # A coding prompt outputs = m.generate(**inputs) inputs = ... # A math prompt m.load_adapter("predibase/gsm8k", "gsm8k") m.set_adapter("gsm8k") outputs = m.generate(**inputs)
- load_adapter(model_id: str, adapter_name: str)[source]
Loads a new adapter from huggingface hub or local path
- Args:
- model_id (str):
Adapter model ID from huggingface hub or local path
- adapter_name (str):
Adapter name to be used to set this adapter as current
- property active_adapter: str
Currently active adapter to be used for inference
- classmethod from_pretrained(pretrained_name_or_path: str, *args, **kwargs)[source]
- Args:
- pretrained_name_or_path (str):
Model card name from huggingface or local path to model directory.
- finite_adapters (bool):
set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification.
- adapter_name (str):
Name used to identify loaded adapter.
- args, kwargs:
Additional arguments to pass to peft.AutoPeftModelForCausalLM.
- export(export_dir: str | None = None) str [source]
Exports the model to
ONNX
format usingtorch.onnx.export
.- Args:
- export_dir (str):
Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
- Returns:
- Path:
Path of the generated
ONNX
file.
- compile(onnx_path: str | None = None, compile_dir: str | None = None, *, batch_size: int = 1, prefill_seq_len: int, ctx_len: int, num_devices: int = 1, num_cores: int = 16, mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, **compiler_options) str [source]
Compile the exported onnx to run on AI100. If the model has not been exported yet, this method will handle the export process.
- Args:
- onnx_path (str):
Onnx file to compile
- compile_dir (str):
Directory path to compile the qpc. A suffix is added to the directory path to avoid reusing same qpc for different parameters.
- num_devices (int):
Number of devices to compile for.
Defaults to 1
.- num_cores (int):
Number of cores to utilize in each device
Defaults to 16
.- mxfp6_matmul (bool):
Use MXFP6 to compress weights for MatMul nodes to run faster on device.
Defaults to False
.- mxint8_kv_cache (bool):
Use MXINT8 to compress KV-cache on device to access and update KV-cache faster.
Defaults to False
.- compiler_options:
Pass any compiler option as input. Any flag that is supported by
qaic-exec
can be passed. Params are converted to flags as below: - aic_num_cores=16 -> -aic-num-cores=16 - convert_to_fp16=True -> -convert-to-fp16
QEFFAutoModelForCausalLM
Args:- full_batch_size (int):
Full batch size to allocate cache lines.
- batch_size (int):
Batch size to compile for.
Defaults to 1
.- prefill_seq_len (int):
Prefill sequence length to compile for. Prompt will be chunked according to this length.
- ctx_len (int):
Context length to allocate space for KV-cache tensors.
- Returns:
- str:
Path of the compiled
qpc
package.
- generate(inputs: Tensor | ndarray | None = None, device_ids: List[int] | None = None, generation_config: GenerationConfig | None = None, stopping_criteria: StoppingCriteria | None = None, streamer: BaseStreamer | None = None, **kwargs) ndarray [source]
Generate tokens from compiled binary. This method takes same parameters as HuggingFace transformers model.generate() method.
- Args:
- inputs:
input_ids
- generation_config:
Merge this generation_config with model-specific for the current generation.
- stopping_criteria:
Pass custom stopping_criteria to stop at a specific point in generation.
- streamer:
Streamer to put the generated tokens into.
- kwargs:
Additional parameters for generation_config or to be passed to the model while generating.
QEffAutoLoraModelForCausalLM
- class QEfficient.peft.lora.auto.QEffAutoLoraModelForCausalLM(model: Module, continuous_batching: bool = False, **kwargs)[source]
QEff class for loading models with multiple LoRA adapters. Currently only Mistral and Llama model are supported. Once exported and compiled, the qpc can perform mixed batch inference with provided prompt_to_adapter_mapping.
- Args:
- model (nn.Module):
PyTorch model
- continuous_batching (bool):
Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM m = QEffAutoPeftModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") m.load_adapter("predibase/gsm8k", "gsm8k") m.load_adapter("predibase/magicoder", "magicoder") m.compile(num_cores=16, device_group=[0]) prompts=["code prompt", "math prompt", "generic"] m.generate(prompts, device_group=[0], prompt_to_adapter_mapping=["magicoder","gsm8k_id","base"])
- download_adapter(adapter_model_id: str, adapter_name: str, adapter_weight: dict | None = None, adapter_config: PeftConfig | None = None)[source]
Loads a new adapter from huggingface hub or local path into CPU cache
Mandatory
Args:- adapter_model_id (str):
Adapter model ID from huggingface hub or local path
- adapter_name (str):
Adapter name to be used to downloaded this adapter
Optional
Args:- adapter_weight (dict):
Adapter weight tensors in dictionary format
- adapter_config (PeftConfig):
Adapter config in the format of PeftConfig
- load_adapter(adapter_model_id: str, adapter_name: str, adapter_weight: dict | None = None, adapter_config: PeftConfig | None = None)[source]
Load adapter into CPU cache and set it as active
Mandatory
Args:- adapter_model_id (str):
Adapter model ID from huggingface hub or local path
- adapter_name (str):
Adapter name to be used to load this adapter
Optional
Args:- adapter_weight (dict):
Adapter weight tensors in dictionary format
- adapter_config (PeftConfig):
Adapter config in the format of PeftConfig
- unload_adapter(adapter_name: str)[source]
Deactivate adpater and remove it from CPU cache
Mandatory
Args:- adapter_name (str):
Adapter name to be unloaded
- export(export_dir: str | None = None) str [source]
Exports the model to
ONNX
format usingtorch.onnx.export
. We currently don’t support exporting non-transformed models. Please refer to theconvert_to_cloud_bertstyle
function in the Low-Level API for a legacy function that supports this.”Optional
Args:does not any arguments.
- Returns:
- str:
Path of the generated
ONNX
graph.
- generate(tokenizer: PreTrainedTokenizerFast | PreTrainedTokenizer, prompts: List[str], device_id: List[int] | None = None, prompt_to_adapter_mapping: List[str] | None = None, runtime: str = 'AI_100', **kwargs)[source]
This method generates output until
eos
orgeneration_len
by executing the compiledqpc
onCloud AI 100
Hardware cards. This is a sequential execution based on thebatch_size
of the compiled model and the number of prompts passed. If the number of prompts cannot be divided by thebatch_size
, the last unfulfilled batch will be dropped.Mandatory
Args:- tokenizer (PreTrainedTokenizerFast or PreTrainedTokenizer):
The tokenizer used in the inference
- prompts (List[str]):
List of prompts to run the execution.
- device_id (List[int]):
Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
- prompt_to_adapter_mapping (List[str]):
The sequence of the adapter names will be matched with sequence of prompts and corresponding adapters will be used for the prompts.”base” for base model (no adapter).
optional
Args:- runtime (str, optional):
Only
AI_100
runtime is supported as of now;ONNXRT
andPyTorch
coming soon. Defaults to “AI_100”.
export
- QEfficient.exporter.export_hf_to_cloud_ai_100.qualcomm_efficient_converter(model_name: str, model_kv: QEFFBaseModel | None = None, local_model_dir: str | None = None, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None, cache_dir: str | None = None, onnx_dir_path: str | None = None, hf_token: str | None = None, seq_length: int = 32, kv: bool = True, form_factor: str = 'cloud', full_batch_size: int | None = None) Tuple[str, str] [source]
This method is an alias for
QEfficient.export
.Usage 1: This method can be used by passing
model_name
andlocal_model_dir
orcache_dir
if required for loading from local dir. This will download the model fromHuggingFace
and export it toONNX
graph and returns generated files path check below.Usage 2: You can pass
model_name
andmodel_kv
as an object ofQEfficient.QEFFAutoModelForCausalLM
, In this case will directly export themodel_kv.model
toONNX
We will be deprecating this function and it will be replaced by
QEFFAutoModelForCausalLM.export
.Mandatory
Args:- model_name (str):
The name of the model to be used.
Optional
Args:- model_kv (torch.nn.Module):
Transformed
KV torch model
to be used.Defaults to None
.- local_model_dir (str):
Path of local model.
Defaults to None
.- tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]):
Model tokenizer.
Defaults to None
.- cache_dir (str):
Path of the
cache
directory.Defaults to None
.- onnx_dir_path (str):
Path to store
ONNX
file.Defaults to None
.- hf_token (str):
HuggingFace token to access gated models.
Defaults is None
.- seq_len (int):
The length of the sequence.
Defaults is 128
.- kv (bool):
If false, it will export to Bert style.
Defaults is True
.- form_factor (str):
Form factor of the hardware, currently only
cloud
is accepted.Defaults to cloud
.
- Returns:
- Tuple[str, str]:
Path to Base
ONNX
dir and path to generatedONNX
model
import QEfficient base_path, onnx_model_path = QEfficient.export(model_name="gpt2")
Deprecated since version This: function will be deprecated in version 1.19, please use QEFFAutoModelForCausalLM.export instead
compile
- QEfficient.compile.compile_helper.compile(onnx_path: str, qpc_path: str, num_cores: int, device_group: List[int] | None = None, aic_enable_depth_first: bool = False, mos: int = -1, batch_size: int = 1, prompt_len: int = 32, ctx_len: int = 128, mxfp6: bool = True, mxint8: bool = False, custom_io_file_path: str | None = None, full_batch_size: int | None = None, **kwargs) str [source]
Compiles the given
ONNX
model using Cloud AI 100 platform SDK compiler and saves the compiledqpc
package atqpc_path
. Generates tensor-slicing configuration if multiple devices are passed indevice_group
.This function will be deprecated soon and will be replaced by
QEFFAutoModelForCausalLM.compile
.Mandatory
Args:- onnx_path (str):
Generated
ONNX
Model Path.- qpc_path (str):
Path for saving compiled qpc binaries.
- num_cores (int):
Number of cores to compile the model on.
Optional
Args:- device_group (List[int]):
Used for finding the number of devices to compile for.
Defaults to None.
- aic_enable_depth_first (bool):
Enables
DFS
with default memory size.Defaults to False.
- mos (int):
Effort level to reduce the on-chip memory.
Defaults to -1.
- batch_size (int):
Batch size to compile the model for.
Defaults to 1.
- full_batch_size (int):
Set full batch size to enable continuous batching mode.
Default to None
- prompt_len (int):
Prompt length for the model to compile.
Defaults to 32
- ctx_len (int):
Maximum context length to compile the model.
Defaults to 128
- mxfp6 (bool):
Enable compilation for
MXFP6
precision.Defaults to True.
- mxint8 (bool):
Compress Present/Past KV to
MXINT8
usingCustomIO
config.Defaults to False.
- custom_io_file_path (str):
Path to
customIO
file (formatted as a string).Defaults to None.
- Returns:
- str:
Path to compiled
qpc
package.
import QEfficient
base_path, onnx_model_path = QEfficient.export(model_name="gpt2")
qpc_path = QEfficient.compile(onnx_path=onnx_model_path, qpc_path=os.path.join(base_path, "qpc"), num_cores=14, device_group=[0])
Deprecated since version This: function will be deprecated in version 1.19, please use QEFFAutoModelForCausalLM.compile instead
Execute
- class QEfficient.generation.text_generation_inference.CloudAI100ExecInfo(batch_size: int, generated_texts: List[str] | List[List[str]], generated_ids: List[ndarray] | ndarray, prefill_time: float, decode_perf: float, total_perf: float, total_time: float)[source]
Bases:
object
Holds all the information about Cloud AI 100 execution
- Args:
- batch_size (int):
Batch size of the QPC compilation.
- generated_texts (Union[List[List[str]], List[str]]):
Generated text(s).
- generated_ids (Union[List[np.ndarray], np.ndarray]):
Generated IDs.
- prefill_time (float):
Time for prefilling.
- decode_perf (float):
Decoding performance.
- total_perf (float):
Total performance.
- total_time (float):
Total time.
- QEfficient.generation.text_generation_inference.cloud_ai_100_exec_kv(tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, qpc_path: str, prompt: str | None = None, prompts_txt_file_path: str | None = None, device_id: List[int] | None = None, generation_len: int | None = None, enable_debug_logs: bool = False, stream: bool = True, write_io_dir: str | None = None, automation=False, prompt_to_lora_id_mapping: List[int] | None = None)[source]
This method generates output until
eos
orgeneration_len
by executing the compiledqpc
onCloud AI 100
Hardware cards. This is a sequential execution based on thebatch_size
of the compiled model and the number of prompts passed. If the number of prompts cannot be divided by thebatch_size
, the last unfulfilled batch will be dropped.Mandatory
Args:- tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]):
Model tokenizer.
- qpc_path (str):
Path to the saved generated binary file after compilation.
Optional
Args:- prompt (str):
Sample prompt for the model text generation.
Defaults to None
.- prompts_txt_file_path (str):
Path of the prompt text file.
Defaults to None
.- generation_len (int):
Maximum context length for the model during compilation.
Defaults to None
.- device_id (List[int]):
Device IDs to be used for execution. If
len(device_id) > 1
, it enables multiple card setup. IfNone
, auto-device-picker will be used.Defaults to None
.- enable_debug_logs (bool):
If True, it enables debugging logs.
Defaults to False
.- stream (bool):
If True, enable streamer, which returns tokens one by one as the model generates them.
Defaults to True
.- Write_io_dir (str):
Path to write the input and output files.
Defaults to None
.- automation (bool):
If true, it prints input, output, and performance stats.
Defaults to False
.
- Returns:
- CloudAI100ExecInfo:
Object holding execution output and performance details.
import transformers import QEfficient base_path, onnx_model_path = QEfficient.export(model_name="gpt2") qpc_path = QEfficient.compile(onnx_path=onnx_model_path, qpc_path=os.path.join(base_path, "qpc"), num_cores=14, device_group=[0]) tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2") execinfo = QEfficient.cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc_path=qpc_path, prompt="Hi there!!", device_id=[0])
- QEfficient.generation.text_generation_inference.fix_prompts(prompt: List[str], batch_size: int, full_batch_size: int | None = None)[source]
Adjusts the list of prompts to match the required batch size.
Mandatory
Args:prompt (List[str]): List of input prompts. batch_size (int): The batch size to process at a time.
Optional
Args:full_batch_size (Optional[int]): The full batch size if different from batch_size.
- Returns:
List[str]: Adjusted list of prompts.