@@ -192,32 +192,53 @@ def get_input_prompts(prompt: str, prompts_txt_file_path: str) -> List[str]:
192
192
return prompt
193
193
194
194
195
- def fix_prompts (prompt : List [str ], batch_size : int , full_batch_size : int = None ):
195
+ def fix_prompts (
196
+ prompt : List [str ],
197
+ batch_size : int ,
198
+ prompt_to_lora_id_mapping : Optional [List [int ]] = None ,
199
+ full_batch_size : int = None ,
200
+ ):
196
201
"""
197
- Adjusts the list of prompts to match the required batch size.
202
+ Adjusts the list of prompts and prompt_to_lora_id_mapping to match the required batch size.
198
203
199
204
``Mandatory`` Args:
200
205
prompt (List[str]): List of input prompts.
201
206
batch_size (int): The batch size to process at a time.
202
207
203
208
``Optional`` Args:
209
+ prompt_to_lora_id_mapping (Optional[List[int]]): Mapping to associate prompts with their respective LoRA adapter.
204
210
full_batch_size (Optional[int]): The full batch size if different from batch_size.
205
211
206
212
Returns:
207
213
List[str]: Adjusted list of prompts.
214
+ List[str]: Adjusted list of prompt_to_lora_id_mapping.
208
215
"""
209
216
exec_batch_size = full_batch_size if full_batch_size is not None else batch_size
210
217
211
218
if len (prompt ) < exec_batch_size :
212
219
logger .warning ("Number of prompts are less than batch size/full batch size, repeating to required batch size" )
213
220
prompt = (prompt * (exec_batch_size // len (prompt ) + 1 ))[:exec_batch_size ]
221
+ if prompt_to_lora_id_mapping is not None :
222
+ logger .warning (
223
+ "Prompt_to_lora_id_mapping are less than batch size/full batch size, repeating to required batch size"
224
+ )
225
+ prompt_to_lora_id_mapping = (
226
+ prompt_to_lora_id_mapping * (exec_batch_size // len (prompt_to_lora_id_mapping ) + 1 )
227
+ )[:exec_batch_size ]
214
228
elif full_batch_size is None and len (prompt ) % batch_size != 0 :
215
229
logger .warning (
216
230
"Number of prompts are not multiple of batch size, dropping last incomplete batch from given input prompts"
217
231
)
218
232
prompt = prompt [: batch_size * (len (prompt ) // batch_size )]
233
+ if prompt_to_lora_id_mapping is not None :
234
+ logger .warning (
235
+ "prompt_to_lora_id_mapping are not multiple of batch size, dropping last incomplete batch from given input prompts"
236
+ )
237
+ prompt_to_lora_id_mapping = prompt_to_lora_id_mapping [
238
+ : batch_size * (len (prompt_to_lora_id_mapping ) // batch_size )
239
+ ]
219
240
220
- return prompt
241
+ return prompt , prompt_to_lora_id_mapping
221
242
222
243
223
244
def read_prompts_txt_file (prompts_txt_file_path : str ):
@@ -311,7 +332,7 @@ def cloud_ai_100_exec_kv(
311
332
"""
312
333
batch_size , ctx_len , full_batch_size = get_compilation_dims (qpc_path )
313
334
prompt : List [str ] = get_input_prompts (prompt , prompts_txt_file_path )
314
- prompt = fix_prompts (prompt , batch_size , full_batch_size )
335
+ prompt , prompt_to_lora_id_mapping = fix_prompts (prompt , batch_size , prompt_to_lora_id_mapping , full_batch_size )
315
336
generate_text = TextGeneration (
316
337
tokenizer = tokenizer ,
317
338
qpc_path = qpc_path ,
0 commit comments