Skip to content

Commit ab38bd5

Browse files
authored
Merge branch 'main' into Document-Custom-Model-Files
2 parents 73af85b + 94fc5a2 commit ab38bd5

18 files changed

+457
-21
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ multilingual = [
109109
"jieba", # for chinese tokenizer
110110
"pyvi", # for vietnamese tokenizer
111111
]
112-
math = ["latex2sympy2_extended>=0.9.1"]
112+
math = ["latex2sympy2_extended>=0.9.3"]
113113

114114
[project.urls]
115115
Homepage = "https://github.com/huggingface/lighteval"

src/lighteval/logging/evaluation_tracker.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import torch
3535
from datasets import Dataset, load_dataset
3636
from datasets.utils.metadata import MetadataConfigs
37-
from fsspec import url_to_fs
3837
from huggingface_hub import DatasetCard, DatasetCardData, HfApi, HFSummaryWriter, hf_hub_url
3938

4039
from lighteval.logging.info_loggers import (
@@ -53,6 +52,11 @@
5352
if is_nanotron_available():
5453
from nanotron.config import GeneralArgs # type: ignore
5554

55+
try:
56+
from fsspec import url_to_fs
57+
except ImportError:
58+
from fsspec.core import url_to_fs
59+
5660

5761
class EnhancedJSONEncoder(json.JSONEncoder):
5862
"""
@@ -231,9 +235,45 @@ def save_results(self, date_id: str, results_dict: dict):
231235
with self.fs.open(output_results_file, "w") as f:
232236
f.write(json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False))
233237

234-
def save_details(self, date_id: str, details_datasets: dict[str, Dataset]):
238+
def _get_details_sub_folder(self, date_id: str):
235239
output_dir_details = Path(self.output_dir) / "details" / self.general_config_logger.model_name
236-
output_dir_details_sub_folder = output_dir_details / date_id
240+
if date_id in ["first", "last"]:
241+
# Get all folders in output_dir_details
242+
if not self.fs.exists(output_dir_details):
243+
raise FileNotFoundError(f"Details directory {output_dir_details} does not exist")
244+
245+
# List all folders and filter out files
246+
folders = [f["name"] for f in self.fs.listdir(output_dir_details) if f["type"] == "directory"]
247+
248+
if not folders:
249+
raise FileNotFoundError(f"No timestamp folders found in {output_dir_details}")
250+
251+
# Parse timestamps and get first or last
252+
date_id = max(folders) if date_id == "last" else min(folders)
253+
return output_dir_details / date_id
254+
255+
def load_details_datasets(self, date_id: str, task_names: list[str]) -> dict[str, Dataset]:
256+
output_dir_details_sub_folder = self._get_details_sub_folder(date_id)
257+
logger.info(f"Loading details from {output_dir_details_sub_folder}")
258+
date_id = output_dir_details_sub_folder.name # Overwrite date_id in case of latest
259+
details_datasets = {}
260+
for file in self.fs.glob(str(output_dir_details_sub_folder / f"details_*_{date_id}.parquet")):
261+
task_name = Path(file).stem.replace("details_", "").replace(f"_{date_id}", "")
262+
if "|".join(task_name.split("|")[:-1]) not in task_names:
263+
logger.info(f"Skipping {task_name} because it is not in the task_names list")
264+
continue
265+
dataset = load_dataset("parquet", data_files=file, split="train")
266+
details_datasets[task_name] = dataset
267+
268+
for task_name in task_names:
269+
if not any(task_name.startswith(task_name) for task_name in details_datasets.keys()):
270+
raise ValueError(
271+
f"Task {task_name} not found in details datasets. Check the tasks to be evaluated or the date_id used to load the details ({date_id})."
272+
)
273+
return details_datasets
274+
275+
def save_details(self, date_id: str, details_datasets: dict[str, Dataset]):
276+
output_dir_details_sub_folder = self._get_details_sub_folder(date_id)
237277
self.fs.mkdirs(output_dir_details_sub_folder, exist_ok=True)
238278
logger.info(f"Saving details to {output_dir_details_sub_folder}")
239279
for task_name, dataset in details_datasets.items():

src/lighteval/main_accelerate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def accelerate( # noqa C901
6767
num_fewshot_seeds: Annotated[
6868
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
6969
] = 1,
70+
load_responses_from_details_date_id: Annotated[
71+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
72+
] = None,
7073
# === saving ===
7174
output_dir: Annotated[
7275
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -137,6 +140,7 @@ def accelerate( # noqa C901
137140
max_samples=max_samples,
138141
use_chat_template=use_chat_template,
139142
system_prompt=system_prompt,
143+
load_responses_from_details_date_id=load_responses_from_details_date_id,
140144
)
141145

142146
# TODO (nathan): better handling of model_args

src/lighteval/main_endpoint.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ def inference_endpoint(
179179
num_fewshot_seeds: Annotated[
180180
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
181181
] = 1,
182+
load_responses_from_details_date_id: Annotated[
183+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
184+
] = None,
182185
# === saving ===
183186
output_dir: Annotated[
184187
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -247,6 +250,7 @@ def inference_endpoint(
247250
max_samples=max_samples,
248251
use_chat_template=use_chat_template,
249252
system_prompt=system_prompt,
253+
load_responses_from_details_date_id=load_responses_from_details_date_id,
250254
)
251255
pipeline = Pipeline(
252256
tasks=tasks,
@@ -292,6 +296,9 @@ def tgi(
292296
num_fewshot_seeds: Annotated[
293297
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
294298
] = 1,
299+
load_responses_from_details_date_id: Annotated[
300+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
301+
] = None,
295302
# === saving ===
296303
output_dir: Annotated[
297304
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -355,6 +362,7 @@ def tgi(
355362
max_samples=max_samples,
356363
use_chat_template=use_chat_template,
357364
system_prompt=system_prompt,
365+
load_responses_from_details_date_id=load_responses_from_details_date_id,
358366
)
359367
pipeline = Pipeline(
360368
tasks=tasks,
@@ -400,6 +408,9 @@ def litellm(
400408
num_fewshot_seeds: Annotated[
401409
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
402410
] = 1,
411+
load_responses_from_details_date_id: Annotated[
412+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
413+
] = None,
403414
# === saving ===
404415
output_dir: Annotated[
405416
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -464,6 +475,7 @@ def litellm(
464475
max_samples=max_samples,
465476
use_chat_template=use_chat_template,
466477
system_prompt=system_prompt,
478+
load_responses_from_details_date_id=load_responses_from_details_date_id,
467479
)
468480
pipeline = Pipeline(
469481
tasks=tasks,

src/lighteval/main_vllm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def vllm(
6363
num_fewshot_seeds: Annotated[
6464
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
6565
] = 1,
66+
load_responses_from_details_date_id: Annotated[
67+
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
68+
] = None,
6669
# === saving ===
6770
output_dir: Annotated[
6871
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -124,6 +127,7 @@ def vllm(
124127
max_samples=max_samples,
125128
use_chat_template=use_chat_template,
126129
system_prompt=system_prompt,
130+
load_responses_from_details_date_id=load_responses_from_details_date_id,
127131
)
128132

129133
if model_args.endswith(".yaml"):

src/lighteval/metrics/dynamic_metrics.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def multilingual_extractive_match_metric(
191191
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
192192
aggregation_function: Callable[[list[float]], float] = max,
193193
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
194+
extraction_mode: Literal["first_match", "any_match"] = "any_match",
194195
precision: int = 6,
195196
) -> SampleLevelMetric:
196197
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
@@ -215,6 +216,10 @@ def multilingual_extractive_match_metric(
215216
How to perform extraction. Defaults to "first_match".
216217
- "no_fallback": Only use first successfully parsed matches
217218
- "first_match": Use the first successfully parsed match + first match irregardless the parsing success
219+
extraction_mode: Literal["first_match", "any_match"]
220+
- "first_match": Only tries to extract the first regex match if it fails no other matches are tried
221+
- "any_match": Tries to extract any regex match
222+
218223
precision: int
219224
Number of decimal places to use when comparing numerical values. Defaults to 6.
220225
@@ -240,9 +245,12 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
240245
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)
241246

242247
extracted_predictions = [
243-
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions
248+
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode)
249+
for pred in predictions
250+
]
251+
extracted_golds = [
252+
extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode) for gold in golds
244253
]
245-
extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds]
246254

247255
# Assert on empty gold and warn on empty pred
248256
if any(len(g) == 0 for g in extracted_golds):
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.

src/lighteval/metrics/utils/extractive_match_utils.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
# SOFTWARE.
2222

2323
import re
24-
from dataclasses import dataclass
24+
from dataclasses import dataclass, field
2525
from functools import lru_cache
2626
from itertools import groupby
27-
from typing import Literal, Sequence
27+
from typing import Any, Literal, Sequence
2828

2929
import sympy
3030
from sympy import Basic, MatrixBase, Number
@@ -39,17 +39,33 @@
3939
from lighteval.utils.timeout import timeout
4040

4141

42+
@requires_latex2sympy2_extended
43+
def latex_normalization_config_default_factory():
44+
from latex2sympy2_extended.latex2sympy2 import NormalizationConfig
45+
46+
return NormalizationConfig(
47+
basic_latex=True,
48+
units=True,
49+
malformed_operators=True,
50+
nits=True,
51+
boxed=True,
52+
equations=True,
53+
)
54+
55+
4256
@dataclass(frozen=True)
4357
class LatexExtractionConfig:
4458
"""Config for extracting latex from the prediction.
4559
4660
Attributes:
4761
try_extract_without_anchor (bool): Whether to try extracting latex without requiring specific anchors like "answer:" or "final answer is"
48-
enforce_boxed_match (bool): Whether to also consider extracting from plain \boxed{...} expressions
62+
boxed_match_priority (int): Priority of the boxed match regex (-1 never, 0 first, 55 after final answer: anchor, etc...)
63+
normalization_config (latex2sympy2_extended.latex2sympy2.NormalizationConfig): Normalization config to use for latex extraction
4964
"""
5065

5166
try_extract_without_anchor: bool = True
52-
enforce_boxed_match: bool = True
67+
boxed_match_priority: int = 55
68+
normalization_config: Any = field(default_factory=latex_normalization_config_default_factory)
5369

5470

5571
@dataclass(frozen=True)
@@ -187,9 +203,8 @@ def lazy_latex_regex(latex_config: LatexExtractionConfig, language: Language) ->
187203
if latex_config.try_extract_without_anchor:
188204
regexes.append((latex_re, 300))
189205

190-
# This ensures that boxed is matched right after the final answer xxxx
191-
if latex_config.enforce_boxed_match:
192-
regexes.append((latex_boxed, 55))
206+
if latex_config.boxed_match_priority >= 0:
207+
regexes.append((latex_boxed, latex_config.boxed_match_priority))
193208

194209
return [(re.compile(pattern, re.DOTALL), priority) for pattern, priority in regexes]
195210

@@ -387,6 +402,7 @@ def extract_target_from_pred(
387402
pred: str,
388403
target_res: list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]],
389404
fallback_mode: Literal["no_fallback", "first_match"] = "no_fallback",
405+
extraction_mode: Literal["first_match", "any_match"] = "any_match",
390406
):
391407
"""Extracts targets from a prediction string using regex patterns.
392408
Returns first sucesffuly extracted match.
@@ -397,6 +413,9 @@ def extract_target_from_pred(
397413
fallback_mode (Literal["no_fallback", "first_match"], optional): How to handle extraction failures. Defaults to "no_fallback".
398414
- "no_fallback": Return only successfully parsed match
399415
- "first_match": Additionaly Include the first string match no matter how parsing finished
416+
extraction_mode (Literal["first_match", "any_match"], optional): How to handle extraction failures. Defaults to "any_match".
417+
- "first_match": Only tries to extract the first match
418+
- "any_match": Tries to extract any match
400419
401420
Returns:
402421
list: List of extracted predictions, with first fallbac string appended if fallback_mode is "first_match"
@@ -410,6 +429,7 @@ def extract_target_from_pred(
410429
for target_patterns, target_type in target_res
411430
for pattern, priority in target_patterns
412431
]
432+
match_found = False
413433

414434
# Group patterns by priority using itertools.groupby
415435
for _, patterns_group in groupby(sorted(all_patterns, key=lambda x: x[2]), key=lambda x: x[2]):
@@ -426,6 +446,7 @@ def extract_target_from_pred(
426446
# Try to extract from each match, starting from rightmost
427447
for match, _, _, target_type in matches_with_pos:
428448
extracted_match, str_fallback = extract_match(match, target_type)
449+
match_found = True
429450

430451
if str_fallback:
431452
fallbacks.append(str_fallback)
@@ -434,8 +455,11 @@ def extract_target_from_pred(
434455
extracted_predictions.append(extracted_match)
435456
break
436457

458+
if extraction_mode == "first_match":
459+
break
460+
437461
# If we found something and we're in first_match mode, stop processing other priorities
438-
if extracted_predictions:
462+
if extracted_predictions or (match_found and extraction_mode == "first_match"):
439463
break
440464

441465
if fallback_mode == "first_match" and fallbacks:

0 commit comments

Comments
 (0)