diff --git a/src/lighteval/metrics/dynamic_metrics.py b/src/lighteval/metrics/dynamic_metrics.py index 577934e9..b827df1e 100644 --- a/src/lighteval/metrics/dynamic_metrics.py +++ b/src/lighteval/metrics/dynamic_metrics.py @@ -188,7 +188,7 @@ def multilingual_quasi_exact_match_metric( def multilingual_extractive_match_metric( language: Language = Language.ENGLISH, gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),), - pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),), + pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(), LatexExtractionConfig()), aggregation_function: Callable[[list[float]], float] = max, fallback_mode: Literal["no_fallback", "first_match"] = "first_match", precision: int = 6, diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py index ff4b6b05..58224481 100644 --- a/src/lighteval/metrics/metrics.py +++ b/src/lighteval/metrics/metrics.py @@ -24,6 +24,7 @@ import numpy as np from aenum import Enum +from lighteval.metrics.dynamic_metrics import multilingual_extractive_match_metric from lighteval.metrics.harness_compatibility.drop import drop_metrics from lighteval.metrics.harness_compatibility.truthful_qa import truthfulqa_mc_metrics from lighteval.metrics.metrics_corpus import ( @@ -58,7 +59,15 @@ remove_braces, remove_braces_and_strip, ) -from lighteval.metrics.sample_preparator import GenerativePreparator, LoglikelihoodPreparator, PerplexityPreparator +from lighteval.metrics.sample_preparator import ( + GenerativePreparator, + LoglikelihoodPreparator, + PerplexityPreparator, +) +from lighteval.metrics.utils.extractive_match_utils import ( + ExprExtractionConfig, + LatexExtractionConfig, +) from lighteval.metrics.utils.metric_utils import ( CorpusLevelMetric, CorpusLevelMetricGrouping, @@ -69,6 +78,7 @@ SampleLevelMetric, SampleLevelMetricGrouping, ) +from lighteval.utils.language import Language from lighteval.utils.utils import as_list @@ -86,8 +96,16 @@ class Metrics(Enum): sample_level_fn=BertScore(normalize_gold=remove_braces, normalize_pred=remove_braces_and_strip).compute, category=MetricCategory.GENERATIVE, use_case=MetricUseCase.SUMMARIZATION, - corpus_level_fn={"BERTScore-P": np.mean, "BERTScore-R": np.mean, "BERTScore-F": np.mean}, - higher_is_better={"BERTScore-P": True, "BERTScore-R": True, "BERTScore-F": True}, + corpus_level_fn={ + "BERTScore-P": np.mean, + "BERTScore-R": np.mean, + "BERTScore-F": np.mean, + }, + higher_is_better={ + "BERTScore-P": True, + "BERTScore-R": True, + "BERTScore-F": True, + }, ) bits_per_byte = CorpusLevelMetric( metric_name="bits_per_byte", @@ -147,14 +165,31 @@ class Metrics(Enum): higher_is_better=True, ) copyright = SampleLevelMetricGrouping( - metric_name=["longest_common_prefix_length", "edit_distance", "edit_similarity"], + metric_name=[ + "longest_common_prefix_length", + "edit_distance", + "edit_similarity", + ], sample_level_fn=StringDistance( - metric_types=["longest_common_prefix_length", "edit_distance", "edit_similarity"], strip_prediction=True + metric_types=[ + "longest_common_prefix_length", + "edit_distance", + "edit_similarity", + ], + strip_prediction=True, ).compute, category=MetricCategory.GENERATIVE, use_case=MetricUseCase.SOCIAL_IMPACTS, - corpus_level_fn={"longest_common_prefix_length": max, "edit_distance": min, "edit_similarity": max}, - higher_is_better={"longest_common_prefix_length": True, "edit_distance": False, "edit_similarity": True}, + corpus_level_fn={ + "longest_common_prefix_length": max, + "edit_distance": min, + "edit_similarity": max, + }, + higher_is_better={ + "longest_common_prefix_length": True, + "edit_distance": False, + "edit_similarity": True, + }, ) drop = SampleLevelMetricGrouping( metric_name=["qem", "f1"], @@ -173,9 +208,15 @@ class Metrics(Enum): higher_is_better=True, ) extractiveness = SampleLevelMetricGrouping( - metric_name=["summarization_coverage", "summarization_density", "summarization_compression"], + metric_name=[ + "summarization_coverage", + "summarization_density", + "summarization_compression", + ], sample_level_fn=Extractiveness( - normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text" + normalize_input=remove_braces, + normalize_pred=remove_braces_and_strip, + input_column="text", ).compute, category=MetricCategory.GENERATIVE, use_case=MetricUseCase.SUMMARIZATION, @@ -225,7 +266,9 @@ class Metrics(Enum): faithfulness = SampleLevelMetric( metric_name="summac", sample_level_fn=Faithfulness( - normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text" + normalize_input=remove_braces, + normalize_pred=remove_braces_and_strip, + input_column="text", ).compute, category=MetricCategory.GENERATIVE, use_case=MetricUseCase.SUMMARIZATION, @@ -307,7 +350,10 @@ class Metrics(Enum): maj_at_4_math = SampleLevelMetric( metric_name="maj@4", sample_level_fn=MajAtK( - k=4, strip_strings=True, normalize_pred=math_normalizer, normalize_gold=math_normalizer + k=4, + strip_strings=True, + normalize_pred=math_normalizer, + normalize_gold=math_normalizer, ).compute, category=MetricCategory.GENERATIVE_SAMPLING, use_case=MetricUseCase.MATH, @@ -333,7 +379,10 @@ class Metrics(Enum): maj_at_8_gsm8k = SampleLevelMetric( metric_name="maj@8", sample_level_fn=MajAtK( - k=8, strip_strings=True, normalize_pred=gsm8k_normalizer, normalize_gold=gsm8k_normalizer + k=8, + strip_strings=True, + normalize_pred=gsm8k_normalizer, + normalize_gold=gsm8k_normalizer, ).compute, category=MetricCategory.GENERATIVE_SAMPLING, use_case=MetricUseCase.MATH, @@ -415,7 +464,9 @@ class Metrics(Enum): quasi_exact_match_math = SampleLevelMetric( metric_name="qem", sample_level_fn=ExactMatches( - strip_strings=True, normalize_pred=math_normalizer, normalize_gold=math_normalizer + strip_strings=True, + normalize_pred=math_normalizer, + normalize_gold=math_normalizer, ).compute, category=MetricCategory.GENERATIVE, use_case=MetricUseCase.MATH, @@ -433,7 +484,9 @@ class Metrics(Enum): quasi_exact_match_gsm8k = SampleLevelMetric( metric_name="qem", sample_level_fn=ExactMatches( - strip_strings=True, normalize_pred=gsm8k_normalizer, normalize_gold=gsm8k_normalizer + strip_strings=True, + normalize_pred=gsm8k_normalizer, + normalize_gold=gsm8k_normalizer, ).compute, category=MetricCategory.GENERATIVE, use_case=MetricUseCase.MATH, @@ -482,8 +535,18 @@ class Metrics(Enum): ).compute, category=MetricCategory.GENERATIVE, use_case=MetricUseCase.ACCURACY, - corpus_level_fn={"rouge1": np.mean, "rouge2": np.mean, "rougeL": np.mean, "rougeLsum": np.mean}, - higher_is_better={"rouge1": True, "rouge2": True, "rougeL": True, "rougeLsum": True}, + corpus_level_fn={ + "rouge1": np.mean, + "rouge2": np.mean, + "rougeL": np.mean, + "rougeLsum": np.mean, + }, + higher_is_better={ + "rouge1": True, + "rouge2": True, + "rougeL": True, + "rougeLsum": True, + }, ) rouge1 = SampleLevelMetric( metric_name="rouge1", @@ -541,6 +604,16 @@ class Metrics(Enum): corpus_level_fn={"truthfulqa_mc1": np.mean, "truthfulqa_mc2": np.mean}, higher_is_better={"truthfulqa_mc1": True, "truthfulqa_mc2": True}, ) + math_gold_as_latex_verifier = multilingual_extractive_match_metric( + Language.ENGLISH, + gold_extraction_target=[LatexExtractionConfig()], + fallback_mode="first_match", + ) + math_gold_as_expr_verifier = multilingual_extractive_match_metric( + Language.ENGLISH, + gold_extraction_target=[ExprExtractionConfig()], + fallback_mode="first_match", + ) word_perplexity = CorpusLevelMetric( metric_name="word_perplexity", sample_level_fn=PerplexityPreparator(units_type="words").prepare, diff --git a/src/lighteval/tasks/default_tasks.py b/src/lighteval/tasks/default_tasks.py index d6a7ec49..595b79ba 100644 --- a/src/lighteval/tasks/default_tasks.py +++ b/src/lighteval/tasks/default_tasks.py @@ -22,6 +22,8 @@ import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.templates.qa import get_qa_prompt_function +from lighteval.utils.language import Language abstract_narrative_understanding_bigbench = LightevalTaskConfig( @@ -5503,6 +5505,30 @@ trust_dataset=True, version=0, ) + +natural_questions = LightevalTaskConfig( + name="natural_questions", + prompt_function=get_qa_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "context": line["context"], + "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], + }, + ), + suite=("lighteval",), + hf_repo="lighteval/SimpleQA", + hf_subset="default", + evaluation_splits=("validation",), + few_shots_split="train", + generation_size=200, + stop_sequence=("\n",), + metric=( + Metrics.prefix_quasi_exact_match, + Metrics.f1_score_quasi, + ), +) + blimp_wh_questions_subject_gap_helm = LightevalTaskConfig( name="blimp:wh_questions_subject_gap", suite=["helm", "blimp"], @@ -6557,6 +6583,30 @@ trust_dataset=True, version=0, ) + +coqa_first_question = LightevalTaskConfig( + name="coqa_first_question", + prompt_function=get_qa_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"][0], + "context": line["story"], + "choices": [line["answers"]["input_text"][0]], + }, + ), + suite=("lighteval",), + hf_repo="stadford/coqa", + hf_subset="default", + hf_avail_splits=["train", "validation"], + evaluation_splits=["validation"], + generation_size=150, + stop_sequence=("\n",), + metric=( + Metrics.prefix_quasi_exact_match, + Metrics.f1_score_quasi, + ), +) + coqa_bb_lighteval = LightevalTaskConfig( name="coqa_bb", suite=["lighteval", "bigbench_programmatic", "bigbench"], @@ -15180,6 +15230,74 @@ trust_dataset=True, version=0, ) +squad_v2 = LightevalTaskConfig( + name="squad_v2", + prompt_function=get_qa_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "context": line["context"], + "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], + }, + ), + suite=("lighteval",), + hf_repo="rajpurkar/squad_v2", + hf_subset="default", + evaluation_splits=("validation",), + few_shots_split="train", + generation_size=200, + stop_sequence=("\n",), + metric=( + Metrics.prefix_quasi_exact_match, + Metrics.f1_score_quasi, + ), +) + +jeopardy = LightevalTaskConfig( + name="jeopardy", + prompt_function=get_qa_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": [line["answer"]], + }, + ), + suite=("lighteval",), + hf_repo="openaccess-ai-collective/jeopardy", + hf_subset="default", + evaluation_splits=("train",), + few_shots_split="train", + generation_size=50, + stop_sequence=("\n",), + metric=( + Metrics.prefix_quasi_exact_match, + Metrics.f1_score_quasi, + ), +) + +simple_qa = LightevalTaskConfig( + name="squad_v2", + prompt_function=get_qa_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "context": line["context"], + "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], + }, + ), + suite=("lighteval",), + hf_repo="lighteval/SimpleQA", + hf_subset="default", + evaluation_splits=("test",), + few_shots_split="few_shot", + generation_size=250, + stop_sequence=("\n",), + metric=( + Metrics.prefix_quasi_exact_match, + Metrics.f1_score_quasi, + ), +) + swag_lighteval = LightevalTaskConfig( name="swag", suite=["lighteval"],