Skip to content

Commit

Permalink
add qa generative tasks + add some metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
hynky1999 committed Jan 28, 2025
1 parent cb075a5 commit b6c50a2
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def multilingual_quasi_exact_match_metric(
def multilingual_extractive_match_metric(
language: Language = Language.ENGLISH,
gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(), LatexExtractionConfig()),
aggregation_function: Callable[[list[float]], float] = max,
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
precision: int = 6,
Expand Down
105 changes: 89 additions & 16 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
from aenum import Enum

from lighteval.metrics.dynamic_metrics import multilingual_extractive_match_metric
from lighteval.metrics.harness_compatibility.drop import drop_metrics
from lighteval.metrics.harness_compatibility.truthful_qa import truthfulqa_mc_metrics
from lighteval.metrics.metrics_corpus import (
Expand Down Expand Up @@ -58,7 +59,15 @@
remove_braces,
remove_braces_and_strip,
)
from lighteval.metrics.sample_preparator import GenerativePreparator, LoglikelihoodPreparator, PerplexityPreparator
from lighteval.metrics.sample_preparator import (
GenerativePreparator,
LoglikelihoodPreparator,
PerplexityPreparator,
)
from lighteval.metrics.utils.extractive_match_utils import (
ExprExtractionConfig,
LatexExtractionConfig,
)
from lighteval.metrics.utils.metric_utils import (
CorpusLevelMetric,
CorpusLevelMetricGrouping,
Expand All @@ -69,6 +78,7 @@
SampleLevelMetric,
SampleLevelMetricGrouping,
)
from lighteval.utils.language import Language
from lighteval.utils.utils import as_list


Expand All @@ -86,8 +96,16 @@ class Metrics(Enum):
sample_level_fn=BertScore(normalize_gold=remove_braces, normalize_pred=remove_braces_and_strip).compute,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.SUMMARIZATION,
corpus_level_fn={"BERTScore-P": np.mean, "BERTScore-R": np.mean, "BERTScore-F": np.mean},
higher_is_better={"BERTScore-P": True, "BERTScore-R": True, "BERTScore-F": True},
corpus_level_fn={
"BERTScore-P": np.mean,
"BERTScore-R": np.mean,
"BERTScore-F": np.mean,
},
higher_is_better={
"BERTScore-P": True,
"BERTScore-R": True,
"BERTScore-F": True,
},
)
bits_per_byte = CorpusLevelMetric(
metric_name="bits_per_byte",
Expand Down Expand Up @@ -147,14 +165,31 @@ class Metrics(Enum):
higher_is_better=True,
)
copyright = SampleLevelMetricGrouping(
metric_name=["longest_common_prefix_length", "edit_distance", "edit_similarity"],
metric_name=[
"longest_common_prefix_length",
"edit_distance",
"edit_similarity",
],
sample_level_fn=StringDistance(
metric_types=["longest_common_prefix_length", "edit_distance", "edit_similarity"], strip_prediction=True
metric_types=[
"longest_common_prefix_length",
"edit_distance",
"edit_similarity",
],
strip_prediction=True,
).compute,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.SOCIAL_IMPACTS,
corpus_level_fn={"longest_common_prefix_length": max, "edit_distance": min, "edit_similarity": max},
higher_is_better={"longest_common_prefix_length": True, "edit_distance": False, "edit_similarity": True},
corpus_level_fn={
"longest_common_prefix_length": max,
"edit_distance": min,
"edit_similarity": max,
},
higher_is_better={
"longest_common_prefix_length": True,
"edit_distance": False,
"edit_similarity": True,
},
)
drop = SampleLevelMetricGrouping(
metric_name=["qem", "f1"],
Expand All @@ -173,9 +208,15 @@ class Metrics(Enum):
higher_is_better=True,
)
extractiveness = SampleLevelMetricGrouping(
metric_name=["summarization_coverage", "summarization_density", "summarization_compression"],
metric_name=[
"summarization_coverage",
"summarization_density",
"summarization_compression",
],
sample_level_fn=Extractiveness(
normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text"
normalize_input=remove_braces,
normalize_pred=remove_braces_and_strip,
input_column="text",
).compute,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.SUMMARIZATION,
Expand Down Expand Up @@ -225,7 +266,9 @@ class Metrics(Enum):
faithfulness = SampleLevelMetric(
metric_name="summac",
sample_level_fn=Faithfulness(
normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text"
normalize_input=remove_braces,
normalize_pred=remove_braces_and_strip,
input_column="text",
).compute,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.SUMMARIZATION,
Expand Down Expand Up @@ -307,7 +350,10 @@ class Metrics(Enum):
maj_at_4_math = SampleLevelMetric(
metric_name="maj@4",
sample_level_fn=MajAtK(
k=4, strip_strings=True, normalize_pred=math_normalizer, normalize_gold=math_normalizer
k=4,
strip_strings=True,
normalize_pred=math_normalizer,
normalize_gold=math_normalizer,
).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.MATH,
Expand All @@ -333,7 +379,10 @@ class Metrics(Enum):
maj_at_8_gsm8k = SampleLevelMetric(
metric_name="maj@8",
sample_level_fn=MajAtK(
k=8, strip_strings=True, normalize_pred=gsm8k_normalizer, normalize_gold=gsm8k_normalizer
k=8,
strip_strings=True,
normalize_pred=gsm8k_normalizer,
normalize_gold=gsm8k_normalizer,
).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.MATH,
Expand Down Expand Up @@ -415,7 +464,9 @@ class Metrics(Enum):
quasi_exact_match_math = SampleLevelMetric(
metric_name="qem",
sample_level_fn=ExactMatches(
strip_strings=True, normalize_pred=math_normalizer, normalize_gold=math_normalizer
strip_strings=True,
normalize_pred=math_normalizer,
normalize_gold=math_normalizer,
).compute,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.MATH,
Expand All @@ -433,7 +484,9 @@ class Metrics(Enum):
quasi_exact_match_gsm8k = SampleLevelMetric(
metric_name="qem",
sample_level_fn=ExactMatches(
strip_strings=True, normalize_pred=gsm8k_normalizer, normalize_gold=gsm8k_normalizer
strip_strings=True,
normalize_pred=gsm8k_normalizer,
normalize_gold=gsm8k_normalizer,
).compute,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.MATH,
Expand Down Expand Up @@ -482,8 +535,18 @@ class Metrics(Enum):
).compute,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn={"rouge1": np.mean, "rouge2": np.mean, "rougeL": np.mean, "rougeLsum": np.mean},
higher_is_better={"rouge1": True, "rouge2": True, "rougeL": True, "rougeLsum": True},
corpus_level_fn={
"rouge1": np.mean,
"rouge2": np.mean,
"rougeL": np.mean,
"rougeLsum": np.mean,
},
higher_is_better={
"rouge1": True,
"rouge2": True,
"rougeL": True,
"rougeLsum": True,
},
)
rouge1 = SampleLevelMetric(
metric_name="rouge1",
Expand Down Expand Up @@ -541,6 +604,16 @@ class Metrics(Enum):
corpus_level_fn={"truthfulqa_mc1": np.mean, "truthfulqa_mc2": np.mean},
higher_is_better={"truthfulqa_mc1": True, "truthfulqa_mc2": True},
)
math_gold_as_latex_verifier = multilingual_extractive_match_metric(
Language.ENGLISH,
gold_extraction_target=[LatexExtractionConfig()],
fallback_mode="first_match",
)
math_gold_as_expr_verifier = multilingual_extractive_match_metric(
Language.ENGLISH,
gold_extraction_target=[ExprExtractionConfig()],
fallback_mode="first_match",
)
word_perplexity = CorpusLevelMetric(
metric_name="word_perplexity",
sample_level_fn=PerplexityPreparator(units_type="words").prepare,
Expand Down
118 changes: 118 additions & 0 deletions src/lighteval/tasks/default_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import lighteval.tasks.default_prompts as prompt
from lighteval.metrics.metrics import Metrics
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.templates.qa import get_qa_prompt_function
from lighteval.utils.language import Language


abstract_narrative_understanding_bigbench = LightevalTaskConfig(
Expand Down Expand Up @@ -5503,6 +5505,30 @@
trust_dataset=True,
version=0,
)

natural_questions = LightevalTaskConfig(
name="natural_questions",
prompt_function=get_qa_prompt_function(
Language.ENGLISH,
lambda line: {
"question": line["question"],
"context": line["context"],
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
},
),
suite=("lighteval",),
hf_repo="lighteval/SimpleQA",
hf_subset="default",
evaluation_splits=("validation",),
few_shots_split="train",
generation_size=200,
stop_sequence=("\n",),
metric=(
Metrics.prefix_quasi_exact_match,
Metrics.f1_score_quasi,
),
)

blimp_wh_questions_subject_gap_helm = LightevalTaskConfig(
name="blimp:wh_questions_subject_gap",
suite=["helm", "blimp"],
Expand Down Expand Up @@ -6557,6 +6583,30 @@
trust_dataset=True,
version=0,
)

coqa_first_question = LightevalTaskConfig(
name="coqa_first_question",
prompt_function=get_qa_prompt_function(
Language.ENGLISH,
lambda line: {
"question": line["question"][0],
"context": line["story"],
"choices": [line["answers"]["input_text"][0]],
},
),
suite=("lighteval",),
hf_repo="stadford/coqa",
hf_subset="default",
hf_avail_splits=["train", "validation"],
evaluation_splits=["validation"],
generation_size=150,
stop_sequence=("\n",),
metric=(
Metrics.prefix_quasi_exact_match,
Metrics.f1_score_quasi,
),
)

coqa_bb_lighteval = LightevalTaskConfig(
name="coqa_bb",
suite=["lighteval", "bigbench_programmatic", "bigbench"],
Expand Down Expand Up @@ -15180,6 +15230,74 @@
trust_dataset=True,
version=0,
)
squad_v2 = LightevalTaskConfig(
name="squad_v2",
prompt_function=get_qa_prompt_function(
Language.ENGLISH,
lambda line: {
"question": line["question"],
"context": line["context"],
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
},
),
suite=("lighteval",),
hf_repo="rajpurkar/squad_v2",
hf_subset="default",
evaluation_splits=("validation",),
few_shots_split="train",
generation_size=200,
stop_sequence=("\n",),
metric=(
Metrics.prefix_quasi_exact_match,
Metrics.f1_score_quasi,
),
)

jeopardy = LightevalTaskConfig(
name="jeopardy",
prompt_function=get_qa_prompt_function(
Language.ENGLISH,
lambda line: {
"question": line["question"],
"choices": [line["answer"]],
},
),
suite=("lighteval",),
hf_repo="openaccess-ai-collective/jeopardy",
hf_subset="default",
evaluation_splits=("train",),
few_shots_split="train",
generation_size=50,
stop_sequence=("\n",),
metric=(
Metrics.prefix_quasi_exact_match,
Metrics.f1_score_quasi,
),
)

simple_qa = LightevalTaskConfig(
name="squad_v2",
prompt_function=get_qa_prompt_function(
Language.ENGLISH,
lambda line: {
"question": line["question"],
"context": line["context"],
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
},
),
suite=("lighteval",),
hf_repo="lighteval/SimpleQA",
hf_subset="default",
evaluation_splits=("test",),
few_shots_split="few_shot",
generation_size=250,
stop_sequence=("\n",),
metric=(
Metrics.prefix_quasi_exact_match,
Metrics.f1_score_quasi,
),
)

swag_lighteval = LightevalTaskConfig(
name="swag",
suite=["lighteval"],
Expand Down

0 comments on commit b6c50a2

Please sign in to comment.