Skip to content

Commit ca66367

Browse files
authored
Merge branch 'main' into standalon_nanotron_config
2 parents 55ce4a9 + 8c787df commit ca66367

17 files changed

+939
-92
lines changed

src/lighteval/logging/info_loggers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,11 @@ def log(
374374
detail.choices = doc.choices
375375
detail.gold_index = as_list(doc.gold_index)
376376
pred_saved = True
377+
if task.has_metric_category[MetricCategory.MULTICHOICE_PMI]:
378+
detail.choices = doc.choices
379+
detail.gold_index = as_list(doc.gold_index)
380+
doc.specific = {**(doc.specific or {}), **{"unconditioned_query": doc.unconditioned_query}}
381+
pred_saved = True
377382
if (
378383
task.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN]
379384
or task.has_metric_category[MetricCategory.LLM_AS_JUDGE]

src/lighteval/metrics/__init__.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,22 @@
3030

3131
def apply_target_perplexity_metric(results: list[ModelResponse], formatted_doc: Doc, metrics: list[Metric]):
3232
outputs = {}
33-
# We only consider the best choice, to check if its logprobs are above 0.5
34-
results = results[formatted_doc.gold_index]
35-
target_logprob = results.result[0]
36-
target_acc = results.result[1]
37-
reference_text = formatted_doc.get_golds()[0]
33+
34+
target_golds = formatted_doc.get_golds()
35+
assert len(results) == len(target_golds), "You should return as many results as there are golds"
36+
target_logprobs = [res.result[0] for res in results]
37+
argmax_logits_eq_gold_list = [res.result[1] for res in results]
38+
target_tokens = [res.generated_tokens for res in results]
3839

3940
for metric in metrics:
4041
if metric.category == MetricCategory.TARGET_PERPLEXITY:
4142
outputs.update(
42-
metric.compute(logprobs=target_logprob, target_acc=target_acc, reference_text=reference_text)
43+
metric.compute(
44+
logprobs=target_logprobs,
45+
argmax_logits_eq_gold_list=argmax_logits_eq_gold_list,
46+
reference_texts=target_golds,
47+
target_tokens=target_tokens,
48+
)
4349
)
4450

4551
return outputs
@@ -61,7 +67,7 @@ def apply_perplexity_metric(results: list[ModelResponse], formatted_doc: Doc, me
6167

6268
for metric in metrics:
6369
if metric.category == MetricCategory.PERPLEXITY:
64-
outputs.update(metric.compute(logprobs=results.result, reference_text=reference_text))
70+
outputs.update(metric.compute(logprobs=[results.result], reference_texts=[reference_text]))
6571

6672
return outputs
6773

@@ -124,23 +130,44 @@ def apply_generative_metric(
124130

125131
def apply_multichoice_metric(results: list[ModelResponse], formatted_doc: Doc, metrics: list[Metric]):
126132
outputs = {}
127-
if len(formatted_doc.choices) <= 1:
133+
n_choices = len(formatted_doc.choices)
134+
is_pmi_category = all(metric.category == MetricCategory.MULTICHOICE_PMI for metric in metrics)
135+
136+
if n_choices <= 1:
128137
raise ValueError(
129138
"You can't use a multi choice metric with only one choice. Use `acc_golds_likelihood` instead."
130139
)
131-
if len(results) != len(formatted_doc.choices):
140+
141+
if not is_pmi_category and len(results) != len(formatted_doc.choices):
132142
raise Exception(
133143
f"You shoud have returned as many model outputs as choices when using an multi choice metric. Returned {len(results)} instead of {len(formatted_doc.choices)}"
134144
)
135145

146+
if is_pmi_category and len(results) != n_choices * 2:
147+
raise Exception(
148+
f"You shoud have returned twice as many model outputs as choices when using an probability multi choice metric. Returned {len(results)} instead of {n_choices * 2} (conditioned and unconditioned)"
149+
)
150+
151+
mc_results = results[:n_choices]
136152
# Todo: make better system with return_bool_score instead of taking first element
137-
choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))]
153+
conditioned_lp = [res.result[0] for res in mc_results]
154+
unconditioned_lp = None
155+
if is_pmi_category:
156+
unconditioned_lp = [res.result[0] for res in results[n_choices : n_choices * 2]]
157+
138158
gold_ixs = as_list(formatted_doc.gold_index)
159+
choices_tokens = [res.generated_tokens for res in mc_results]
139160

