Skip to content

Commit

Permalink
revert symbols, improve sets handling
Browse files Browse the repository at this point in the history
  • Loading branch information
hynky1999 committed Feb 5, 2025
1 parent c2cb488 commit c536de0
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ multilingual = [
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]
math = ["latex2sympy2_extended>=1.0.2"]
math = ["latex2sympy2_extended==1.0.3"]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
7 changes: 4 additions & 3 deletions src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def multilingual_extractive_match_metric(
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
extraction_mode: Literal["first_match", "any_match"] = "any_match",
precision: int = 6,
timeout_seconds: int = 5,
) -> SampleLevelMetric:
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
Expand Down Expand Up @@ -245,11 +246,11 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)

extracted_predictions = [
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode)
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds)
for pred in predictions
]
extracted_golds = [
extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode) for gold in golds
extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode, extraction_mode, timeout_seconds) for gold in golds
]

# Assert on empty gold and warn on empty pred
Expand All @@ -270,7 +271,7 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc

return aggregation_function(
[
(1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0)
(1.0 if any(compare_gold_target(gold, pred, precision, timeout_seconds=timeout_seconds) for gold in extracted_golds) else 0.0)
for pred in extracted_predictions
]
)
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/metrics/utils/extractive_match_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def convert_to_pct(number: Number):
@requires_latex2sympy2_extended
@lru_cache(maxsize=20)
def extract_latex(match: re.Match, latex_config: LatexExtractionConfig, timeout_seconds: int) -> tuple[sympy.Expr | str | None, str]:
from latex2sympy2_extended.latex2sympy2 import normalize_latex
from latex2sympy2_extended.latex2sympy2 import normalize_latex, FiniteSet as L2SFiniteSet
latex_exprs = []
latex_strs = []

Expand Down Expand Up @@ -472,7 +472,7 @@ def extract_latex(match: re.Match, latex_config: LatexExtractionConfig, timeout_
all_elements.extend(expr.args)
else:
all_elements.append(expr)
return FiniteSet(*all_elements), " and ".join(latex_strs)
return L2SFiniteSet(*all_elements), " and ".join(latex_strs)

# Otherwise return the single expression
return latex_exprs[0], latex_strs[0]
Expand Down
28 changes: 22 additions & 6 deletions src/lighteval/metrics/utils/math_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from sympy.core.relational import Relational

from lighteval.utils.timeout import timeout
from latex2sympy2_extended.sets import FiniteSet as L2SFiniteSet


def safe_sympy_doit(a: Basic | MatrixBase):
Expand Down Expand Up @@ -181,19 +182,34 @@ def sympy_deep_compare_set_and_tuple(gold: FiniteSet | Tuple, pred: FiniteSet |
Note: in order to fully support finite sets, we should ideally do kartesian product comparison
but this is not implemented yet. We kinda hope sympy will order the elements.
"""
from latex2sympy2_extended.sets import FiniteSet as L2SFiniteSet
def unwrap_eq(s):
if is_assignment_relation(s):
return take_last_relation(s).rhs
return s

def sort_key(x):
try:
return default_sort_key(unwrap_eq(x).evalf())
except TimeoutError:
raise
except:
return default_sort_key(unwrap_eq(x))


# This ensures it works for {1/3} and {0.333333}
if len(gold) == len(pred):
if isinstance(gold, FiniteSet):
gold_args = list(ordered(gold.args, keys=lambda x: default_sort_key(unwrap_eq(x)), default=False))
pred_args = list(ordered(pred.args, keys=lambda x: default_sort_key(unwrap_eq(x)), default=False))
gold_args = list(ordered(gold.args, keys=sort_key, default=False))
pred_args = list(ordered(pred.args, keys=sort_key, default=False))

elif isinstance(gold, Tuple) and isinstance(pred, L2SFiniteSet):
# We treat the pred as tuple too
pred_args = pred._unsorted_args
gold_args = gold.args

elif isinstance(pred, FiniteSet):
pred_args = list(ordered(pred.args, keys=lambda x: default_sort_key(unwrap_eq(x)), default=False))
pred_args = list(ordered(pred.args, keys=sort_key, default=False))
gold_args = gold.args
else:
gold_args = gold.args
Expand Down Expand Up @@ -286,19 +302,19 @@ def is_equation(expr: Basic | MatrixBase) -> bool:

@requires_latex2sympy2_extended
def is_assignment_relation(expr: Basic | MatrixBase) -> bool:
from latex2sympy2_extended.latex2sympy2 import is_assignment_symbol
from latex2sympy2_extended.latex2sympy2 import is_expr_of_only_symbols
"""Check if an expression is an assignment relation. E.g a=1
Args:
expr: The expression to check
Returns:
bool: True if expr is a relational expression or And of relations, False otherwise
"""
if isinstance(expr, Eq) and is_assignment_symbol(expr.lhs):
if isinstance(expr, Eq) and is_expr_of_only_symbols(expr.lhs):
return True

if isinstance(expr, And) and len(expr.args) > 0:
return all(isinstance(arg, Eq) for arg in expr.args) and is_assignment_symbol(expr.args[0].lhs)
return all(isinstance(arg, Eq) for arg in expr.args) and is_expr_of_only_symbols(expr.args[0].lhs)

return False

Expand Down
29 changes: 22 additions & 7 deletions tests/metrics/test_extractive_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,12 +604,6 @@ def test_latex_notation_math(gold, pred, expected):
"$-x >= -1$",
1,
),
# Test incomplete equation
(
"$a +z = 0$",
"$0$",
0,
),
],
)
def test_relations_math(gold, pred, expected):
Expand Down Expand Up @@ -1132,12 +1126,33 @@ def test_math_extraction_additional_cases(gold, pred, expected):
r"$\text{Even}$",
r"$Even$",
1
)
),
# (
# r"$f(x)$",
# r"$f(y)$",
# 1
# )
(
r"$x_{1}=10^{\frac{-5+\sqrt{13}}{6}},\quadx_{2}=10^{\frac{-5-\sqrt{13}}{6}}$",
r"$\boxed{10^{\frac{\sqrt{13} - 5}{6}}} \quad \text{and} \quad \boxed{10^{-\frac{5 + \sqrt{13}}{6}}}$",
1,
),
(
r"$y_{1}=-2 x^{2}+4 x+3, y_{2}=3 x^{2}+12 x+10$",
r"\($y_1 = \boxed{-2(x - 1)^2 + 5} \) and \( y_2 = \boxed{3(x + 2)^2 - 2} \) ",
1,
),
(
r"$x_{1}=\frac{1}{2}+\frac{31\sqrt{5}}{216},\quadx_{2}=\frac{1}{2}-\frac{31\sqrt{5}}{216}$",
r"$\boxed{\dfrac{108 + 31\sqrt{5}}{216}} \quad \text{and} \quad \boxed{\dfrac{108 - 31\sqrt{5}}{216}}$",
1,
),
(
r"$x_{1}=10^{\frac{-5+\sqrt{13}}{6}},\quadx_{2}=10^{\frac{-5-\sqrt{13}}{6}}$",
r"$\boxed{10^{\frac{\sqrt{13} - 5}{6}}} \quad \text{and} \quad \boxed{10^{-\frac{5 + \sqrt{13}}{6}}}$",
1,
),
])
def test_math_numina_cases(gold, pred, expected):
assert compare_strings(gold, pred, match_types=["latex", "expr"]) == expected

0 comments on commit c536de0

Please sign in to comment.