29
29
from dataclasses import dataclass , field
30
30
from datetime import timedelta
31
31
from enum import Enum , auto
32
+ from typing import Dict
32
33
33
34
import numpy as np
34
35
from tqdm import tqdm
35
36
36
37
from lighteval .logging .evaluation_tracker import EvaluationTracker
37
38
from lighteval .metrics .utils .metric_utils import MetricCategory
39
+ from lighteval .models .abstract_model import ModelInfo
38
40
from lighteval .models .model_loader import TransformersModel , load_model
39
41
from lighteval .models .model_output import (
40
42
GenerativeMultiturnResponse ,
43
45
LoglikelihoodSingleTokenResponse ,
44
46
ModelResponse ,
45
47
)
48
+ from lighteval .models .transformers .transformers_model import TransformersModelConfig
49
+ from lighteval .models .utils import _simplify_name
50
+ from lighteval .models .vllm .vllm_model import VLLMModelConfig
46
51
from lighteval .tasks .lighteval_task import LightevalTask , create_requests_from_tasks
47
52
from lighteval .tasks .registry import Registry , taskinfo_selector
48
53
from lighteval .tasks .requests import RequestType , SampleUid
@@ -142,42 +147,49 @@ def __init__(
142
147
"--max_samples WAS SET. THESE NUMBERS ARE ONLY PARTIAL AND SHOULD NOT BE USED FOR COMPARISON UNLESS YOU KNOW WHAT YOU ARE DOING."
143
148
)
144
149
150
+ self .tasks = tasks
151
+ self .model = model
145
152
self .model_config = model_config
146
153
self .evaluation_tracker = evaluation_tracker
147
- self .accelerator , self .parallel_context = self ._init_parallelism_manager ()
148
- self .model = self ._init_model (model_config , model )
149
-
150
- self .evaluation_tracker .general_config_logger .log_model_info (self .model .model_info )
151
- self ._init_tasks_and_requests (tasks = tasks )
154
+ self ._init_parallelism_manager ()
152
155
self ._init_random_seeds ()
156
+
157
+ self .evaluation_tracker .general_config_logger .log_model_info (self ._get_model_info ())
158
+ self ._init_tasks ()
159
+
153
160
# Final results
154
161
self .final_dict : dict = None
155
162
163
+ def _get_model_info (self ):
164
+ if isinstance (self .model_config , (VLLMModelConfig , TransformersModelConfig )):
165
+ # At this point we only need the model name to know the details path
166
+ return ModelInfo (model_name = _simplify_name (self .model_config .pretrained ))
167
+ else :
168
+ return self ._init_model ().model_info
169
+
156
170
def _init_parallelism_manager (self ):
157
- accelerator , parallel_context = None , None
171
+ self . accelerator , self . parallel_context = None , None
158
172
if self .launcher_type == ParallelismManager .ACCELERATE :
159
173
if not is_accelerate_available ():
160
174
raise ValueError ("You are trying to launch an accelerate model, but accelerate is not installed" )
161
- accelerator = Accelerator (kwargs_handlers = [InitProcessGroupKwargs (timeout = timedelta (seconds = 3000 ))])
162
- test_all_gather (accelerator = accelerator )
175
+ self . accelerator = Accelerator (kwargs_handlers = [InitProcessGroupKwargs (timeout = timedelta (seconds = 3000 ))])
176
+ test_all_gather (accelerator = self . accelerator )
163
177
elif self .launcher_type == ParallelismManager .NANOTRON :
164
178
if not is_nanotron_available ():
165
179
raise ValueError ("You are trying to launch a nanotron model, but nanotron is not installed" )
166
180
dist .initialize_torch_distributed ()
167
- parallel_context = ParallelContext (
181
+ self . parallel_context = ParallelContext (
168
182
tensor_parallel_size = self .model_config .lighteval_config .parallelism .tp ,
169
183
pipeline_parallel_size = self .model_config .lighteval_config .parallelism .pp ,
170
184
data_parallel_size = self .model_config .lighteval_config .parallelism .dp ,
171
185
)
172
- test_all_gather (parallel_context = parallel_context )
186
+ test_all_gather (parallel_context = self . parallel_context )
173
187
174
- return accelerator , parallel_context
175
-
176
- def _init_model (self , model_config , model ):
188
+ def _init_model (self ):
177
189
logger .info ("--- LOADING MODEL ---" )
178
- if model_config is not None :
190
+ if self . model_config is not None :
179
191
if self .parallel_context :
180
- return NanotronLightevalModel (
192
+ self . model = NanotronLightevalModel (
181
193
checkpoint_path = os .path .dirname (self .pipeline_parameters .nanotron_checkpoint_path )
182
194
if self .pipeline_parameters .nanotron_checkpoint_path
183
195
else "" ,
@@ -188,46 +200,42 @@ def _init_model(self, model_config, model):
188
200
env_config = self .pipeline_parameters .env_config ,
189
201
)
190
202
else :
191
- return load_model (config = model_config , env_config = self .pipeline_parameters .env_config )
192
- if isinstance (model , TransformersModel ):
193
- return model
194
- else :
195
- return TransformersModel .from_model (
196
- model = model ,
203
+ self .model = load_model (config = self .model_config , env_config = self .pipeline_parameters .env_config )
204
+ if not isinstance (self .model , TransformersModel ):
205
+ self .model = TransformersModel .from_model (
206
+ model = self .model ,
197
207
use_chat_template = self .pipeline_parameters .use_chat_template ,
198
208
env_config = self .pipeline_parameters .env_config ,
199
209
accelerator = self .accelerator ,
200
210
)
211
+ return self .model
201
212
202
- def _init_tasks_and_requests (self , tasks : str ):
213
+ def _init_tasks (self ):
203
214
with local_ranks_zero_first () if self .launcher_type == ParallelismManager .NANOTRON else nullcontext ():
204
215
logger .info ("--- LOADING TASKS ---" )
205
216
registry = Registry (
206
217
cache_dir = self .pipeline_parameters .env_config .cache_dir ,
207
218
custom_tasks = self .pipeline_parameters .custom_tasks_directory ,
208
219
)
209
- task_names_list , fewshots_dict = taskinfo_selector (tasks , registry )
210
- task_dict = registry .get_task_dict (task_names_list )
211
- LightevalTask .load_datasets (list (task_dict .values ()), self .pipeline_parameters .dataset_loading_processes )
212
-
213
- self .evaluation_tracker .task_config_logger .log (task_dict )
214
-
215
- requests , docs = create_requests_from_tasks (
216
- task_dict = task_dict ,
217
- fewshot_dict = fewshots_dict ,
218
- num_fewshot_seeds = self .pipeline_parameters .num_fewshot_seeds ,
219
- lm = self .model ,
220
- max_samples = self .pipeline_parameters .max_samples ,
221
- evaluation_tracker = self .evaluation_tracker ,
222
- use_chat_template = self .pipeline_parameters .use_chat_template ,
223
- system_prompt = self .pipeline_parameters .system_prompt ,
220
+ self .task_names_list , self .fewshots_dict = taskinfo_selector (self .tasks , registry )
221
+ self .task_dict = registry .get_task_dict (self .task_names_list )
222
+ LightevalTask .load_datasets (
223
+ list (self .task_dict .values ()), self .pipeline_parameters .dataset_loading_processes
224
224
)
225
225
226
- self .task_names_list = task_names_list
227
- self .task_dict = task_dict
228
- self .fewshot_dict = fewshots_dict
229
- self .requests = requests
230
- self .docs = docs
226
+ self .evaluation_tracker .task_config_logger .log (self .task_dict )
227
+
228
+ def _init_requests (self ):
229
+ self .requests , self .docs = create_requests_from_tasks (
230
+ task_dict = self .task_dict ,
231
+ fewshot_dict = self .fewshots_dict ,
232
+ num_fewshot_seeds = self .pipeline_parameters .num_fewshot_seeds ,
233
+ lm = self .model ,
234
+ max_samples = self .pipeline_parameters .max_samples ,
235
+ evaluation_tracker = self .evaluation_tracker ,
236
+ use_chat_template = self .pipeline_parameters .use_chat_template ,
237
+ system_prompt = self .pipeline_parameters .system_prompt ,
238
+ )
231
239
232
240
def _init_random_seeds (self ):
233
241
logger .info ("--- INIT SEEDS ---" )
@@ -280,16 +288,37 @@ def evaluate(self):
280
288
except OSError :
281
289
pass
282
290
291
+ @staticmethod
292
+ def _metric_category_to_request_type () -> Dict [MetricCategory , RequestType ]:
293
+ """Maps MetricCategory to their corresponding RequestType."""
294
+ return {
295
+ MetricCategory .TARGET_PERPLEXITY : RequestType .LOGLIKELIHOOD ,
296
+ MetricCategory .PERPLEXITY : RequestType .LOGLIKELIHOOD_ROLLING ,
297
+ MetricCategory .GENERATIVE_SAMPLING : RequestType .GREEDY_UNTIL ,
298
+ MetricCategory .GENERATIVE : RequestType .GREEDY_UNTIL ,
299
+ MetricCategory .GENERATIVE_LOGPROB : RequestType .GREEDY_UNTIL ,
300
+ MetricCategory .MULTICHOICE : RequestType .LOGLIKELIHOOD ,
301
+ MetricCategory .MULTICHOICE_PMI : RequestType .LOGLIKELIHOOD ,
302
+ MetricCategory .MULTICHOICE_ONE_TOKEN : RequestType .LOGLIKELIHOOD_SINGLE_TOKEN ,
303
+ MetricCategory .LLM_AS_JUDGE_MULTI_TURN : RequestType .GREEDY_UNTIL_MULTI_TURN ,
304
+ MetricCategory .LLM_AS_JUDGE : RequestType .GREEDY_UNTIL ,
305
+ }
306
+
307
+ @staticmethod
308
+ def _request_type_to_response () -> Dict [RequestType , type [ModelResponse ]]:
309
+ return {
310
+ RequestType .LOGLIKELIHOOD : LoglikelihoodResponse ,
311
+ RequestType .LOGLIKELIHOOD_SINGLE_TOKEN : LoglikelihoodSingleTokenResponse ,
312
+ RequestType .LOGLIKELIHOOD_ROLLING : LoglikelihoodResponse ,
313
+ RequestType .GREEDY_UNTIL_MULTI_TURN : GenerativeMultiturnResponse ,
314
+ RequestType .GREEDY_UNTIL : GenerativeResponse ,
315
+ }
316
+
283
317
def _load_responses_from_details (self ):
284
318
logger .info ("--- LOADING RESPONSES FROM DETAILS ---" )
285
319
sample_id_to_responses : dict [(SampleUid , MetricCategory ), list [ModelResponse ]] = collections .defaultdict (list )
286
320
287
- request_types = list (self .requests .keys ())
288
- if len (request_types ) > 1 :
289
- raise ValueError (
290
- "Loading responses from details when there are multiple request types is currently not supported"
291
- )
292
- model_response_type = self ._get_model_response_type (request_types [0 ])
321
+ model_response_type = self ._get_model_response_type ()
293
322
294
323
details_datasets = self .evaluation_tracker .load_details_datasets (
295
324
self .pipeline_parameters .load_responses_from_details_date_id , self .task_names_list
@@ -333,25 +362,25 @@ def _load_responses_from_details(self):
333
362
sample_id_to_responses [(SampleUid (task_name , f"{ idx } _{ 0 } " ), metric_category )] = [response ]
334
363
return sample_id_to_responses
335
364
336
- def _get_model_response_type (self , request_type ):
337
- if request_type == RequestType .LOGLIKELIHOOD :
338
- model_response_type = LoglikelihoodResponse
339
- elif request_type == RequestType .LOGLIKELIHOOD_SINGLE_TOKEN :
340
- model_response_type = LoglikelihoodSingleTokenResponse
341
- elif request_type == RequestType .LOGLIKELIHOOD_ROLLING :
342
- model_response_type = LoglikelihoodResponse
343
- elif request_type == RequestType .GREEDY_UNTIL_MULTI_TURN :
344
- model_response_type = GenerativeMultiturnResponse
345
- elif request_type == RequestType .GREEDY_UNTIL :
346
- model_response_type = GenerativeResponse
347
- else :
348
- raise ValueError (
349
- f"Loading responses from details for request type { request_type } is currently not supported"
350
- )
351
-
365
+ def _get_model_response_type (self ):
366
+ model_response_type = None
367
+ for task in self .task_dict .values ():
368
+ for metric_category , has_metric_category in task .has_metric_category .items ():
369
+ if has_metric_category :
370
+ request_type = self ._metric_category_to_request_type ()[metric_category ]
371
+ new_model_response_type = self ._request_type_to_response ()[request_type ]
372
+ if model_response_type and new_model_response_type != model_response_type :
373
+ raise ValueError (
374
+ f"Loading responses from details with multiple model response types ({ model_response_type } and { new_model_response_type } ) is currently not supported"
375
+ )
376
+ model_response_type = new_model_response_type
352
377
return model_response_type
353
378
354
379
def _run_model (self ):
380
+ # Initi model stuff lazily to avoid loading the model if not needed
381
+ self ._init_model ()
382
+ self ._init_requests () # Needs the model to be initialized
383
+
355
384
# Running all requests depending on the model call type (log likelihood, generative, ...)
356
385
# to be able to batch them
357
386
logger .info ("--- RUNNING MODEL ---" )
0 commit comments