140161
for metric in metrics:
141-
if metric.category == MetricCategory.MULTICHOICE:
162+
if metric.category == MetricCategory.MULTICHOICE_PMI or metric.category == MetricCategory.MULTICHOICE:
142163
outputs.update(
143-
metric.compute(choices_logprob=choices_logprob, gold_ixs=gold_ixs, formatted_doc=formatted_doc)
164+
metric.compute(
165+
gold_ixs=gold_ixs,
166+
choices_logprob=conditioned_lp,
167+
unconditioned_logprob=unconditioned_lp,
168+
choices_tokens=choices_tokens,
169+
formatted_doc=formatted_doc,
170+
)
144171
)
145172
return outputs
146173

@@ -151,12 +178,21 @@ def apply_multichoice_metric_one_token(results: list[ModelResponse], formatted_d
151178
raise Exception("You returned more than one result for a sample with a gmultichoice metric on only one token.")
152179
results = results[0]
153180
choices_logprob = results.result
181+
choices_texts = formatted_doc.choices
154182
gold_ixs = as_list(formatted_doc.gold_index)
155183

156184
for metric in metrics:
157185
if metric.category == MetricCategory.MULTICHOICE_ONE_TOKEN:
158186
outputs.update(
159-
metric.compute(choices_logprob=choices_logprob, gold_ixs=gold_ixs, formatted_doc=formatted_doc)
187+
metric.compute(
188+
choices_logprob=choices_logprob,
189+
# Neither token or PMI are supported for this metric
190+
unconditioned_logprob=None,
191+
choices_tokens=None,
192+
choices_texts=choices_texts,
193+
gold_ixs=gold_ixs,
194+
formatted_doc=formatted_doc,
195+
)
160196
)
161197

162198
return outputs
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
from typing import Callable
24+
25+
import numpy as np
26+
27+
from lighteval.metrics.metrics_sample import LoglikelihoodAcc, NormalizedMultiChoiceProbability, Probability
28+
from lighteval.metrics.normalizations import LogProbNormalization, LogProbPMINorm, LogProbTokenNorm
29+
from lighteval.metrics.utils import MetricCategory, MetricUseCase, SampleLevelMetric
30+
31+
32+
def loglikelihood_acc_metric(normalization: LogProbNormalization | None = None) -> SampleLevelMetric:
33+
"""
34+
Creates a accuracy (loglikelihood) metric, which returns accuracy given normalization.
35+
"""
36+
37+
normalization_str = normalization.name if normalization else ""
38+
metric_name = f"acc_{normalization_str}"
39+
return SampleLevelMetric(
40+
metric_name=metric_name,
41+
sample_level_fn=LoglikelihoodAcc(logprob_normalization=normalization).compute,
42+
category=MetricCategory.MULTICHOICE
43+
if not normalization == LogProbPMINorm()
44+
else MetricCategory.MULTICHOICE_PMI,
45+
use_case=MetricUseCase.ACCURACY,
46+
corpus_level_fn=np.mean,
47+
higher_is_better=True,
48+
)
49+
50+
51+
def normalized_multi_choice_prob_metric(
52+
normalization: LogProbNormalization | None = None,
53+
aggregation_function: Callable[[np.ndarray], float] = np.max,
54+
) -> SampleLevelMetric:
55+
"""
56+
Creates a normalized multi-choice probability metric, which returns the probability of the gold choice / sum of probabilities of all choices (after logprobs are normalized).
57+
"""
58+
59+
normalization_str = normalization.name if normalization else ""
60+
metric_name = "_".join(filter(None, ["normalized_mc_prob_", normalization_str]))
61+
62+
return SampleLevelMetric(
63+
metric_name=metric_name,
64+
sample_level_fn=NormalizedMultiChoiceProbability(
65+
log_prob_normalization=normalization, aggregation_function=aggregation_function
66+
).compute,
67+
category=MetricCategory.MULTICHOICE
68+
if not normalization == LogProbPMINorm()
69+
else MetricCategory.MULTICHOICE_PMI,
70+
use_case=MetricUseCase.ACCURACY,
71+
corpus_level_fn=np.mean,
72+
higher_is_better=True,
73+
)
74+
75+
76+
def probability_metric(
77+
normalization: LogProbTokenNorm | None = None,
78+
aggregation_function: Callable[[np.ndarray], float] = np.max,
79+
) -> SampleLevelMetric:
80+
"""
81+
Creates a probability metric, which returns the probability of the gold choice given normalization.
82+
"""
83+
84+
normalization_str = normalization.name if normalization else ""
85+
metric_name = "_".join(filter(None, ["prob", normalization_str]))
86+
87+
return SampleLevelMetric(
88+
metric_name=metric_name,
89+
sample_level_fn=Probability(normalization=normalization, aggregation_function=aggregation_function).compute,
90+
category=MetricCategory.TARGET_PERPLEXITY,
91+
use_case=MetricUseCase.PERPLEXITY,
92+
corpus_level_fn=np.mean,
93+
higher_is_better=True,
94+
)

src/lighteval/metrics/harness_compatibility/truthful_qa.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,17 @@
2222

2323
import numpy as np
2424

25+
from lighteval.tasks.requests import Doc
26+
2527

2628
# Comes from the harness
27-
def truthfulqa_mc_metrics(gold_ixs, choices_logprob, formatted_doc):
29+
def truthfulqa_mc_metrics(
30+
gold_ixs: list[int],
31+
choices_logprob: list[float],
32+
unconditioned_logprob: list[float] | None,
33+
choices_tokens: list[list[int]] | None,
34+
formatted_doc: Doc,
35+
):
2836
def mc1(lls):
2937
# The gold answers in `mc1_targets` are always first (index = `0`).
3038
return np.argmax(lls) == 0
@@ -47,7 +55,7 @@ def mc2(lls, split_idx):
4755
last_harness_gold = g
4856
else:
4957
break
50-
58+
# TODO: This completely ignores any normalization, but keeping it as is
5159
mc2_last_gold_ix = last_harness_gold - len_mc1 + 1
5260
mc1_lls, mc2_lls = choices_logprob[:len_mc1], choices_logprob[len_mc1:]
5361
return {"truthfulqa_mc1": mc1(mc1_lls), "truthfulqa_mc2": mc2(mc2_lls, mc2_last_gold_ix)}

src/lighteval/metrics/metrics.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
faithfulness,
5252
)
5353
from lighteval.metrics.normalizations import (
54+
LogProbCharNorm,
5455
bigbench_normalizer,
5556
gsm8k_normalizer,
5657
harness_triviaqa_normalizer,
@@ -288,39 +289,39 @@ class Metrics(Enum):
288289
)
289290
loglikelihood_acc = SampleLevelMetric(
290291
metric_name="acc",
291-
sample_level_fn=LoglikelihoodAcc().compute,
292+
sample_level_fn=LoglikelihoodAcc(logprob_normalization=None).compute,
292293
category=MetricCategory.MULTICHOICE,
293294
use_case=MetricUseCase.ACCURACY,
294295
corpus_level_fn=np.mean,
295296
higher_is_better=True,
296297
)
297298
loglikelihood_acc_norm = SampleLevelMetric(
298299
metric_name="acc_norm",
299-
sample_level_fn=LoglikelihoodAcc(length_normalization=True).compute,
300+
sample_level_fn=LoglikelihoodAcc(logprob_normalization=LogProbCharNorm()).compute,
300301
category=MetricCategory.MULTICHOICE,
301302
use_case=MetricUseCase.ACCURACY,
302303
corpus_level_fn=np.mean,
303304
higher_is_better=True,
304305
)
305306
loglikelihood_acc_norm_nospace = SampleLevelMetric(
306307
metric_name="acc_norm",
307-
sample_level_fn=LoglikelihoodAcc(length_normalization=True, ignore_first_space=True).compute,
308+
sample_level_fn=LoglikelihoodAcc(logprob_normalization=LogProbCharNorm(ignore_first_space=True)).compute,
308309
category=MetricCategory.MULTICHOICE,
309310
use_case=MetricUseCase.ACCURACY,
310311
corpus_level_fn=np.mean,
311312
higher_is_better=True,
312313
)
313314
loglikelihood_acc_norm_single_token = SampleLevelMetric(
314315
metric_name="acc_norm",
315-
sample_level_fn=LoglikelihoodAcc(length_normalization=True).compute,
316+
sample_level_fn=LoglikelihoodAcc(logprob_normalization=LogProbCharNorm()).compute,
316317
category=MetricCategory.MULTICHOICE_ONE_TOKEN,
317318
use_case=MetricUseCase.ACCURACY,
318319
corpus_level_fn=np.mean,
319320
higher_is_better=True,
320321
)
321322
loglikelihood_acc_single_token = SampleLevelMetric(
322323
metric_name="acc",
323-
sample_level_fn=LoglikelihoodAcc().compute,
324+
sample_level_fn=LoglikelihoodAcc(logprob_normalization=None).compute,
324325
category=MetricCategory.MULTICHOICE_ONE_TOKEN,
325326
use_case=MetricUseCase.ACCURACY,
326327
corpus_level_fn=np.mean,

0 commit comments

Comments
 (0)