Skip to content

Commit

Permalink
update extraction match to reflect newest math-verify
Browse files Browse the repository at this point in the history
  • Loading branch information
hynky1999 committed Feb 4, 2025
1 parent d7a1f11 commit c2cb488
Show file tree
Hide file tree
Showing 5 changed files with 534 additions and 121 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ multilingual = [
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]
math = ["latex2sympy2_extended>=0.9.3"]
math = ["latex2sympy2_extended>=1.0.2"]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
233 changes: 168 additions & 65 deletions src/lighteval/metrics/utils/extractive_match_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
# SOFTWARE.

import re
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from functools import lru_cache
from itertools import groupby
from typing import Any, Literal, Sequence

import sympy
from sympy import Basic, MatrixBase, Number
from sympy import Basic, FiniteSet, MatrixBase, Number
from sympy.parsing import parse_expr

from lighteval.metrics.utils.math_comparison import should_treat_as_complex
Expand All @@ -48,7 +48,7 @@ def latex_normalization_config_default_factory():
units=True,
malformed_operators=True,
nits=True,
boxed=True,
boxed="all",
equations=True,
)

Expand Down Expand Up @@ -159,37 +159,91 @@ def lazy_expr_regex(expr_config: ExprExtractionConfig, language: Language) -> li
return [(re.compile(pattern), priority) for pattern, priority in regexes]


@lru_cache(maxsize=1)
def lazy_latex_regex(latex_config: LatexExtractionConfig, language: Language) -> list[tuple[re.Pattern[str], int]]:
# Only LaTeX expressions between delimiters
percent_re_group = r"(?P<percent>\s*(?:\\?%|[Pp]ercent|[Pp]ercentage|[Pp]ct))"
latex_envs_re = (
r"("
r"(?<!\\)\$\$(?P<latexDisplayDollar>[\s\S]+?)(?<!\\)\$\$|" # $$...$$ (display math, can be multiline)
r"(?<!\\)\\\[(?P<latexDisplayBracket>[\s\S]+?)(?<!\\)\\\]|" # \[...\] (display math, can be multiline)
r"(?<!\\|\d)\$(?P<latexInlineDollar>(?:\\[$]|[^\n$])+?)(?<!\\)\$|" # $...$ (inline math, single line, allows escaped $), we make sure it's not preceded by a digit to minimize false positives containing dollar as a unit
r"(?<!\\)\\\((?P<latexInlineParenthesis>[^\n]+?)(?<!\\)\\\)|" # \(...\) (inline math, single line)
r"(?<!\\)\[(?P<latexInlineBracket>[^\n$]+?)(?<!\\)\]" # [....] While this is not a valid display, math LLMs like to generate it. We allow it
rf"){percent_re_group}?"
)
def make_latex_env_pattern(prefix: str = "", context: Literal["boxed", "plain"] = "plain") -> str:
"""Creates a LaTeX environment pattern with uniquely prefixed group names.
Args:
prefix (str): Prefix to add to group names to make them unique
context (Literal["boxed", "plain"]): Type of content to match inside the environments
- "boxed": Match environments containing \boxed{...}
- "plain": Match any LaTeX content
Returns:
str: Regex pattern for matching LaTeX environments with percent suffix
"""
percent_re_group = rf"(?P<{prefix}percent>(?:\\?%|[Pp]ercent|[Pp]ercentage|[Pp]ct))"

# Define base content patterns
display_dollar_content = r"(?:[^$]|\$(?!\$))"
# Either \ not followed by ] or everything but \
display_content_bracket = r"(?:[^\\]|\\(?!\]))"
inline_dollar_content = r"(?:\\[$]|[^\n$])"
inline_content_parenthesis = r"(?:[^\\\n]|\\(?!\)))"
inline_content_bracket = r"[^\n\]\[]"

