Skip to content

Commit 2447937

Browse files
committed
Add prompt_to_lora_id_mapping adjustment in fix_prompts()
Signed-off-by: Jou-An Chen <[email protected]>
1 parent 2ba81dc commit 2447937

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,32 +192,53 @@ def get_input_prompts(prompt: str, prompts_txt_file_path: str) -> List[str]:
192192
return prompt
193193

194194

195-
def fix_prompts(prompt: List[str], batch_size: int, full_batch_size: int = None):
195+
def fix_prompts(
196+
prompt: List[str],
197+
batch_size: int,
198+
prompt_to_lora_id_mapping: Optional[List[int]] = None,
199+
full_batch_size: int = None,
200+
):
196201
"""
197-
Adjusts the list of prompts to match the required batch size.
202+
Adjusts the list of prompts and prompt_to_lora_id_mapping to match the required batch size.
198203
199204
``Mandatory`` Args:
200205
prompt (List[str]): List of input prompts.
201206
batch_size (int): The batch size to process at a time.
202207
203208
``Optional`` Args:
209+
prompt_to_lora_id_mapping (Optional[List[int]]): Mapping to associate prompts with their respective LoRA adapter.
204210
full_batch_size (Optional[int]): The full batch size if different from batch_size.
205211
206212
Returns:
207213
List[str]: Adjusted list of prompts.
214+
List[str]: Adjusted list of prompt_to_lora_id_mapping.
208215
"""
209216
exec_batch_size = full_batch_size if full_batch_size is not None else batch_size
210217

211218
if len(prompt) < exec_batch_size:
212219
logger.warning("Number of prompts are less than batch size/full batch size, repeating to required batch size")
213220
prompt = (prompt * (exec_batch_size // len(prompt) + 1))[:exec_batch_size]
221+
if prompt_to_lora_id_mapping is not None:
222+
logger.warning(
223+
"Prompt_to_lora_id_mapping are less than batch size/full batch size, repeating to required batch size"
224+
)
225+
prompt_to_lora_id_mapping = (
226+
prompt_to_lora_id_mapping * (exec_batch_size // len(prompt_to_lora_id_mapping) + 1)
227+
)[:exec_batch_size]
214228
elif full_batch_size is None and len(prompt) % batch_size != 0:
215229
logger.warning(
216230
"Number of prompts are not multiple of batch size, dropping last incomplete batch from given input prompts"
217231
)
218232
prompt = prompt[: batch_size * (len(prompt) // batch_size)]
233+
if prompt_to_lora_id_mapping is not None:
234+
logger.warning(
235+
"prompt_to_lora_id_mapping are not multiple of batch size, dropping last incomplete batch from given input prompts"
236+
)
237+
prompt_to_lora_id_mapping = prompt_to_lora_id_mapping[
238+
: batch_size * (len(prompt_to_lora_id_mapping) // batch_size)
239+
]
219240

220-
return prompt
241+
return prompt, prompt_to_lora_id_mapping
221242

222243

223244
def read_prompts_txt_file(prompts_txt_file_path: str):
@@ -311,7 +332,7 @@ def cloud_ai_100_exec_kv(
311332
"""
312333
batch_size, ctx_len, full_batch_size = get_compilation_dims(qpc_path)
313334
prompt: List[str] = get_input_prompts(prompt, prompts_txt_file_path)
314-
prompt = fix_prompts(prompt, batch_size, full_batch_size)
335+
prompt, prompt_to_lora_id_mapping = fix_prompts(prompt, batch_size, prompt_to_lora_id_mapping, full_batch_size)
315336
generate_text = TextGeneration(
316337
tokenizer=tokenizer,
317338
qpc_path=qpc_path,

0 commit comments

Comments
 (0)