Skip to content

Commit

Permalink
Extractive Match metric (#495)
Browse files Browse the repository at this point in the history
* extract matching

* better docstring

* lazy imports

* bump up math

* Update src/lighteval/metrics/dynamic_metrics.py

Co-authored-by: Clémentine Fourrier <[email protected]>

* fix pr commnets

* Apply suggestions from code review

Co-authored-by: Clémentine Fourrier <[email protected]>

* rename comparisson -> comparison

---------

Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
hynky1999 and clefourrier authored Jan 15, 2025
1 parent a7aa6ed commit 59624c8
Show file tree
Hide file tree
Showing 8 changed files with 2,052 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ tensorboardX = ["tensorboardX"]
vllm = ["vllm", "ray", "more_itertools"]
quality = ["ruff==v0.2.2","pre-commit"]
tests = ["pytest==7.4.0"]
dev = ["lighteval[accelerate,quality,tests,multilingual]"]
dev = ["lighteval[accelerate,quality,tests,multilingual,math]"]
docs = ["hf-doc-builder", "watchdog"]
extended_tasks = [
"langdetect", # ifeval
Expand All @@ -109,6 +109,7 @@ multilingual = [
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]
math = ["latex2sympy2_extended>=0.9.0"]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
108 changes: 107 additions & 1 deletion src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Callable, Literal
import logging
from typing import Callable, Literal, Sequence

import numpy as np

Expand All @@ -37,8 +38,22 @@
LogProbTokenNorm,
get_multilingual_normalizer,
)
from lighteval.metrics.utils.extractive_match_utils import ( # noqa: F401
ExprExtractionConfig,
ExtractionTarget,
IndicesExtractionConfig,
LatexExtractionConfig,
extract_target_from_pred,
get_extraction_regexes,
)
from lighteval.metrics.utils.math_comparison import compare_gold_target
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
from lighteval.utils.timeout import timeout


logger = logging.getLogger(__name__)


def loglikelihood_acc_metric(normalization: LogProbNormalization | None = None) -> SampleLevelMetric:
Expand Down Expand Up @@ -168,3 +183,94 @@ def multilingual_quasi_exact_match_metric(
corpus_level_fn=np.mean,
higher_is_better=True,
)


def multilingual_extractive_match_metric(
language: Language = Language.ENGLISH,
gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
aggregation_function: Callable[[list[float]], float] = max,
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
precision: int = 6,
) -> SampleLevelMetric:
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
Known issues:
- If the task is to simplify an expression, the metric might overestimate the accuracy. This is because if the model doesn't output any anchor for the extraction (e.g final answer is..),
it's possible that the the extracted prediction will be the expression to simplify. Because we do simplifications ourselves, it can thus happen that sympy will correctly simplify the expression,
thus it will match gold, despite model not doing anything. PRs to fix this are welcome.
- There is currently no StringExtractionConfig, so if the gold is \boxed{\text{Friday}} and model outputs Friday it will not match, because nothing will be extracted.
Args:
language: Language
The language of the samples.
gold_extraction_target: Sequence[ExtractionTarget]
Extraction targets to use for gold answers. Defaults to extracting simple math expressions.
pred_extraction_target: Sequence[ExtractionTarget]
Extraction targets to use for predictions. Defaults to extracting simple math expressions.
aggregation_function: Callable[[list[float]], float]
Function to aggregate scores when multiple golds/predictions are present. Defaults to max.
fallback_mode: Literal["no_fallback", "first_match"]
How to perform extraction. Defaults to "first_match".
- "no_fallback": Only use first successfully parsed matches
- "first_match": Use the first successfully parsed match + first match irregardless the parsing success
precision: int
Number of decimal places to use when comparing numerical values. Defaults to 6.
Returns:
A sample level metric that extracts and compares mathematical expressions.
"""

@timeout(2)
def add_to_specifics_with_timeout(
formatted_doc: Doc, extracted_predictions: list[list[str]], extracted_golds: list[list[str]]
) -> None:
if formatted_doc.specific is None:
formatted_doc.specific = {}

formatted_doc.specific["extracted_predictions"] = [
str(pred) for preds in extracted_predictions for pred in preds
]
formatted_doc.specific["extracted_golds"] = [str(gold) for golds in extracted_golds for gold in golds]

def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc) -> float:
gold_extraction_regexes = get_extraction_regexes(formatted_doc, gold_extraction_target, language)
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)

extracted_predictions = [
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions
]
extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds]

# Assert on empty gold and warn on empty pred
if any(len(g) == 0 for g in extracted_golds):
raise ValueError(f"No gold targets found for at least one gold. Gold: {golds}, Pred: {predictions}")

if all(len(p) == 0 for p in extracted_predictions):
logger.warning(
f"We did not manage to extract a prediction in the correct format. Gold: {golds}, Pred: {predictions}"
)

# We have to use timeout because the sypmy to str conversion can be very slow
try:
add_to_specifics_with_timeout(formatted_doc, extracted_predictions, extracted_golds)
except: # noqa: E722
logger.warning("Timeout when adding extracted predictions and golds to specific")

return aggregation_function(
[
(1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0)
for pred in extracted_predictions
]
)

return SampleLevelMetric(
metric_name="extractive_match",
sample_level_fn=sample_level_fn,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
higher_is_better=True,
)
Loading

0 comments on commit 59624c8

Please sign in to comment.