if context == "boxed":
# Rewrite patterns to optionally include boxed content
display_dollar_content = rf"{display_dollar_content}*?\\boxed{{{display_dollar_content}+?}}{display_dollar_content}*?"
display_content_bracket = rf"{display_content_bracket}*?\\boxed{{{display_content_bracket}+?}}{display_content_bracket}*?"
inline_dollar_content = rf"{inline_dollar_content}*?\\boxed{{{inline_dollar_content}+?}}{inline_dollar_content}*?"
inline_content_parenthesis = rf"{inline_content_parenthesis}*?\\boxed{{{inline_content_parenthesis}+?}}{inline_content_parenthesis}*?"
inline_content_bracket = rf"{inline_content_bracket}*?\\boxed{{{inline_content_bracket}+?}}{inline_content_bracket}*?"
else:
display_dollar_content = rf"{display_dollar_content}+?"
display_content_bracket = rf"{display_content_bracket}+?"
inline_dollar_content = rf"{inline_dollar_content}+?"
inline_content_parenthesis = rf"{inline_content_parenthesis}+?"
inline_content_bracket = rf"{inline_content_bracket}+?"

# Build list of regex patterns
patterns = [
# Display math environments (allow multiline)
rf"(?<!\\)\$\$(?P<{prefix}latexDisplayDollar>{display_dollar_content})(?<!\\)\$\$",
rf"(?<!\\)\\\[(?P<{prefix}latexDisplayBracket>{display_content_bracket})(?<!\\)\\\]",
# Inline math environments (single line only)
rf"(?<!\\|\d)\$(?P<{prefix}latexInlineDollar>{inline_dollar_content})(?<!\\)\$",
rf"(?<!\\)\\\((?P<{prefix}latexInlineParenthesis>{inline_content_parenthesis})(?<!\\)\\\)",
rf"\s\[(?P<{prefix}latexInlineBracket>{inline_content_bracket})\]\s",
]
if context == "boxed":
# allow also matching plain boxed
patterns.append(rf"(?P<{prefix}latexBoxed>\\boxed{{.+}})")
elif context == "plain":
simple_number = r"-?\d+(?:[.,]\d+)?"
patterns.append(rf"(?P<{prefix}latexFraction>-?\\frac{{{simple_number}}}{{{simple_number}}})")

# Join patterns with | and wrap in parentheses
latex_env_re = rf"(?:(?:{'|'.join(patterns)})\s*{percent_re_group}?)"

return latex_env_re

# Match latex without environments
latex_boxed = rf"(?P<latexBoxed>\\boxed{{.+}})\$?{percent_re_group}?" # Boxed number, it's fine to be as greedy as possible as we will find the correct end afterwards
simple_number = r"-?\d+(?:[.,]\d+)?"
latex_fraction = rf"(?P<latexFraction>-?\\frac{{{simple_number}}}{{{simple_number}}})\$?{percent_re_group}?"

@lru_cache(maxsize=1)
def lazy_latex_regex(
latex_config: LatexExtractionConfig,
language: Language
) -> list[tuple[re.Pattern[str], int]]:
translation_literal = TRANSLATION_LITERALS[language]
# Pattern for multiple latex environments connected by and/or
# Create patterns for up to 5 connected expressions
first_latex_group = make_latex_env_pattern('first_')
and_word = translation_literal.and_word
or_word = translation_literal.or_word
next_groups = ''.join([rf"(?:\s*(?:{and_word}|{or_word})\s*{make_latex_env_pattern(f'next{i}_')})?" for i in range(1, 6)])

latex_envs_re = rf"(?:{first_latex_group}{next_groups})"
colon_re = rf"[{re.escape(translation_literal.colon)}\:]"

answer_prefix_re = rf"(?i:{translation_literal.answer})"

# We first match boxed env, for some reason that's the most common case of output
# Then we match the latex with environments, then we try to match the fraction
regexes: list[tuple[str, int]] = []
for latex_re in [latex_envs_re, latex_fraction]:
for latex_re in [latex_envs_re]:
if language == Language.ENGLISH:
final_answer_prefixed_re = rf"(?i:final answer is)\:?\s*{latex_re}\.?\s?I hope"
final_answer_prefixed_just_is = rf"(?i:final answer.{{0,100}}?)\s+is\:?\s*{latex_re}"
final_answer_prefixed_just_is = (
rf"(?i:final answer.{{0,100}}?)\s+is\:?\s*{latex_re}"
)
regexes.append((final_answer_prefixed_re, 0))
regexes.append((final_answer_prefixed_just_is, 50))

Expand All @@ -203,8 +257,15 @@ def lazy_latex_regex(latex_config: LatexExtractionConfig, language: Language) ->
if latex_config.try_extract_without_anchor:
regexes.append((latex_re, 300))

