From c536de044220d002c7e6957684d412a5fa380be8 Mon Sep 17 00:00:00 2001 From: Hynek Kydlicek Date: Wed, 5 Feb 2025 01:22:28 +0100 Subject: [PATCH] revert symbols, improve sets handling --- pyproject.toml | 2 +- src/lighteval/metrics/dynamic_metrics.py | 7 +++-- .../metrics/utils/extractive_match_utils.py | 4 +-- .../metrics/utils/math_comparison.py | 28 ++++++++++++++---- tests/metrics/test_extractive_match.py | 29 ++++++++++++++----- 5 files changed, 51 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a986faae..405e4f83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,7 @@ multilingual = [ "jieba", # for chinese tokenizer "pyvi", # for vietnamese tokenizer ] -math = ["latex2sympy2_extended>=1.0.2"] +math = ["latex2sympy2_extended==1.0.3"] [project.urls] Homepage = "https://github.com/huggingface/lighteval" diff --git a/src/lighteval/metrics/dynamic_metrics.py b/src/lighteval/metrics/dynamic_metrics.py index 51f749d0..3ac3f915 100644 --- a/src/lighteval/metrics/dynamic_metrics.py +++ b/src/lighteval/metrics/dynamic_metrics.py @@ -193,6 +193,7 @@ def multilingual_extractive_match_metric( fallback_mode: Literal["no_fallback", "first_match"] = "first_match", extraction_mode: Literal["first_match", "any_match"] = "any_match", precision: int = 6, + timeout_seconds: int = 5, ) -> SampleLevelMetric: """Creates a language-aware extractive match metric that extracts answers from the model's output. @@ -245,11 +246,11 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language) extracted_predictions = [ - extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode) + extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds) for pred in predictions ] extracted_golds = [ - extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode) for gold in golds + extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds) for gold in golds ] # Assert on empty gold and warn on empty pred @@ -270,7 +271,7 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc return aggregation_function( [ - (1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0) + (1.0 if any(compare_gold_target(gold, pred, precision, timeout_seconds=timeout_seconds) for gold in extracted_golds) else 0.0) for pred in extracted_predictions ] ) diff --git a/src/lighteval/metrics/utils/extractive_match_utils.py b/src/lighteval/metrics/utils/extractive_match_utils.py index 2d2b7677..07f4213e 100644 --- a/src/lighteval/metrics/utils/extractive_match_utils.py +++ b/src/lighteval/metrics/utils/extractive_match_utils.py @@ -414,7 +414,7 @@ def convert_to_pct(number: Number): @requires_latex2sympy2_extended @lru_cache(maxsize=20) def extract_latex(match: re.Match, latex_config: LatexExtractionConfig, timeout_seconds: int) -> tuple[sympy.Expr | str | None, str]: - from latex2sympy2_extended.latex2sympy2 import normalize_latex + from latex2sympy2_extended.latex2sympy2 import normalize_latex, FiniteSet as L2SFiniteSet latex_exprs = [] latex_strs = [] @@ -472,7 +472,7 @@ def extract_latex(match: re.Match, latex_config: LatexExtractionConfig, timeout_ all_elements.extend(expr.args) else: all_elements.append(expr) - return FiniteSet(*all_elements), " and ".join(latex_strs) + return L2SFiniteSet(*all_elements), " and ".join(latex_strs) # Otherwise return the single expression return latex_exprs[0], latex_strs[0] diff --git a/src/lighteval/metrics/utils/math_comparison.py b/src/lighteval/metrics/utils/math_comparison.py index d278f5b6..4108f59a 100644 --- a/src/lighteval/metrics/utils/math_comparison.py +++ b/src/lighteval/metrics/utils/math_comparison.py @@ -54,6 +54,7 @@ from sympy.core.relational import Relational from lighteval.utils.timeout import timeout +from latex2sympy2_extended.sets import FiniteSet as L2SFiniteSet def safe_sympy_doit(a: Basic | MatrixBase): @@ -181,19 +182,34 @@ def sympy_deep_compare_set_and_tuple(gold: FiniteSet | Tuple, pred: FiniteSet | Note: in order to fully support finite sets, we should ideally do kartesian product comparison but this is not implemented yet. We kinda hope sympy will order the elements. """ + from latex2sympy2_extended.sets import FiniteSet as L2SFiniteSet def unwrap_eq(s): if is_assignment_relation(s): return take_last_relation(s).rhs return s + def sort_key(x): + try: + return default_sort_key(unwrap_eq(x).evalf()) + except TimeoutError: + raise + except: + return default_sort_key(unwrap_eq(x)) + # This ensures it works for {1/3} and {0.333333} if len(gold) == len(pred): if isinstance(gold, FiniteSet): - gold_args = list(ordered(gold.args, keys=lambda x: default_sort_key(unwrap_eq(x)), default=False)) - pred_args = list(ordered(pred.args, keys=lambda x: default_sort_key(unwrap_eq(x)), default=False)) + gold_args = list(ordered(gold.args, keys=sort_key, default=False)) + pred_args = list(ordered(pred.args, keys=sort_key, default=False)) + + elif isinstance(gold, Tuple) and isinstance(pred, L2SFiniteSet): + # We treat the pred as tuple too + pred_args = pred._unsorted_args + gold_args = gold.args + elif isinstance(pred, FiniteSet): - pred_args = list(ordered(pred.args, keys=lambda x: default_sort_key(unwrap_eq(x)), default=False)) + pred_args = list(ordered(pred.args, keys=sort_key, default=False)) gold_args = gold.args else: gold_args = gold.args @@ -286,7 +302,7 @@ def is_equation(expr: Basic | MatrixBase) -> bool: @requires_latex2sympy2_extended def is_assignment_relation(expr: Basic | MatrixBase) -> bool: - from latex2sympy2_extended.latex2sympy2 import is_assignment_symbol + from latex2sympy2_extended.latex2sympy2 import is_expr_of_only_symbols """Check if an expression is an assignment relation. E.g a=1 Args: @@ -294,11 +310,11 @@ def is_assignment_relation(expr: Basic | MatrixBase) -> bool: Returns: bool: True if expr is a relational expression or And of relations, False otherwise """ - if isinstance(expr, Eq) and is_assignment_symbol(expr.lhs): + if isinstance(expr, Eq) and is_expr_of_only_symbols(expr.lhs): return True if isinstance(expr, And) and len(expr.args) > 0: - return all(isinstance(arg, Eq) for arg in expr.args) and is_assignment_symbol(expr.args[0].lhs) + return all(isinstance(arg, Eq) for arg in expr.args) and is_expr_of_only_symbols(expr.args[0].lhs) return False diff --git a/tests/metrics/test_extractive_match.py b/tests/metrics/test_extractive_match.py index 9ff34f5c..d7ad4c78 100644 --- a/tests/metrics/test_extractive_match.py +++ b/tests/metrics/test_extractive_match.py @@ -604,12 +604,6 @@ def test_latex_notation_math(gold, pred, expected): "$-x >= -1$", 1, ), - # Test incomplete equation - ( - "$a +z = 0$", - "$0$", - 0, - ), ], ) def test_relations_math(gold, pred, expected): @@ -1132,12 +1126,33 @@ def test_math_extraction_additional_cases(gold, pred, expected): r"$\text{Even}$", r"$Even$", 1 - ) + ), # ( # r"$f(x)$", # r"$f(y)$", # 1 # ) + + ( + r"$x_{1}=10^{\frac{-5+\sqrt{13}}{6}},\quadx_{2}=10^{\frac{-5-\sqrt{13}}{6}}$", + r"$\boxed{10^{\frac{\sqrt{13} - 5}{6}}} \quad \text{and} \quad \boxed{10^{-\frac{5 + \sqrt{13}}{6}}}$", + 1, + ), + ( + r"$y_{1}=-2 x^{2}+4 x+3, y_{2}=3 x^{2}+12 x+10$", + r"\($y_1 = \boxed{-2(x - 1)^2 + 5} \) and \( y_2 = \boxed{3(x + 2)^2 - 2} \) ", + 1, + ), + ( + r"$x_{1}=\frac{1}{2}+\frac{31\sqrt{5}}{216},\quadx_{2}=\frac{1}{2}-\frac{31\sqrt{5}}{216}$", + r"$\boxed{\dfrac{108 + 31\sqrt{5}}{216}} \quad \text{and} \quad \boxed{\dfrac{108 - 31\sqrt{5}}{216}}$", + 1, + ), + ( + r"$x_{1}=10^{\frac{-5+\sqrt{13}}{6}},\quadx_{2}=10^{\frac{-5-\sqrt{13}}{6}}$", + r"$\boxed{10^{\frac{\sqrt{13} - 5}{6}}} \quad \text{and} \quad \boxed{10^{-\frac{5 + \sqrt{13}}{6}}}$", + 1, + ), ]) def test_math_numina_cases(gold, pred, expected): assert compare_strings(gold, pred, match_types=["latex", "expr"]) == expected \ No newline at end of file