Skip to content

Commit a95156e

Browse files
committed
Implemented the possibility to load predictions from details files and continue evaluating from there.
1 parent f6fee3a commit a95156e

File tree

5 files changed

+110
-7
lines changed

5 files changed

+110
-7
lines changed

src/lighteval/logging/evaluation_tracker.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,36 @@ def save_results(self, date_id: str, results_dict: dict):
209209
with self.fs.open(output_results_file, "w") as f:
210210
f.write(json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False))
211211

212-
def save_details(self, date_id: str, details_datasets: dict[str, Dataset]):
212+
def _get_details_sub_folder(self, date_id: str):
213213
output_dir_details = Path(self.output_dir) / "details" / self.general_config_logger.model_name
214-
output_dir_details_sub_folder = output_dir_details / date_id
214+
if date_id == "latest":
215+
# Get all folders in output_dir_details
216+
if not self.fs.exists(output_dir_details):
217+
raise FileNotFoundError(f"Details directory {output_dir_details} does not exist")
218+
219+
# List all folders and filter out files
220+
folders = [f['name'] for f in self.fs.listdir(output_dir_details) if f['type'] == 'directory']
221+
222+
if not folders:
223+
raise FileNotFoundError(f"No timestamp folders found in {output_dir_details}")
224+
225+
# Parse timestamps and get latest
226+
date_id = max(folders)
227+
return output_dir_details / date_id
228+
229+
def load_details_datasets(self, date_id: str) -> dict[str, Dataset]:
230+
output_dir_details_sub_folder = self._get_details_sub_folder(date_id)
231+
date_id = output_dir_details_sub_folder.name # Overwrite date_id in case of latest
232+
details_datasets = {}
233+
for file in self.fs.glob(str(output_dir_details_sub_folder / f"details_*_{date_id}.parquet")):
234+
task_name = Path(file).stem.replace(f"details_", "").replace(f"_{date_id}", "")
235+
dataset = load_dataset("parquet", data_files=file, split="train")
236+
details_datasets[task_name] = dataset
237+
return details_datasets
238+
239+
240+
def save_details(self, date_id: str, details_datasets: dict[str, Dataset]):
241+
output_dir_details_sub_folder = self._get_details_sub_folder(date_id)
215242
self.fs.mkdirs(output_dir_details_sub_folder, exist_ok=True)
216243
logger.info(f"Saving details to {output_dir_details_sub_folder}")
217244
for task_name, dataset in details_datasets.items():

