Skip to content

Commit

Permalink
bump up latex2sympy + adjust latex target
Browse files Browse the repository at this point in the history
  • Loading branch information
hynky1999 committed Jan 28, 2025
1 parent 571937c commit 8a161c1
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ multilingual = [
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]
math = ["latex2sympy2_extended>=0.9.1"]
math = ["latex2sympy2_extended>=0.9.3"]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
9 changes: 7 additions & 2 deletions src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def multilingual_extractive_match_metric(
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
aggregation_function: Callable[[list[float]], float] = max,
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
extraction_mode: Literal["first_match", "any_match"] = "any_match",
precision: int = 6,
) -> SampleLevelMetric:
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
Expand All @@ -215,6 +216,7 @@ def multilingual_extractive_match_metric(
How to perform extraction. Defaults to "first_match".
- "no_fallback": Only use first successfully parsed matches
- "first_match": Use the first successfully parsed match + first match irregardless the parsing success
extraction_mode: Literal["first_match", "any_match"]
precision: int
Number of decimal places to use when comparing numerical values. Defaults to 6.
Expand All @@ -240,9 +242,12 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)

extracted_predictions = [
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode)
for pred in predictions
]
extracted_golds = [
extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode) for gold in golds
]
extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds]

# Assert on empty gold and warn on empty pred
if any(len(g) == 0 for g in extracted_golds):
Expand Down
39 changes: 32 additions & 7 deletions src/lighteval/metrics/utils/extractive_match_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
# SOFTWARE.

import re
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import lru_cache
from itertools import groupby
from typing import Literal, Sequence
from typing import Any, Literal, Sequence

import sympy
from sympy import Basic, MatrixBase, Number
Expand All @@ -39,17 +39,33 @@
from lighteval.utils.timeout import timeout


@requires_latex2sympy2_extended
def latex_normalization_config_default_factory():
from latex2sympy2_extended.latex2sympy2 import NormalizationConfig

return NormalizationConfig(
basic_latex=True,
units=True,
malformed_operators=True,
nits=True,
boxed=True,
equations=True,
)


@dataclass(frozen=True)
class LatexExtractionConfig:
"""Config for extracting latex from the prediction.
Attributes:
try_extract_without_anchor (bool): Whether to try extracting latex without requiring specific anchors like "answer:" or "final answer is"
enforce_boxed_match (bool): Whether to also consider extracting from plain \boxed{...} expressions
boxed_match_priority (int): Priority of the boxed match regex (-1 never, 0 first, 55 after final answer: anchor, etc...)
normalization_config (latex2sympy2_extended.latex2sympy2.NormalizationConfig): Normalization config to use for latex extraction
"""

try_extract_without_anchor: bool = True
enforce_boxed_match: bool = True
boxed_match_priority: int = 55
normalization_config: Any = field(default_factory=latex_normalization_config_default_factory)


@dataclass(frozen=True)
Expand Down Expand Up @@ -188,8 +204,8 @@ def lazy_latex_regex(latex_config: LatexExtractionConfig, language: Language) ->
regexes.append((latex_re, 300))

# This ensures that boxed is matched right after the final answer xxxx
if latex_config.enforce_boxed_match:
regexes.append((latex_boxed, 55))
if latex_config.boxed_match_priority >= 0:
regexes.append((latex_boxed, latex_config.boxed_match_priority))

return [(re.compile(pattern, re.DOTALL), priority) for pattern, priority in regexes]

Expand Down Expand Up @@ -387,6 +403,7 @@ def extract_target_from_pred(
pred: str,
target_res: list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]],
fallback_mode: Literal["no_fallback", "first_match"] = "no_fallback",
extraction_mode: Literal["first_match", "any_match"] = "any_match",
):
"""Extracts targets from a prediction string using regex patterns.
Returns first sucesffuly extracted match.
Expand All @@ -397,6 +414,9 @@ def extract_target_from_pred(
fallback_mode (Literal["no_fallback", "first_match"], optional): How to handle extraction failures. Defaults to "no_fallback".
- "no_fallback": Return only successfully parsed match
- "first_match": Additionaly Include the first string match no matter how parsing finished
extraction_mode (Literal["first_match", "any_match"], optional): How to handle extraction failures. Defaults to "any_match".
- "first_match": Only tries to extract the first match
- "any_match": Tries to extract any match
Returns:
list: List of extracted predictions, with first fallbac string appended if fallback_mode is "first_match"
Expand All @@ -410,6 +430,7 @@ def extract_target_from_pred(
for target_patterns, target_type in target_res
for pattern, priority in target_patterns
]
match_found = False

# Group patterns by priority using itertools.groupby
for _, patterns_group in groupby(sorted(all_patterns, key=lambda x: x[2]), key=lambda x: x[2]):
Expand All @@ -426,6 +447,7 @@ def extract_target_from_pred(
# Try to extract from each match, starting from rightmost
for match, _, _, target_type in matches_with_pos:
extracted_match, str_fallback = extract_match(match, target_type)
match_found = True

if str_fallback:
fallbacks.append(str_fallback)
Expand All @@ -434,8 +456,11 @@ def extract_target_from_pred(
extracted_predictions.append(extracted_match)
break

if extraction_mode == "first_match":
break

# If we found something and we're in first_match mode, stop processing other priorities
if extracted_predictions:
if extracted_predictions or (match_found and extraction_mode == "first_match"):
break

if fallback_mode == "first_match" and fallbacks:
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/metrics/utils/math_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def should_treat_as_complex(latex_str: str) -> bool:
def compare_gold_target(
gold: list[Basic | MatrixBase | str], target: list[Basic | MatrixBase | str], precision: int
) -> bool:
@timeout(timeout_seconds=10)
@timeout(timeout_seconds=100)
def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool:
# If both are sympy expressions, we can use sympy to compare them
if isinstance(gold, (Basic, MatrixBase)) and isinstance(target, (Basic, MatrixBase)):
Expand Down
1 change: 1 addition & 0 deletions tests/metrics/test_extractive_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,7 @@ def test_math_extraction_edge_cases(gold, pred, expected):
r"Thus, the answer is $x \in \boxed{(2,12) \cup (12,102)}$",
1,
),
(r"$204$", r"$24+108=204$", 1),
],
)
def test_math_extraction_additional_cases(gold, pred, expected):
Expand Down

0 comments on commit 8a161c1

Please sign in to comment.