Skip to content

Commit cdc387a

Browse files
Add prompt_to_lora_id_mapping adjustment in fix_prompts() (#242)
This is regarding the issue reported in [issue#251](#251) The finite lorax feature failed to execute when the number of prompts provided is less than the full batch size. The solution involves applying the same adjustment strategy for `prompt_to_lora_id_mapping` as used for `prompt` in the `fix_prompts()` function located in `QEfficient/generation/text_generation_inference.py`. Signed-off-by: Jou-An Chen <[email protected]>
1 parent 669df06 commit cdc387a

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,40 @@ def fix_prompts(prompt: List[str], batch_size: int, full_batch_size: int = None)
233233
return prompt
234234

235235

236+
def fix_prompt_to_lora_id_mapping(prompt_to_lora_id_mapping: List[int], batch_size: int, full_batch_size: int = None):
237+
"""
238+
Adjusts the list of prompt_to_lora_id_mapping to match the required batch size.
239+
240+
``Mandatory`` Args:
241+
prompt_to_lora_id_mapping (Optional[List[int]]): Mapping to associate prompts with their respective LoRA adapter.
242+
batch_size (int): The batch size to process at a time.
243+
244+
``Optional`` Args:
245+
full_batch_size (Optional[int]): The full batch size if different from batch_size.
246+
247+
Returns:
248+
List[int]: Adjusted list of prompt_to_lora_id_mapping.
249+
"""
250+
exec_batch_size = full_batch_size if full_batch_size is not None else batch_size
251+
252+
if len(prompt_to_lora_id_mapping) < exec_batch_size:
253+
logger.warning(
254+
"Prompt_to_lora_id_mapping are less than batch size/full batch size, repeating to required batch size"
255+
)
256+
prompt_to_lora_id_mapping = (
257+
prompt_to_lora_id_mapping * (exec_batch_size // len(prompt_to_lora_id_mapping) + 1)
258+
)[:exec_batch_size]
259+
elif full_batch_size is None and len(prompt_to_lora_id_mapping) % batch_size != 0:
260+
logger.warning(
261+
"prompt_to_lora_id_mapping are not multiple of batch size, dropping last incomplete batch from given input prompts"
262+
)
263+
prompt_to_lora_id_mapping = prompt_to_lora_id_mapping[
264+
: batch_size * (len(prompt_to_lora_id_mapping) // batch_size)
265+
]
266+
267+
return prompt_to_lora_id_mapping
268+
269+
236270
def read_prompts_txt_file(prompts_txt_file_path: str):
237271
prompt = []
238272
with open(prompts_txt_file_path, "r") as file:
@@ -325,6 +359,10 @@ def cloud_ai_100_exec_kv(
325359
batch_size, ctx_len, full_batch_size = get_compilation_dims(qpc_path)
326360
prompt: List[str] = get_input_prompts(prompt, prompts_txt_file_path)
327361
prompt = fix_prompts(prompt, batch_size, full_batch_size)
362+
if prompt_to_lora_id_mapping is not None:
363+
prompt_to_lora_id_mapping = fix_prompt_to_lora_id_mapping(
364+
prompt_to_lora_id_mapping, batch_size, full_batch_size
365+
)
328366
generate_text = TextGeneration(
329367
tokenizer=tokenizer,
330368
qpc_path=qpc_path,

0 commit comments

Comments
 (0)