src/lighteval/main_accelerate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def accelerate( # noqa C901
6767
num_fewshot_seeds: Annotated[
6868
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
6969
] = 1,
70+
load_responses_from_details_date_id: Annotated[
71+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
72+
] = None,
7073
# === saving ===
7174
output_dir: Annotated[
7275
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -137,6 +140,7 @@ def accelerate( # noqa C901
137140
max_samples=max_samples,
138141
use_chat_template=use_chat_template,
139142
system_prompt=system_prompt,
143+
load_responses_from_details_date_id=load_responses_from_details_date_id,
140144
)
141145

142146
# TODO (nathan): better handling of model_args

src/lighteval/main_endpoint.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ def inference_endpoint(
179179
num_fewshot_seeds: Annotated[
180180
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
181181
] = 1,
182+
load_responses_from_details_date_id: Annotated[
183+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
184+
] = None,
182185
# === saving ===
183186
output_dir: Annotated[
184187
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -247,6 +250,7 @@ def inference_endpoint(
247250
max_samples=max_samples,
248251
use_chat_template=use_chat_template,
249252
system_prompt=system_prompt,
253+
load_responses_from_details_date_id=load_responses_from_details_date_id,
250254
)
251255
pipeline = Pipeline(
252256
tasks=tasks,
@@ -292,6 +296,9 @@ def tgi(
292296
num_fewshot_seeds: Annotated[
293297
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
294298
] = 1,
299+
load_responses_from_details_date_id: Annotated[
300+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
301+
] = None,
295302
# === saving ===
296303
output_dir: Annotated[
297304
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -355,6 +362,7 @@ def tgi(
355362
max_samples=max_samples,
356363
use_chat_template=use_chat_template,
357364
system_prompt=system_prompt,
365+
load_responses_from_details_date_id=load_responses_from_details_date_id,
358366
)
359367
pipeline = Pipeline(
360368
tasks=tasks,
@@ -400,6 +408,9 @@ def litellm(
400408
num_fewshot_seeds: Annotated[
401409
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
402410
] = 1,
411+
load_responses_from_details_date_id: Annotated[
412+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
413+
] = None,
403414
# === saving ===
404415
output_dir: Annotated[
405416
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -464,6 +475,7 @@ def litellm(
464475
max_samples=max_samples,
465476
use_chat_template=use_chat_template,
466477
system_prompt=system_prompt,
478+
load_responses_from_details_date_id=load_responses_from_details_date_id,
467479
)
468480
pipeline = Pipeline(
469481
tasks=tasks,

src/lighteval/main_vllm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def vllm(
6363
num_fewshot_seeds: Annotated[
6464
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
6565
] = 1,
66+
load_responses_from_details_date_id: Annotated[
67+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
68+
] = None,
6669
# === saving ===
6770
output_dir: Annotated[
6871
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -124,6 +127,7 @@ def vllm(
124127
max_samples=max_samples,
125128
use_chat_template=use_chat_template,
126129
system_prompt=system_prompt,
130+
load_responses_from_details_date_id=load_responses_from_details_date_id,
127131
)
128132

129133
if model_args.endswith(".yaml"):

src/lighteval/pipeline.py

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23+
import ast
2324
import collections
2425
import os
2526
import random
@@ -34,10 +35,10 @@
3435
from lighteval.logging.evaluation_tracker import EvaluationTracker
3536
from lighteval.metrics.utils.metric_utils import MetricCategory
3637
from lighteval.models.model_loader import TransformersModel, load_model
37-
from lighteval.models.model_output import ModelResponse
38+
from lighteval.models.model_output import GenerativeMultiturnResponse, GenerativeResponse, LoglikelihoodResponse, LoglikelihoodSingleTokenResponse, ModelResponse
3839
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
3940
from lighteval.tasks.registry import Registry, taskinfo_selector
40-
from lighteval.tasks.requests import SampleUid
41+
from lighteval.tasks.requests import RequestType, SampleUid
4142
from lighteval.utils.imports import (
4243
NO_ACCELERATE_ERROR_MSG,
4344
NO_NANOTRON_ERROR_MSG,
@@ -95,6 +96,7 @@ class PipelineParameters:
9596
max_samples: int | None = None
9697
use_chat_template: bool = False
9798
system_prompt: str | None = None
99+
load_responses_from_details_date_id: str | None = None
98100

99101
def __post_init__(self): # noqa C901
100102
if self.launcher_type == ParallelismManager.ACCELERATE:
@@ -245,7 +247,11 @@ def evaluate(self):
245247
config=self.model_config,
246248
)
247249

248-
sample_id_to_responses = self._run_model()
250+
if self.pipeline_parameters.load_responses_from_details_date_id:
251+
sample_id_to_responses = self._load_responses_from_details()
252+
else:
253+
sample_id_to_responses = self._run_model()
254+
249255
self._compute_metrics(sample_id_to_responses)
250256

251257
if self.is_main_process():
@@ -261,6 +267,53 @@ def evaluate(self):
261267
except OSError:
262268
pass
263269

270+
271+
def _load_responses_from_details(self):
272+
logger.info("--- LOADING RESPONSES FROM DETAILS ---")
273+
sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list)
274+
275+
request_types = list(self.requests.keys())
276+
if len(request_types) > 1:
277+
raise ValueError("Loading responses from details when there are multiple request types is currently not supported")
278+
request_type = request_types[0]
279+
if request_type == RequestType.LOGLIKELIHOOD:
280+
model_response_type = LoglikelihoodResponse
281+
elif request_type == RequestType.LOGLIKELIHOOD_SINGLE_TOKEN:
282+
model_response_type = LoglikelihoodSingleTokenResponse
283+
elif request_type == RequestType.LOGLIKELIHOOD_ROLLING:
284+
model_response_type = LoglikelihoodResponse
285+
elif request_type == RequestType.GREEDY_UNTIL_MULTI_TURN:
286+
model_response_type = GenerativeMultiturnResponse
287+
elif request_type == RequestType.GREEDY_UNTIL:
288+
model_response_type = GenerativeResponse
289+
else:
290+
raise ValueError(f"Loading responses from details for request type {request_type} is currently not supported")
291+
292+
details_datasets = self.evaluation_tracker.load_details_datasets(self.pipeline_parameters.load_responses_from_details_date_id)
293+
for task_name, dataset in details_datasets.items():
294+
task: LightevalTask = self._get_task(task_name)
295+
num_samples = len(dataset["predictions"])
296+
max_samples = self.pipeline_parameters.max_samples if self.pipeline_parameters.max_samples else num_samples
297+
if num_samples > max_samples:
298+
logger.warning(f"Skipping {num_samples - max_samples} samples for {task_name} when loading responses from details because max_samples is set to {max_samples}")
299+
num_samples = self.pipeline_parameters.max_samples
300+
for metric_category, has_metric_category in task.has_metric_category.items():
301+
if not has_metric_category:
302+
continue
303+
for idx in range(num_samples):
304+
kwargs = {
305+
"result": ast.literal_eval(dataset["predictions"][idx]),
306+
"input_tokens": ast.literal_eval(dataset["input_tokens"][idx]),
307+
"generated_tokens": ast.literal_eval(dataset["cont_tokens"][idx]),
308+
"truncated_tokens_count": ast.literal_eval(dataset["truncated"][idx])[0],
309+
"padded_tokens_count": ast.literal_eval(dataset["padded"][idx])[0]
310+
}
311+
if model_response_type == GenerativeResponse:
312+
kwargs["logits"] = ast.literal_eval(dataset["pred_logits"][idx])
313+
response = model_response_type(**kwargs)
314+
sample_id_to_responses[(SampleUid(task_name, f"{idx}_{0}"), metric_category)] = [response]
315+
return sample_id_to_responses
316+
264317
def _run_model(self):
265318
# Running all requests depending on the model call type (log likelihood, generative, ...)
266319
# to be able to batch them
@@ -283,6 +336,10 @@ def _run_model(self):
283336

284337
return sample_id_to_responses
285338

339+
def _get_task(self, task_name: str):
340+
short_task_name = task_name.rsplit("|", 1)[0]
341+
return self.task_dict[short_task_name]
342+
286343
def _compute_metrics(self, sample_id_to_responses):
287344
# To compute the metrics we first group the samples and task and then by metrics.
288345
# This way we can batch the metrics computation for each task and metric category
@@ -307,8 +364,7 @@ def _compute_metrics(self, sample_id_to_responses):
307364
task_metric_category_groups[sample_id.task_name][metric_category]["docs"].append(self.docs[sample_id])
308365

309366
for task_name, samples_per_metric in task_metric_category_groups.items():
310-
short_task_name = task_name.rsplit("|", 1)[0]
311-
task: LightevalTask = self.task_dict[short_task_name]
367+
task: LightevalTask = self._get_task(task_name)
312368

313369
for metric_category, samples in samples_per_metric.items():
314370
sample_ids = samples["ids"]

0 commit comments

Comments
 (0)