20
20
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
21
# SOFTWARE.
22
22
23
+ import ast
23
24
import collections
24
25
import os
25
26
import random
34
35
from lighteval .logging .evaluation_tracker import EvaluationTracker
35
36
from lighteval .metrics .utils .metric_utils import MetricCategory
36
37
from lighteval .models .model_loader import TransformersModel , load_model
37
- from lighteval .models .model_output import ModelResponse
38
+ from lighteval .models .model_output import GenerativeMultiturnResponse , GenerativeResponse , LoglikelihoodResponse , LoglikelihoodSingleTokenResponse , ModelResponse
38
39
from lighteval .tasks .lighteval_task import LightevalTask , create_requests_from_tasks
39
40
from lighteval .tasks .registry import Registry , taskinfo_selector
40
- from lighteval .tasks .requests import SampleUid
41
+ from lighteval .tasks .requests import RequestType , SampleUid
41
42
from lighteval .utils .imports import (
42
43
NO_ACCELERATE_ERROR_MSG ,
43
44
NO_NANOTRON_ERROR_MSG ,
@@ -95,6 +96,7 @@ class PipelineParameters:
95
96
max_samples : int | None = None
96
97
use_chat_template : bool = False
97
98
system_prompt : str | None = None
99
+ load_responses_from_details_date_id : str | None = None
98
100
99
101
def __post_init__ (self ): # noqa C901
100
102
if self .launcher_type == ParallelismManager .ACCELERATE :
@@ -245,7 +247,11 @@ def evaluate(self):
245
247
config = self .model_config ,
246
248
)
247
249
248
- sample_id_to_responses = self ._run_model ()
250
+ if self .pipeline_parameters .load_responses_from_details_date_id :
251
+ sample_id_to_responses = self ._load_responses_from_details ()
252
+ else :
253
+ sample_id_to_responses = self ._run_model ()
254
+
249
255
self ._compute_metrics (sample_id_to_responses )
250
256
251
257
if self .is_main_process ():
@@ -261,6 +267,53 @@ def evaluate(self):
261
267
except OSError :
262
268
pass
263
269
270
+
271
+ def _load_responses_from_details (self ):
272
+ logger .info ("--- LOADING RESPONSES FROM DETAILS ---" )
273
+ sample_id_to_responses : dict [(SampleUid , MetricCategory ), list [ModelResponse ]] = collections .defaultdict (list )
274
+
275
+ request_types = list (self .requests .keys ())
276
+ if len (request_types ) > 1 :
277
+ raise ValueError ("Loading responses from details when there are multiple request types is currently not supported" )
278
+ request_type = request_types [0 ]
279
+ if request_type == RequestType .LOGLIKELIHOOD :
280
+ model_response_type = LoglikelihoodResponse
281
+ elif request_type == RequestType .LOGLIKELIHOOD_SINGLE_TOKEN :
282
+ model_response_type = LoglikelihoodSingleTokenResponse
283
+ elif request_type == RequestType .LOGLIKELIHOOD_ROLLING :
284
+ model_response_type = LoglikelihoodResponse
285
+ elif request_type == RequestType .GREEDY_UNTIL_MULTI_TURN :
286
+ model_response_type = GenerativeMultiturnResponse
287
+ elif request_type == RequestType .GREEDY_UNTIL :
288
+ model_response_type = GenerativeResponse
289
+ else :
290
+ raise ValueError (f"Loading responses from details for request type { request_type } is currently not supported" )
291
+
292
+ details_datasets = self .evaluation_tracker .load_details_datasets (self .pipeline_parameters .load_responses_from_details_date_id )
293
+ for task_name , dataset in details_datasets .items ():
294
+ task : LightevalTask = self ._get_task (task_name )
295
+ num_samples = len (dataset ["predictions" ])
296
+ max_samples = self .pipeline_parameters .max_samples if self .pipeline_parameters .max_samples else num_samples
297
+ if num_samples > max_samples :
298
+ logger .warning (f"Skipping { num_samples - max_samples } samples for { task_name } when loading responses from details because max_samples is set to { max_samples } " )
299
+ num_samples = self .pipeline_parameters .max_samples
300
+ for metric_category , has_metric_category in task .has_metric_category .items ():
301
+ if not has_metric_category :
302
+ continue
303
+ for idx in range (num_samples ):
304
+ kwargs = {
305
+ "result" : ast .literal_eval (dataset ["predictions" ][idx ]),
306
+ "input_tokens" : ast .literal_eval (dataset ["input_tokens" ][idx ]),
307
+ "generated_tokens" : ast .literal_eval (dataset ["cont_tokens" ][idx ]),
308
+ "truncated_tokens_count" : ast .literal_eval (dataset ["truncated" ][idx ])[0 ],
309
+ "padded_tokens_count" : ast .literal_eval (dataset ["padded" ][idx ])[0 ]
310
+ }
311
+ if model_response_type == GenerativeResponse :
312
+ kwargs ["logits" ] = ast .literal_eval (dataset ["pred_logits" ][idx ])
313
+ response = model_response_type (** kwargs )
314
+ sample_id_to_responses [(SampleUid (task_name , f"{ idx } _{ 0 } " ), metric_category )] = [response ]
315
+ return sample_id_to_responses
316
+
264
317
def _run_model (self ):
265
318
# Running all requests depending on the model call type (log likelihood, generative, ...)
266
319
# to be able to batch them
@@ -283,6 +336,10 @@ def _run_model(self):
283
336
284
337
return sample_id_to_responses
285
338
339
+ def _get_task (self , task_name : str ):
340
+ short_task_name = task_name .rsplit ("|" , 1 )[0 ]
341
+ return self .task_dict [short_task_name ]
342
+
286
343
def _compute_metrics (self , sample_id_to_responses ):
287
344
# To compute the metrics we first group the samples and task and then by metrics.
288
345
# This way we can batch the metrics computation for each task and metric category
@@ -307,8 +364,7 @@ def _compute_metrics(self, sample_id_to_responses):
307
364
task_metric_category_groups [sample_id .task_name ][metric_category ]["docs" ].append (self .docs [sample_id ])
308
365
309
366
for task_name , samples_per_metric in task_metric_category_groups .items ():
310
- short_task_name = task_name .rsplit ("|" , 1 )[0 ]
311
- task : LightevalTask = self .task_dict [short_task_name ]
367
+ task : LightevalTask = self ._get_task (task_name )
312
368
313
369
for metric_category , samples in samples_per_metric .items ():
314
370
sample_ids = samples ["ids" ]
0 commit comments