Skip to content

Commit

Permalink
Math extraction - allow only trying the first match, more customizabl…
Browse files Browse the repository at this point in the history
…e latex extraction + bump deps (#522)

* extract matching

* better docstring

* lazy imports

* bump up math

* Update src/lighteval/metrics/dynamic_metrics.py

Co-authored-by: Clémentine Fourrier <[email protected]>

* fix pr commnets

* Apply suggestions from code review

Co-authored-by: Clémentine Fourrier <[email protected]>

* rename comparisson -> comparison

* fix expr numbers extraction with currency or units

* add test for correct extraction of failed answer

* bump of latex2sympy2 version, add new tests for extract metric

* bump up latex2sympy + adjust latex target

* revert gold target timoeut

* remove dead comment 💀

* add doc

---------

Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
hynky1999 and clefourrier authored Jan 28, 2025
1 parent cb075a5 commit 0e46269
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ multilingual = [
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]
math = ["latex2sympy2_extended>=0.9.1"]
math = ["latex2sympy2_extended>=0.9.3"]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
12 changes: 10 additions & 2 deletions src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def multilingual_extractive_match_metric(
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
aggregation_function: Callable[[list[float]], float] = max,
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
extraction_mode: Literal["first_match", "any_match"] = "any_match",
precision: int = 6,
) -> SampleLevelMetric:
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
Expand All @@ -215,6 +216,10 @@ def multilingual_extractive_match_metric(
How to perform extraction. Defaults to "first_match".
- "no_fallback": Only use first successfully parsed matches
- "first_match": Use the first successfully parsed match + first match irregardless the parsing success
extraction_mode: Literal["first_match", "any_match"]
- "first_match": Only tries to extract the first regex match if it fails no other matches are tried
- "any_match": Tries to extract any regex match
precision: int
Number of decimal places to use when comparing numerical values. Defaults to 6.
Expand All @@ -240,9 +245,12 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)

extracted_predictions = [
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode)
for pred in predictions
]
extracted_golds = [
extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode) for gold in golds
]
extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds]

# Assert on empty gold and warn on empty pred
if any(len(g) == 0 for g in extracted_golds):
Expand Down
40 changes: 32 additions & 8 deletions src/lighteval/metrics/utils/extractive_match_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
# SOFTWARE.

import re
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import lru_cache
from itertools import groupby
from typing import Literal, Sequence
from typing import Any, Literal, Sequence

import sympy
from sympy import Basic, MatrixBase, Number
Expand All @@ -39,17 +39,33 @@
from lighteval.utils.timeout import timeout


@requires_latex2sympy2_extended
def latex_normalization_config_default_factory():
from latex2sympy2_extended.latex2sympy2 import NormalizationConfig

return NormalizationConfig(
basic_latex=True,
units=True,
malformed_operators=True,
nits=True,
boxed=True,
equations=True,
)


@dataclass(frozen=True)
class LatexExtractionConfig:
"""Config for extracting latex from the prediction.
Attributes:
try_extract_without_anchor (bool): Whether to try extracting latex without requiring specific anchors like "answer:" or "final answer is"
enforce_boxed_match (bool): Whether to also consider extracting from plain \boxed{...} expressions
boxed_match_priority (int): Priority of the boxed match regex (-1 never, 0 first, 55 after final answer: anchor, etc...)
normalization_config (latex2sympy2_extended.latex2sympy2.NormalizationConfig): Normalization config to use for latex extraction
"""

try_extract_without_anchor: bool = True
enforce_boxed_match: bool = True
boxed_match_priority: int = 55
normalization_config: Any = field(default_factory=latex_normalization_config_default_factory)


@dataclass(frozen=True)
Expand Down Expand Up @@ -187,9 +203,8 @@ def lazy_latex_regex(latex_config: LatexExtractionConfig, language: Language) ->
if latex_config.try_extract_without_anchor:
regexes.append((latex_re, 300))

# This ensures that boxed is matched right after the final answer xxxx
if latex_config.enforce_boxed_match:
regexes.append((latex_boxed, 55))
if latex_config.boxed_match_priority >= 0:
regexes.append((latex_boxed, latex_config.boxed_match_priority))

return [(re.compile(pattern, re.DOTALL), priority) for pattern, priority in regexes]

Expand Down Expand Up @@ -387,6 +402,7 @@ def extract_target_from_pred(
pred: str,
target_res: list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]],
fallback_mode: Literal["no_fallback", "first_match"] = "no_fallback",
extraction_mode: Literal["first_match", "any_match"] = "any_match",
):
"""Extracts targets from a prediction string using regex patterns.
Returns first sucesffuly extracted match.
Expand All @@ -397,6 +413,9 @@ def extract_target_from_pred(
fallback_mode (Literal["no_fallback", "first_match"], optional): How to handle extraction failures. Defaults to "no_fallback".
- "no_fallback": Return only successfully parsed match
- "first_match": Additionaly Include the first string match no matter how parsing finished
extraction_mode (Literal["first_match", "any_match"], optional): How to handle extraction failures. Defaults to "any_match".
- "first_match": Only tries to extract the first match
- "any_match": Tries to extract any match
Returns:
list: List of extracted predictions, with first fallbac string appended if fallback_mode is "first_match"
Expand All @@ -410,6 +429,7 @@ def extract_target_from_pred(
for target_patterns, target_type in target_res
for pattern, priority in target_patterns
]
match_found = False

# Group patterns by priority using itertools.groupby
for _, patterns_group in groupby(sorted(all_patterns, key=lambda x: x[2]), key=lambda x: x[2]):
Expand All @@ -426,6 +446,7 @@ def extract_target_from_pred(
# Try to extract from each match, starting from rightmost
for match, _, _, target_type in matches_with_pos:
extracted_match, str_fallback = extract_match(match, target_type)
match_found = True

if str_fallback:
fallbacks.append(str_fallback)
Expand All @@ -434,8 +455,11 @@ def extract_target_from_pred(
extracted_predictions.append(extracted_match)
break

if extraction_mode == "first_match":
break

# If we found something and we're in first_match mode, stop processing other priorities
if extracted_predictions:
if extracted_predictions or (match_found and extraction_mode == "first_match"):
break

if fallback_mode == "first_match" and fallbacks:
Expand Down

0 comments on commit 0e46269

Please sign in to comment.