Skip to content

Commit 306ee76

Browse files
authored
Merge branch 'huggingface:main' into add_swiss_legal_evals
2 parents cb6bfb4 + be18ae5 commit 306ee76

File tree

14 files changed

+2124
-32
lines changed

14 files changed

+2124
-32
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
---
2323

24-
**Documentation**: <a href="https://github.com/huggingface/lighteval/wiki" target="_blank">Lighteval's Wiki</a>
24+
**Documentation**: <a href="https://huggingface.co/docs/lighteval/index" target="_blank">Lighteval's Wiki</a>
2525

2626
---
2727

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ tensorboardX = ["tensorboardX"]
9595
vllm = ["vllm", "ray", "more_itertools"]
9696
quality = ["ruff==v0.2.2","pre-commit"]
9797
tests = ["pytest==7.4.0"]
98-
dev = ["lighteval[accelerate,quality,tests,multilingual]"]
98+
dev = ["lighteval[accelerate,quality,tests,multilingual,math]"]
9999
docs = ["hf-doc-builder", "watchdog"]
100100
extended_tasks = [
101101
"langdetect", # ifeval
@@ -111,6 +111,7 @@ multilingual = [
111111
"jieba", # for chinese tokenizer
112112
"pyvi", # for vietnamese tokenizer
113113
]
114+
math = ["latex2sympy2_extended>=0.9.0"]
114115

115116
[project.urls]
116117
Homepage = "https://github.com/huggingface/lighteval"

src/lighteval/data.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,15 @@
2525
from typing import Iterator, Tuple
2626

2727
import torch
28+
from packaging import version
2829
from torch.utils.data import Dataset
29-
from torch.utils.data.distributed import DistributedSampler, T_co
30+
31+
32+
if version.parse(torch.__version__) >= version.parse("2.5.0"):
33+
from torch.utils.data.distributed import DistributedSampler, _T_co
34+
else:
35+
from torch.utils.data.distributed import DistributedSampler
36+
from torch.utils.data.distributed import T_co as _T_co
3037

3138
from lighteval.tasks.requests import (
3239
GreedyUntilRequest,
@@ -318,7 +325,7 @@ class GenDistributedSampler(DistributedSampler):
318325
as our samples are sorted by length.
319326
"""
320327

321-
def __iter__(self) -> Iterator[T_co]:
328+
def __iter__(self) -> Iterator[_T_co]:
322329
if self.shuffle:
323330
# deterministically shuffle based on epoch and seed
324331
g = torch.Generator()

src/lighteval/metrics/dynamic_metrics.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23-
from typing import Callable, Literal
23+
import logging
24+
from typing import Callable, Literal, Sequence
2425

2526
import numpy as np
2627

@@ -37,8 +38,22 @@
3738
LogProbTokenNorm,
3839
get_multilingual_normalizer,
3940
)
41+
from lighteval.metrics.utils.extractive_match_utils import ( # noqa: F401
42+
ExprExtractionConfig,
43+
ExtractionTarget,
44+
IndicesExtractionConfig,
45+
LatexExtractionConfig,
46+
extract_target_from_pred,
47+
get_extraction_regexes,
48+
)
49+
from lighteval.metrics.utils.math_comparison import compare_gold_target
4050
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
51+
from lighteval.tasks.requests import Doc
4152
from lighteval.utils.language import Language
53+
from lighteval.utils.timeout import timeout
54+
55+
56+
logger = logging.getLogger(__name__)
4257

4358

4459
def loglikelihood_acc_metric(normalization: LogProbNormalization | None = None) -> SampleLevelMetric:
@@ -168,3 +183,94 @@ def multilingual_quasi_exact_match_metric(
168183
corpus_level_fn=np.mean,
169184
higher_is_better=True,
170185
)
186+
187+
188+
def multilingual_extractive_match_metric(
189+
language: Language = Language.ENGLISH,
190+
gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
191+
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
192+
aggregation_function: Callable[[list[float]], float] = max,
193+
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
194+
precision: int = 6,
195+
) -> SampleLevelMetric:
196+
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
197+
198+
Known issues:
199+
- If the task is to simplify an expression, the metric might overestimate the accuracy. This is because if the model doesn't output any anchor for the extraction (e.g final answer is..),
200+
it's possible that the the extracted prediction will be the expression to simplify. Because we do simplifications ourselves, it can thus happen that sympy will correctly simplify the expression,
201+
thus it will match gold, despite model not doing anything. PRs to fix this are welcome.
202+
203+
- There is currently no StringExtractionConfig, so if the gold is \boxed{\text{Friday}} and model outputs Friday it will not match, because nothing will be extracted.
204+
205+
Args:
206+
language: Language
207+
The language of the samples.
208+
gold_extraction_target: Sequence[ExtractionTarget]
209+
Extraction targets to use for gold answers. Defaults to extracting simple math expressions.
210+
pred_extraction_target: Sequence[ExtractionTarget]
211+
Extraction targets to use for predictions. Defaults to extracting simple math expressions.
212+
aggregation_function: Callable[[list[float]], float]
213+
Function to aggregate scores when multiple golds/predictions are present. Defaults to max.
214+
fallback_mode: Literal["no_fallback", "first_match"]
215+
How to perform extraction. Defaults to "first_match".
216+
- "no_fallback": Only use first successfully parsed matches
217+
- "first_match": Use the first successfully parsed match + first match irregardless the parsing success
218+
precision: int
219+
Number of decimal places to use when comparing numerical values. Defaults to 6.
220+
221+
Returns:
222+
A sample level metric that extracts and compares mathematical expressions.
223+
224+
"""
225+
226+
@timeout(2)
227+
def add_to_specifics_with_timeout(
228+
formatted_doc: Doc, extracted_predictions: list[list[str]], extracted_golds: list[list[str]]
229+
) -> None:
230+
if formatted_doc.specific is None:
231+
formatted_doc.specific = {}
232+
233+
formatted_doc.specific["extracted_predictions"] = [
234+
str(pred) for preds in extracted_predictions for pred in preds
235+
]
236+
formatted_doc.specific["extracted_golds"] = [str(gold) for golds in extracted_golds for gold in golds]
237+
238+
def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc) -> float:
239+
gold_extraction_regexes = get_extraction_regexes(formatted_doc, gold_extraction_target, language)
240+
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)
241+
242+
extracted_predictions = [
243+
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions
244+
]
245+
extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds]
246+
247+
# Assert on empty gold and warn on empty pred
248+
if any(len(g) == 0 for g in extracted_golds):
249+
raise ValueError(f"No gold targets found for at least one gold. Gold: {golds}, Pred: {predictions}")
250+
251+
if all(len(p) == 0 for p in extracted_predictions):
252+
logger.warning(
253+
f"We did not manage to extract a prediction in the correct format. Gold: {golds}, Pred: {predictions}"
254+
)
255+
256+
# We have to use timeout because the sypmy to str conversion can be very slow
257+
try:
258+
add_to_specifics_with_timeout(formatted_doc, extracted_predictions, extracted_golds)
259+
except: # noqa: E722
260+
logger.warning("Timeout when adding extracted predictions and golds to specific")
261+
262+
return aggregation_function(
263+
[
264+
(1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0)
265+
for pred in extracted_predictions
266+
]
267+
)
268+
269+
return SampleLevelMetric(
270+
metric_name="extractive_match",
271+
sample_level_fn=sample_level_fn,
272+
category=MetricCategory.GENERATIVE,
273+
use_case=MetricUseCase.ACCURACY,
274+
corpus_level_fn=np.mean,
275+
higher_is_better=True,
276+
)