# This ensures that boxed is matched right after the final answer xxxx
if latex_config.boxed_match_priority >= 0:
regexes.append((latex_boxed, latex_config.boxed_match_priority))
latex_re_boxed = make_latex_env_pattern(prefix='first_', context='boxed')
next_groups = ''.join([rf"(?:\s*(?:{and_word}|{or_word})\s*{make_latex_env_pattern(f'next{i}_', context='boxed')})?" for i in range(1, 6)])
latex_re_boxed = rf"{latex_re_boxed}{next_groups}"
regexes.append((latex_re_boxed, latex_config.boxed_match_priority))
# Match plain boxed, the issue with plain boxed is that it's impossible to know where it stops, so if there are
# till last }. We do the actuall extraction in the normalization step.
regexes.append((rf"(?P<first_latexBoxed>\\boxed{{.+}})", latex_config.boxed_match_priority))

return [(re.compile(pattern, re.DOTALL), priority) for pattern, priority in regexes]

Expand Down Expand Up @@ -268,7 +329,9 @@ def lazy_indices_regex(


def get_extraction_regexes(
formatted_doc: Doc, target_types: Sequence[ExtractionTarget], language: Language
formatted_doc: Doc,
target_types: Sequence[ExtractionTarget],
language: Language
) -> list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]]:
extraction_regexes: list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]] = [
(lazy_latex_regex(target_type, language), target_type)
Expand Down Expand Up @@ -296,21 +359,21 @@ def get_target_type_order(target_type: ExtractionTarget) -> int:

# Small cache, to catche repeated calls invalid parsing
@lru_cache(maxsize=20)
@timeout(timeout_seconds=5)
@requires_latex2sympy2_extended
def parse_latex_with_timeout(latex: str):
def parse_latex_with_timeout(latex: str, timeout_seconds: int):
from latex2sympy2_extended.latex2sympy2 import latex2sympy

return latex2sympy(latex, is_real=not should_treat_as_complex(latex), convert_degrees=False)
return timeout(timeout_seconds)(latex2sympy)(
latex, is_real=not should_treat_as_complex(latex), convert_degrees=False, normalization_config=None
)


@lru_cache(maxsize=20)
@timeout(timeout_seconds=5)
def parse_expr_with_timeout(expr: str):
return parse_expr(expr, evaluate=False)
def parse_expr_with_timeout(expr: str, timeout_seconds: int):
return timeout(timeout_seconds)(parse_expr)(expr, evaluate=False)


def extract_expr(match: re.Match) -> tuple[str | sympy.Expr | None, str]:
def extract_expr(match: re.Match, timeout_seconds: int) -> tuple[str | sympy.Expr | None, str]:
# First combine the number
groups = match.groupdict()
# Expr group will always exist because every regex has it
Expand Down Expand Up @@ -338,7 +401,7 @@ def extract_expr(match: re.Match) -> tuple[str | sympy.Expr | None, str]:
# Remove new lines and spaces
if expr:
try:
return parse_expr_with_timeout(expr.replace("\n", " ").replace("^", "**")), expr
return parse_expr_with_timeout(expr.replace("\n", " ").replace("^", "**"), timeout_seconds), expr
except: # noqa: E722
pass
return None, expr
Expand All @@ -348,52 +411,90 @@ def convert_to_pct(number: Number):
return sympy.Mul(number, sympy.Rational(1, 100), evaluate=False)


@lru_cache(maxsize=1000)
@timeout(timeout_seconds=5)
@requires_latex2sympy2_extended
def extract_latex(match: re.Match) -> tuple[sympy.Expr | str | None, str]:
from latex2sympy2_extended.latex2sympy2 import NormalizationConfig, normalize_latex

latex = next((val for name, val in match.groupdict().items() if name.startswith("latex") and val), "")
is_percentage = True if match.group("percent") else False

normalized_latex = normalize_latex(
latex,
NormalizationConfig(
basic_latex=True,
units=True,
malformed_operators=True,
nits=True,
boxed=True,
equations=True,
),
@lru_cache(maxsize=20)
def extract_latex(match: re.Match, latex_config: LatexExtractionConfig, timeout_seconds: int) -> tuple[sympy.Expr | str | None, str]:
from latex2sympy2_extended.latex2sympy2 import normalize_latex
latex_exprs = []
latex_strs = []

# Get all latex groups (both first_ and nextN_ prefixes)
first_latex_group = next(
((val, name) for name, val in match.groupdict().items() if name.startswith("first_latex") and val),
None
)

try:
parsed_latex = parse_latex_with_timeout(normalized_latex)
if is_percentage:
parsed_latex = convert_to_pct(parsed_latex)
except: # noqa: E722
return None, normalized_latex
return parsed_latex, normalized_latex


def extract_match(match: re.Match, target_type: ExtractionTarget) -> tuple[Basic | MatrixBase | str | None, str]:

# Get all nextN_ groups
next_latex_groups = [
next(
((val, name) for name, val in match.groupdict().items() if name.startswith(f"next{i}_latex") and val),
None
)
for i in range(1, 6)
]

all_latex = list(filter(lambda x: x is not None, [first_latex_group] + next_latex_groups))

for latex, name in all_latex:
name_without_prefix = name.split('_')[0]
group_name = name.split('_')[1] if len(name.split('_')) > 1 else None
is_percentage = True if match.groupdict().get(f"{name_without_prefix}_percent") else False

# Use modified config if group name is 'boxed'
config = latex_config.normalization_config
if group_name == 'latexBoxed':
config = replace(config, boxed="last") # Use replace to modify single field

normalized_latex = normalize_latex(
latex,
config=config,
)
latex_strs.append(normalized_latex)

try:
parsed_latex = parse_latex_with_timeout(normalized_latex, timeout_seconds=timeout_seconds)
if is_percentage:
parsed_latex = convert_to_pct(parsed_latex)
latex_exprs.append(parsed_latex)
except: # noqa: E722
latex_exprs.append(None)
pass

if not latex_exprs:
return None, ""

# If we have multiple expressions and all of them are parsed, wrap them in a Tuple
if len(latex_exprs) > 1 and all(expr is not None for expr in latex_exprs):
# To handle solution is: 1,2 and 3
all_elements = []
for expr in latex_exprs:
if isinstance(expr, FiniteSet):
all_elements.extend(expr.args)
else:
all_elements.append(expr)
return FiniteSet(*all_elements), " and ".join(latex_strs)

# Otherwise return the single expression
return latex_exprs[0], latex_strs[0]


def extract_match(match: re.Match, target_type: ExtractionTarget, timeout_seconds: int) -> tuple[Basic | MatrixBase | str | None, str]:
"""Extracts the match from the regex match.
Args:
match (re.Match): The regex match object containing the extracted text
target_type (ExtractionTarget): The type of extraction to perform (latex, expression, or indices)
timeout_seconds (int): Maximum time in seconds to spend parsing expressions
Returns:
tuple[Basic | MatrixBase | str | None, str]: A tuple containing:
- The extracted and parsed value (if successful) or None (if parsing failed)
- The string representation of the extracted text
"""
if isinstance(target_type, LatexExtractionConfig):
return extract_latex(match)
return extract_latex(match, target_type, timeout_seconds=timeout_seconds)
elif isinstance(target_type, ExprExtractionConfig):
return extract_expr(match)
return extract_expr(match, timeout_seconds=timeout_seconds)
elif isinstance(target_type, IndicesExtractionConfig):
return match.group("indices"), match.group("indices")

Expand All @@ -403,6 +504,7 @@ def extract_target_from_pred(
target_res: list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]],
fallback_mode: Literal["no_fallback", "first_match"] = "no_fallback",
extraction_mode: Literal["first_match", "any_match"] = "any_match",
timeout_seconds: int = 5,
):
"""Extracts targets from a prediction string using regex patterns.
Returns first sucesffuly extracted match.
Expand All @@ -416,6 +518,7 @@ def extract_target_from_pred(
extraction_mode (Literal["first_match", "any_match"], optional): How to handle extraction failures. Defaults to "any_match".
- "first_match": Only tries to extract the first match
- "any_match": Tries to extract any match
timeout_seconds (int, optional): Maximum time in seconds to spend parsing each expression. Defaults to 5.
Returns:
list: List of extracted predictions, with first fallbac string appended if fallback_mode is "first_match"
Expand Down Expand Up @@ -445,7 +548,7 @@ def extract_target_from_pred(

# Try to extract from each match, starting from rightmost
for match, _, _, target_type in matches_with_pos:
extracted_match, str_fallback = extract_match(match, target_type)
extracted_match, str_fallback = extract_match(match, target_type, timeout_seconds)
match_found = True

if str_fallback:
Expand Down
Loading

0 comments on commit c2cb488

Please sign in to comment.