Source code for QEfficient.utils.generate_inputs

# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import numpy as np
import torch

from QEfficient.utils import get_num_layers_from_config, get_padding_shape_from_config, padding_check_and_fix


[docs]class InputHandler: def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size): """ Initialization ``Mandatory`` Args: :batch_size (int): Number of prompts to run in one batch. :tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Pass model tokenizer. :config (AutoConfig): From pretrained model. :prompt (List[str]): String to used as input prompt for the model. :prompt_len (int): Prompt length for the model to compile. :ctx_len (int): Maximum context length to compile the model. :full_batch_size (int): Continuous batching batch size """ # check and fix tokenizer viability padding_check_and_fix(tokenizer) self.tokenizer = tokenizer self.prompt = prompt self.prompt_len = prompt_len self.ctx_len = ctx_len self.full_batch_size = full_batch_size self.n_layer = get_num_layers_from_config(config) self.padding_shape = get_padding_shape_from_config( config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len )
[docs] def prepare_pytorch_inputs(self): """ Function responsible for creating Prefill stage tensor inputs for PyTorch model. Return: :Dict: input_ids, position_ids, past_key_values """ inputs = self.tokenizer( self.prompt, return_tensors="pt", padding=True, ) input_ids = inputs["input_ids"] batch_size, input_len = input_ids.shape inputs.pop("attention_mask") position_ids = torch.arange(input_len).view(1, -1) inputs["input_ids"] = torch.concat( [ input_ids, torch.ones((batch_size, self.prompt_len - input_len), dtype=torch.int64) * (self.tokenizer.pad_token_id), ], 1, ) inputs["position_ids"] = torch.concat( [ position_ids, torch.ones((batch_size, self.prompt_len - input_len), dtype=torch.int64) * (-1), ], 1, ) if self.full_batch_size: inputs["input_ids"] = input_ids inputs["position_ids"] = torch.arange(input_len).view(1, input_len) inputs["batch_index"] = torch.arange(1).view(-1, 1) past_key_values = [] for i in range(self.n_layer): past_key = torch.zeros((self.padding_shape), dtype=torch.float32) past_value = torch.zeros((self.padding_shape), dtype=torch.float32) pkv = (past_key, past_value) past_key_values.append(pkv) inputs["past_key_values"] = tuple(past_key_values) return inputs
[docs] def update_pytorch_inputs(self, inputs, pt_outputs): """ Function responsible for updating Prefill stage inputs to create decode stage inputs for PyTorch model. ``Mandatory`` Args: :inputs (Dict): Pytorch inputs from previous iteration :pt_outputs (Dict): Pytorch outputs from previous iteration Return: :Dict: Updated input_ids, position_ids and past_key_values """ updated_inputs = {} if self.full_batch_size: batch_index = torch.arange(1).view(-1, 1) input_ids = pt_outputs.logits.detach().argmax(2) updated_inputs["input_ids"] = torch.full((self.full_batch_size, 1), self.tokenizer.pad_token_id) updated_inputs["input_ids"][batch_index.view(-1)] = input_ids position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1 updated_inputs["position_ids"] = torch.full((self.full_batch_size, 1), 0) updated_inputs["position_ids"][batch_index.view(-1)] = position_ids updated_inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) else: updated_inputs["input_ids"] = pt_outputs["logits"].argmax(-1).reshape(-1, 1) updated_inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 updated_inputs["past_key_values"] = tuple( [(key.detach(), value.detach()) for key, value in pt_outputs["past_key_values"]] ) return updated_inputs
[docs] def prepare_ort_inputs(self): """ Function responsible for creating Prefill stage numpy inputs for ONNX model to be run on ONNXRT. Return: :Dict: input_ids, position_ids, past_key_values """ inputs = self.tokenizer( self.prompt, return_tensors="np", padding=True, ) input_ids = inputs["input_ids"] batch_size, input_len = input_ids.shape inputs.pop("attention_mask") position_ids = np.arange(input_len).reshape(1, -1) inputs["input_ids"] = np.concatenate( [input_ids, np.full((batch_size, self.prompt_len - input_len), self.tokenizer.pad_token_id)], axis=1, ).astype(np.int64) inputs["position_ids"] = np.concatenate( [position_ids, np.full((batch_size, self.prompt_len - input_len), -1)], axis=1, ).astype(np.int64) for i in range(self.n_layer): inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) return inputs
[docs] def update_ort_inputs(self, inputs, ort_outputs): """ Function responsible for updating Prefill stage inputs to create inputs for decode stage inputs for ONNX model to be run on ONNXRT. ``Mandatory`` Args: :inputs (Dict): NumPy inputs of Onnx model from previous iteration :ort_outputs (Dict): Numpy outputs of Onnx model from previous iteration Return: :Dict: Updated input_ids, position_ids and past_key_values """ updated_inputs = {} updated_inputs["input_ids"] = ort_outputs["logits"].argmax(-1) updated_inputs["position_ids"] = np.max(inputs["position_ids"], axis=1, keepdims=True) + 1 for i in range(self.n_layer): updated_inputs["past_key." + str(i)] = ort_outputs["past_key_values"][i * 2] updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1] return updated_inputs
[docs] def update_ort_outputs(self, ort_outputs): """ Function responsible for updating ONNXRT session outputs. ``Mandatory`` Args: :ort_outputs (Dict): Numpy outputs of Onnx model from current iteration Return: updated_outputs (Dict): Updated past_key_values, logits """ present_key_values = [] for i in range(self.n_layer): if "past_key." + str(i) + "_RetainedState" in ort_outputs: present_key_values.append(ort_outputs["past_key." + str(i) + "_RetainedState"]) if "past_value." + str(i) + "_RetainedState" in ort_outputs: present_key_values.append(ort_outputs["past_value." + str(i) + "_RetainedState"]) outputs = {} outputs["past_key_values"] = present_key_values outputs["logits"] = ort_outputs["logits"] return outputs