src/lighteval/metrics/llm_as_judge.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,22 +193,32 @@ def __call_litellm(self, prompts):
193193
import litellm
194194

195195
def __call_api(prompt):
196+
error_message = "ERROR: Failed to get response from the API."
196197
for _ in range(self.API_MAX_RETRY):
197198
try:
198-
response = litellm.completion(
199-
model=self.model,
200-
messages=prompt,
201-
response_format={"type": "text"},
202-
max_tokens=512,
203-
n=1,
204-
caching=True,
205-
)
199+
kwargs = {
200+
"model": self.model,
201+
"messages": prompt,
202+
"response_format": {"type": "text"},
203+
"max_tokens": 512,
204+
"n": 1,
205+
"caching": True,
206+
}
207+
response = litellm.completion(**kwargs)
206208
text = response.choices[0].message.content
209+
if not text or text == error_message:
210+
kwargs["caching"] = False
211+
response = litellm.completion(**kwargs)
212+
text = response.choices[0].message.content
213+
if not text or text == error_message:
214+
# Just return an error response if the second attempt fails too
215+
logger.error(f"Failed to get response from the API for prompt: {prompt}")
216+
return error_message
207217
return text
208218
except Exception as e:
209219
logger.warning(f"{type(e), e}")
210220
time.sleep(self.API_RETRY_SLEEP)
211-
raise Exception("Failed to get response from the API")
221+
return error_message
212222

213223
results = []
214224
with ThreadPoolExecutor(100) as executor:

src/lighteval/metrics/metrics_corpus.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"""
2727
import logging
2828
import math
29+
from typing import Literal
2930

3031
import numpy as np
3132
import sacrebleu
@@ -89,33 +90,38 @@ def compute(self, items: list[LogprobCorpusMetricInput]):
8990

9091

9192
class CorpusLevelTranslationMetric:
92-
def __init__(self, metric_type: str):
93+
def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""):
9394
"""Stores the relevant parameters for a corpus level translation metric.
9495
9596
Args:
9697
metric_type (str): Can be any of bleu, chrf, or ter depending on the metric to use.
9798
"""
98-
if metric_type == "bleu":
99-
self.metric = sacrebleu.corpus_bleu
100-
elif metric_type == "chrf":
101-
self.metric = sacrebleu.corpus_chrf
102-
elif metric_type == "ter":
103-
self.metric = sacrebleu.corpus_ter
99+
self.metric_type = metric_type
100+
self.lang = lang
101+
102+
def get_metric(self):
103+
if self.metric_type == "bleu":
104+
return sacrebleu.BLEU(trg_lang=self.lang)
105+
elif self.metric_type == "chrf":
106+
return sacrebleu.CHRF()
107+
elif self.metric_type == "ter":
108+
return sacrebleu.TER(asian_support=True if self.lang != "" else False)
104109
else:
105-
raise ValueError(f"Unknown corpus level translation metric type : {metric_type}")
110+
raise ValueError(f"Unknown corpus level translation metric type : {self.metric_type}")
106111

107112
def compute(self, items: list[GenerativeCorpusMetricInput]) -> float:
108113
"""Computes the metric score over all the corpus generated items, by using the sacrebleu implementation."""
114+
metric = self.get_metric()
109115
golds = [i.golds for i in items]
110116
preds = []
111117
for i in items:
112118
pred = as_list(i.preds)
113119
if len(pred) > 1:
114120
logger.info(
115-
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{self.metric.__name__})."
121+
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{metric.__name__})."
116122
)
117123
preds.append(pred[0])
118-
return float(self.metric(hypotheses=preds, references=golds).score)
124+
return float(metric.corpus_score(hypotheses=preds, references=golds).score)
119125

120126

121127
class CorpusLevelPerplexityMetric:

0 commit comments

Comments
 (0)