From 8a161c19eb363cd58a83318fc11e502f5323bda0 Mon Sep 17 00:00:00 2001 From: Hynek Kydlicek Date: Tue, 28 Jan 2025 12:45:09 +0100 Subject: [PATCH] bump up latex2sympy + adjust latex target --- pyproject.toml | 2 +- src/lighteval/metrics/dynamic_metrics.py | 9 ++++- .../metrics/utils/extractive_match_utils.py | 39 +++++++++++++++---- .../metrics/utils/math_comparison.py | 2 +- tests/metrics/test_extractive_match.py | 1 + 5 files changed, 42 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 126d66244..df4ff39da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,7 @@ multilingual = [ "jieba", # for chinese tokenizer "pyvi", # for vietnamese tokenizer ] -math = ["latex2sympy2_extended>=0.9.1"] +math = ["latex2sympy2_extended>=0.9.3"] [project.urls] Homepage = "https://github.com/huggingface/lighteval" diff --git a/src/lighteval/metrics/dynamic_metrics.py b/src/lighteval/metrics/dynamic_metrics.py index 577934e9d..26fdabd4a 100644 --- a/src/lighteval/metrics/dynamic_metrics.py +++ b/src/lighteval/metrics/dynamic_metrics.py @@ -191,6 +191,7 @@ def multilingual_extractive_match_metric( pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),), aggregation_function: Callable[[list[float]], float] = max, fallback_mode: Literal["no_fallback", "first_match"] = "first_match", + extraction_mode: Literal["first_match", "any_match"] = "any_match", precision: int = 6, ) -> SampleLevelMetric: """Creates a language-aware extractive match metric that extracts answers from the model's output. @@ -215,6 +216,7 @@ def multilingual_extractive_match_metric( How to perform extraction. Defaults to "first_match". - "no_fallback": Only use first successfully parsed matches - "first_match": Use the first successfully parsed match + first match irregardless the parsing success + extraction_mode: Literal["first_match", "any_match"] precision: int Number of decimal places to use when comparing numerical values. Defaults to 6. @@ -240,9 +242,12 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language) extracted_predictions = [ - extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions + extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode) + for pred in predictions + ] + extracted_golds = [ + extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode) for gold in golds ] - extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds] # Assert on empty gold and warn on empty pred if any(len(g) == 0 for g in extracted_golds): diff --git a/src/lighteval/metrics/utils/extractive_match_utils.py b/src/lighteval/metrics/utils/extractive_match_utils.py index 01d2fc102..66bdbe006 100644 --- a/src/lighteval/metrics/utils/extractive_match_utils.py +++ b/src/lighteval/metrics/utils/extractive_match_utils.py @@ -21,10 +21,10 @@ # SOFTWARE. import re -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import lru_cache from itertools import groupby -from typing import Literal, Sequence +from typing import Any, Literal, Sequence import sympy from sympy import Basic, MatrixBase, Number @@ -39,17 +39,33 @@ from lighteval.utils.timeout import timeout +@requires_latex2sympy2_extended +def latex_normalization_config_default_factory(): + from latex2sympy2_extended.latex2sympy2 import NormalizationConfig + + return NormalizationConfig( + basic_latex=True, + units=True, + malformed_operators=True, + nits=True, + boxed=True, + equations=True, + ) + + @dataclass(frozen=True) class LatexExtractionConfig: """Config for extracting latex from the prediction. Attributes: try_extract_without_anchor (bool): Whether to try extracting latex without requiring specific anchors like "answer:" or "final answer is" - enforce_boxed_match (bool): Whether to also consider extracting from plain \boxed{...} expressions + boxed_match_priority (int): Priority of the boxed match regex (-1 never, 0 first, 55 after final answer: anchor, etc...) + normalization_config (latex2sympy2_extended.latex2sympy2.NormalizationConfig): Normalization config to use for latex extraction """ try_extract_without_anchor: bool = True - enforce_boxed_match: bool = True + boxed_match_priority: int = 55 + normalization_config: Any = field(default_factory=latex_normalization_config_default_factory) @dataclass(frozen=True) @@ -188,8 +204,8 @@ def lazy_latex_regex(latex_config: LatexExtractionConfig, language: Language) -> regexes.append((latex_re, 300)) # This ensures that boxed is matched right after the final answer xxxx - if latex_config.enforce_boxed_match: - regexes.append((latex_boxed, 55)) + if latex_config.boxed_match_priority >= 0: + regexes.append((latex_boxed, latex_config.boxed_match_priority)) return [(re.compile(pattern, re.DOTALL), priority) for pattern, priority in regexes] @@ -387,6 +403,7 @@ def extract_target_from_pred( pred: str, target_res: list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]], fallback_mode: Literal["no_fallback", "first_match"] = "no_fallback", + extraction_mode: Literal["first_match", "any_match"] = "any_match", ): """Extracts targets from a prediction string using regex patterns. Returns first sucesffuly extracted match. @@ -397,6 +414,9 @@ def extract_target_from_pred( fallback_mode (Literal["no_fallback", "first_match"], optional): How to handle extraction failures. Defaults to "no_fallback". - "no_fallback": Return only successfully parsed match - "first_match": Additionaly Include the first string match no matter how parsing finished + extraction_mode (Literal["first_match", "any_match"], optional): How to handle extraction failures. Defaults to "any_match". + - "first_match": Only tries to extract the first match + - "any_match": Tries to extract any match Returns: list: List of extracted predictions, with first fallbac string appended if fallback_mode is "first_match" @@ -410,6 +430,7 @@ def extract_target_from_pred( for target_patterns, target_type in target_res for pattern, priority in target_patterns ] + match_found = False # Group patterns by priority using itertools.groupby for _, patterns_group in groupby(sorted(all_patterns, key=lambda x: x[2]), key=lambda x: x[2]): @@ -426,6 +447,7 @@ def extract_target_from_pred( # Try to extract from each match, starting from rightmost for match, _, _, target_type in matches_with_pos: extracted_match, str_fallback = extract_match(match, target_type) + match_found = True if str_fallback: fallbacks.append(str_fallback) @@ -434,8 +456,11 @@ def extract_target_from_pred( extracted_predictions.append(extracted_match) break + if extraction_mode == "first_match": + break + # If we found something and we're in first_match mode, stop processing other priorities - if extracted_predictions: + if extracted_predictions or (match_found and extraction_mode == "first_match"): break if fallback_mode == "first_match" and fallbacks: diff --git a/src/lighteval/metrics/utils/math_comparison.py b/src/lighteval/metrics/utils/math_comparison.py index 483d1d450..fc5b3dc4c 100644 --- a/src/lighteval/metrics/utils/math_comparison.py +++ b/src/lighteval/metrics/utils/math_comparison.py @@ -413,7 +413,7 @@ def should_treat_as_complex(latex_str: str) -> bool: def compare_gold_target( gold: list[Basic | MatrixBase | str], target: list[Basic | MatrixBase | str], precision: int ) -> bool: - @timeout(timeout_seconds=10) + @timeout(timeout_seconds=100) def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool: # If both are sympy expressions, we can use sympy to compare them if isinstance(gold, (Basic, MatrixBase)) and isinstance(target, (Basic, MatrixBase)): diff --git a/tests/metrics/test_extractive_match.py b/tests/metrics/test_extractive_match.py index 78e7fdae2..3ca1c2510 100644 --- a/tests/metrics/test_extractive_match.py +++ b/tests/metrics/test_extractive_match.py @@ -969,6 +969,7 @@ def test_math_extraction_edge_cases(gold, pred, expected): r"Thus, the answer is $x \in \boxed{(2,12) \cup (12,102)}$", 1, ), + (r"$204$", r"$24+108=204$", 1), ], ) def test_math_extraction_additional_cases(gold, pred, expected):