Skip to content

Commit ba3e2ee

Browse files
committed
Revert "Merge pull request #10 from huggingface/fix-target-perplexity"
This reverts commit 1925742, reversing changes made to 0cf83ce.
1 parent 1925742 commit ba3e2ee

File tree

11 files changed

+67
-94
lines changed

11 files changed

+67
-94
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ LightEval is an evaluation suite which gathers a selection of features from wide
88

99
It is still an early, internal version - it should be nice to use but don't expect 100% stability!
1010

11-
In case of problems or question, feel free to open an issue!
11+
In case of problems or question, feel free to open an issue!
1212

1313
## How to install and use
1414
### Requirements
@@ -50,11 +50,11 @@ Lastly, create a **line summary** of your evaluation, in `metadata_table.json`.
5050
- `suite` (list), the suite(s) to which your evaluation should belong. This field allows us to compare different tasks implementation, and is used a task selection to differentiate the versions to launch. At the moment, you'll find the keywords ["helm", "bigbench", "original", "lighteval"]; you can add also add new ones (for test, we recommend using "custom").
5151
- `prompt_function` (str), the name of the prompt function you defined in the step above
5252
- `hf_repo` (str), the path to your evaluation dataset on the hub
53-
- `hf_subset` (str), the specific subset you want to use for your evaluation (note: when the dataset has no subset, fill this field with `"default"`, not with `None` or `""`)
53+
- `hf_subset` (str), the specific subset you want to use for your evaluation (note: when the dataset has no subset, fill this field with `"default"`, not with `None` or `""`)
5454
- `hf_avail_splits` (list), all the splits available for your dataset (train, valid or validation, test, other...)
5555
- `evaluation_splits` (list), the splits you want to use for evaluation
5656
- `few_shots_split` (str, can be `null`), the specific split from which you want to select samples for your few-shot examples. It should be different from the sets included in `evaluation_splits`
57-
- `few_shots_select` (str, can be `null`), the method that you will use to select items for your few-shot examples. Can be `null`, or one of:
57+
- `few_shots_select` (str, can be `null`), the method that you will use to select items for your few-shot examples. Can be `null`, or one of:
5858
- `balanced` selects examples from the `few_shots_split` with balanced labels, to avoid skewing the few shot examples (hence the model generations) towards one specific label
5959
- `random` selects examples at random from the `few_shots_split`
6060
- `random_sampling` selects new examples at random from the `few_shots_split` for every new item, but if a sampled item is equal to the current one, it is removed from the available samples
@@ -102,7 +102,7 @@ These metrics need the model to generate an output. They are therefore slower.
102102
- `exact_match_indicator`: Exact match with some preceding context (before an indicator) removed
103103
- `f1_score_quasi` (HELM): Average F1 score in terms of word overlap between the model output and gold, with both being normalized first
104104
- `f1_score`: Average F1 score in terms of word overlap between the model output and gold without normalisation
105-
- `f1_score_macro`: Corpus level macro F1 score
105+
- `f1_score_macro`: Corpus level macro F1 score
106106
- `f1_score_macro`: Corpus level micro F1 score
107107
- Summarization:
108108
- `rouge` (Harness): Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/)
@@ -141,7 +141,7 @@ These metrics need both the generation and its logprob. They are not working at
141141
- `prediction_perplexity` (HELM): Measure of the logprob of a given input.
142142

143143
## Adding a new metric
144-
If you want to add a new metric, first check if you can use one of the parametrized functions in `src.lighteval.metrics.metrics_corpus` or `metrics_sample`. If not, add it to either of these files depending on the level at which it is applied. Then, follow the example in `src.lighteval.metrics.metrics` to register your metric.
144+
If you want to add a new metric, first check if you can use one of the parametrized functions in `src.lighteval.metrics.metrics_corpus` or `metrics_sample`. If not, add it to either of these files depending on the level at which it is applied. Then, follow the example in `src.lighteval.metrics.metrics` to register your metric.
145145

146146
## Examples of scripts to launch lighteval on the cluster
147147
### Evaluate a whole suite on one node, 8 GPUs

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ optimum = ["optimum==1.12.0"]
8282
quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"]
8383
adapters = ["peft==0.3.0"]
8484
nanotron = [
85-
"nanotron@git+https://github.com/huggingface/nanotron@main",
86-
"brrr@git+https://github.com/huggingface/brrr@fix-lighteval",
85+
"nanotron@git+https://github.com/huggingface/nanotron@8c1a49588d0745a6404644a86547c2dd6a63640e",
86+
"brrr@git+https://github.com/huggingface/brrr@e8a503e2ec08b34eed7522d331aec3bee8cdd29b",
8787
"tensorboardX"
8888
]
8989

src/lighteval/data.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_original_order(self, new_arr: list) -> list:
7272

7373
return original_order
7474

75-
def get_set_split_start_end(self, split_id: int) -> tuple[int, int]:
75+
def get_split_start_end(self, split_id: int) -> tuple[int, int]:
7676
"""
7777
Get the start and end indices of a dataset split.
7878
@@ -96,7 +96,7 @@ def splits_start_end_iterator(self) -> tuple[int, int]:
9696
tuple: A tuple containing the start and end indices of a split.
9797
"""
9898
for split_id in range(self.dataset_splits):
99-
yield self.get_set_split_start_end(split_id)
99+
yield self.get_split_start_end(split_id)
100100

101101
def __getitem__(self, index) -> Request:
102102
"""
@@ -189,9 +189,7 @@ def _sorting_criteria(self, x) -> int:
189189
Returns:
190190
Any: The collated data.
191191
"""
192-
toks = x[0]
193-
meta_data = x[1]
194-
stop_tokens, gen_length = meta_data[0], meta_data[1]
192+
toks, (stop_tokens, gen_length) = x
195193
return -(len(toks) + gen_length)
196194

197195

src/lighteval/metrics/__init__.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,12 @@
77

88

99
def apply_target_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]):
10-
if len(formatted_doc.get_golds()) != 1:
11-
raise ValueError("Target perplexity metric can only be used with one gold reference")
1210
outputs = {}
13-
reference_text = formatted_doc.get_golds()[0]
14-
current_result = results.pop(0)
15-
target_logprob = current_result.result[0]
16-
target_acc = current_result.result[1]
11+
current_results = [results.pop(0) for _ in range(len(formatted_doc.get_golds()))]
1712

1813
for metric in metrics:
19-
if Metrics[metric].value.category == MetricCategory.TARGET_PERPLEXITY:
20-
outputs.update(
21-
Metrics[metric].value.compute(
22-
logprobs=target_logprob, target_acc=target_acc, reference_text=reference_text
23-
)
24-
)
14+
if Metrics[metric].value.category == MetricCategory.PERPLEXITY:
15+
outputs.update(Metrics[metric].value.compute(results=current_results))
2516

2617
return results, outputs
2718

@@ -39,9 +30,7 @@ def apply_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metr
3930

4031
for metric in metrics:
4132
if Metrics[metric].value.category == MetricCategory.PERPLEXITY:
42-
outputs.update(
43-
Metrics[metric].value.compute(logprobs=current_result.result, reference_text=reference_text)
44-
)
33+
outputs.update(Metrics[metric].value.compute(results=current_result, reference_text=reference_text))
4534

4635
return results, outputs
4736

src/lighteval/metrics/metrics_sample.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,17 @@ def compute(self, choices_logprob: list[float], gold_ixs: list[float], formatted
275275
return 1.0 / (min(ranked_choices) + 1)
276276

277277

278-
def acc_golds_likelihood(target_acc: list[int] | int, **kwargs) -> int:
278+
def acc_golds_likelihood(results: list[tuple[float, int]], **kwargs) -> int:
279279
"""Tests if at least one of predicted gold targets' log-likelihood is above 0.5.
280280
281281
Args:
282-
target_acc (list[int]): List of scores indicating whether the predictions log-probabilities are above 0.5 aggregated.
282+
results (list[int]): List of tuples containing, for each gold, the predictions log-probabilities associated with whether they are above 0.5 aggregated.
283+
formatted_doc (Doc): _description_
283284
284285
Returns:
285286
int: 1 if at least one of the possible golds had a log-likelihood above 0.5.
286287
"""
287-
return max([int(acc_ppl) for acc_ppl in as_list(target_acc)])
288+
return max([int(acc_ppl) for _, acc_ppl in results])
288289

289290

290291
class ROUGE:

src/lighteval/metrics/sample_preparator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,14 @@ def count_units(self, text: str) -> int:
106106
if self.units_type == "bytes":
107107
return len(text.encode("utf-8"))
108108

109-
def prepare(self, logprobs: list[float] | float, reference_text: str, **kwargs):
109+
def prepare(self, results, reference_text, **kwargs):
110110
"""Prepares an individual perplexity example to the format expected by metrics computed at the corpus level (aggregated).
111111
112112
Args:
113-
logprobs (list[float]): List of the logprobabilities computed for each item of the sequence or single aggregated logprob over the sequence
113+
results (list[float]): List of the logprobabilities computed for each item
114114
reference_text (str): Current reference text for which to compute the length in self.units_type
115115
116116
Returns:
117117
PerplexityCorpusMetricInput: Stores the measured logprobs and associated text lengths, counted in the reference unit.
118118
"""
119-
return PerplexityCorpusMetricInput(logprobs=logprobs, weights=self.count_units(reference_text))
119+
return PerplexityCorpusMetricInput(logprobs=results.result, weights=self.count_units(reference_text))

src/lighteval/models/brrr_models.py

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
# flake8: noqa: C901,E1120
1+
# flake8: noqa: C901
22
import os
33
import time
4-
from dataclasses import dataclass
5-
from typing import List, Optional, Tuple, Union, Type
4+
from typing import List, Optional, Tuple, Union
65

76
import torch
87
import torch.nn.functional as F
@@ -29,22 +28,9 @@
2928
from tqdm import tqdm
3029
from transformers import AutoTokenizer, BatchEncoding
3130

32-
from lighteval.tasks.requests import (
33-
GreedyUntilRequest,
34-
LoglikelihoodRequest,
35-
LoglikelihoodRollingRequest,
36-
LoglikelihoodSingleTokenRequest,
37-
)
38-
from lighteval.data import (
39-
GenDistributedSampler,
40-
GenerativeTaskDataset,
41-
LoglikelihoodDataset,
42-
LoglikelihoodSingleTokenDataset,
43-
)
31+
from lighteval.data import GenDataset, GenDistributedSampler, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
4432
from lighteval.models.model_output import Batch, GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
45-
from lighteval.tasks.requests import GreedyUntilRequest
46-
from lighteval.utils import as_list
47-
from lighteval.utils_parallelism import find_executable_batch_size
33+
from lighteval.utils import as_list, find_executable_batch_size
4834

4935

5036
# from .brrr_generation import GenerationConfig, GenerationInputs, SamplerType, greedy_search_tokenized
@@ -55,7 +41,8 @@
5541

5642
TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]
5743

58-
STARTING_BATCH_SIZE = 512
44+
# _DeviceMapping = NewType("DeviceMapping", Mapping[str, Union[int, str, torch.device]])
45+
5946

6047
class BRRRModel:
6148
# Default max sequence length setting for when no `max_length` is provided
@@ -81,7 +68,6 @@ def __init__(
8168
s5cmd_numworkers: int = 64,
8269
s5cmd_concurrency: int = 10,
8370
s5cmd_path: str = "/admin/home/thomwolf/miniconda/envs/b4r/bin/s5cmd",
84-
model_class: Optional[Type] = None,
8571
):
8672
"""Initializes a brrr model for evaluation.
8773
Args:
@@ -134,9 +120,6 @@ def __init__(
134120
self.tokenizer.model_max_length = self.max_length
135121

136122
model_config_cls = self.model_config.__class__.__name__
137-
if model_class is not None:
138-
CONFIG_TO_MODEL_CLASS[self.model_config.__class__.__name__] = model_class
139-
140123
if model_config_cls not in CONFIG_TO_MODEL_CLASS:
141124
raise ValueError(
142125
f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported"
@@ -411,7 +394,7 @@ def _encode_pair(self, context, continuation):
411394
continuation_enc = whole_enc[context_enc_len:]
412395
return context_enc, continuation_enc
413396

414-
def homogeneize_ending_conditions(self, ending_condition: Union[tuple, dict, list, str]) -> tuple[list, int]:
397+
def homogeneize_ending_conditions(self, ending_condition: tuple | dict | list | str) -> tuple[list, int]:
415398
"""Ending conditions are submitted in several possible formats.
416399
By default in lighteval we pass them as tuples (stop sequence, max number of items).
417400
In the harness they sometimes are passed as dicts {"until": .., "max_length": ...} or
@@ -506,7 +489,7 @@ def loglikelihood_single_token(
506489
disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0),
507490
)
508491

509-
def loglikelihood(self, requests: List[LoglikelihoodRequest], override_bs=None) -> List[LoglikelihoodReturn]:
492+
def loglikelihood(self, requests: List[Tuple[str, str]], override_bs=None) -> List[LoglikelihoodReturn]:
510493
"""Tokenize the context and continuation and compute the log likelihood of those
511494
tokenized sequences.
512495
@@ -535,7 +518,7 @@ def loglikelihood(self, requests: List[LoglikelihoodRequest], override_bs=None)
535518
disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0),
536519
)
537520

538-
def loglikelihood_rolling(self, requests: List[LoglikelihoodRollingRequest], override_bs=None) -> List[LoglikelihoodReturn]:
521+
def loglikelihood_rolling(self, requests: List[Tuple[str, str]], override_bs=None) -> List[LoglikelihoodReturn]:
539522
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
540523
tokenized_reqs = []
541524

@@ -625,7 +608,7 @@ def prepare_batch(
625608

626609
# when too long to fit in context, truncate from the left
627610
inp = torch.tensor(
628-
tokens[-max_context:], # [:-1],
611+
(tokens)[-max_context:], # [:-1],
629612
dtype=torch.long,
630613
)
631614

@@ -716,7 +699,7 @@ def _get_subsets(self, dataset, dataset_splits):
716699

717700
@torch.inference_mode()
718701
def _loglikelihood_single_token(
719-
self, requests: List[LoglikelihoodSingleTokenRequest], disable_tqdm: bool = False, override_bs: int = -1, dataset_splits: int = 1
702+
self, requests, disable_tqdm: bool = False, override_bs: int = -1, dataset_splits: int = 1
720703
) -> List[LoglikelihoodSingleTokenReturn]:
721704
dataset = LoglikelihoodSingleTokenDataset(requests=requests)
722705
res = []
@@ -938,7 +921,7 @@ def _loglikelihood_single_token(
938921
# We are in a process which return no output (beginning/middle of the PP group)
939922
return []
940923

941-
return dataset.get_original_order(res)
924+
return dataset.ordered.get_original(res)
942925

943926
@torch.inference_mode()
944927
def _loglikelihood_tokens(
@@ -949,14 +932,26 @@ def _loglikelihood_tokens(
949932
dataset_splits: int = 1,
950933
return_bool_score: bool = True,
951934
) -> List[LoglikelihoodReturn]:
952-
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=dataset_splits)
935+
dataset = LoglikelihoodDataset(requests=requests)
953936
res = []
954937

955938
# Dataset is sorted in descending size.
956939
# every 20-25% of the dataset we try to double the batch size for speed up
957-
starting_batch_size = STARTING_BATCH_SIZE
940+
starting_batch_size = 512
941+
942+
total_length, subset_length = self._get_subsets(dataset, dataset_splits)
943+
944+
for s, subset_start in enumerate(
945+
tqdm(
946+
range(0, total_length, subset_length),
947+
disable=disable_tqdm,
948+
position=0,
949+
desc=f"loglikelihood -- Node {dist.get_rank(self.parallel_context.world_pg)}",
950+
)
951+
):
952+
dataset.split_start = subset_start
953+
dataset.split_end = min(subset_start + subset_length, total_length)
958954

959-
for s, (split_start, split_end) in tqdm(enumerate(dataset.splits_start_end_iterator())):
960955
# automatic (variable) batch size detection for vectorization
961956
# pull longest context sample from request
962957
_, context_enc, continuation_enc = dataset[0]
@@ -1160,18 +1155,18 @@ def _loglikelihood_tokens(
11601155
# print(f"i {i} padded: {r.padded}")
11611156

11621157
if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank:
1163-
assert len(res) == (split_end-split_start), "we didn't cover all the data"
1158+
assert len(res) == total_length, "we didn't cover all the data"
11641159

11651160
if len(res) == 0:
11661161
# We are in a process which return no output (beginning/middle of the PP group)
11671162
return []
11681163

1169-
return dataset.get_original_order(res)
1164+
return dataset.ordered.get_original(res)
11701165

11711166
@torch.inference_mode()
11721167
def greedy_until(
11731168
self,
1174-
requests: List[GreedyUntilRequest],
1169+
requests: List[Tuple[str, dict]],
11751170
task_names: Optional[List[str]] = None,
11761171
returns_logits=False,
11771172
disable_tqdm: bool = False,
@@ -1183,24 +1178,15 @@ def greedy_until(
11831178
# pull longest context sample from request
11841179
if task_names:
11851180
enc_inputs = [
1186-
(index, (
1187-
self.tok_encode(req.context),
1188-
self.homogeneize_ending_conditions((req.stop_sequence, req.generation_size)),
1189-
task_name,
1190-
))
1191-
for index, (req, task_name) in enumerate(zip(requests, task_names))
1181+
(self.tok_encode(req[0]), self.homogeneize_ending_conditions(req[1]), task_name)
1182+
for req, task_name in zip(requests, task_names)
11921183
]
11931184
else:
11941185
enc_inputs = [
1195-
(index, (
1196-
self.tok_encode(req.context),
1197-
self.homogeneize_ending_conditions((req.stop_sequence, req.generation_size)),
1198-
None,
1199-
))
1200-
for index, req in enumerate(requests)
1186+
(self.tok_encode(req[0]), self.homogeneize_ending_conditions(req[1]), None) for req in requests
12011187
]
12021188

1203-
dataset = GenerativeTaskDataset(requests=enc_inputs, dataset_splits=dataset_splits)
1189+
dataset = GenDataset(requests=enc_inputs)
12041190
res = []
12051191

12061192
# Dataset is sorted in descending size.
@@ -1209,20 +1195,20 @@ def greedy_until(
12091195

12101196
total_length, subset_length = self._get_subsets(dataset, dataset_splits)
12111197

1212-
for s, _ in enumerate(
1198+
for s, subset_start in enumerate(
12131199
tqdm(
1214-
dataset.splits_start_end_iterator(),
1215-
total=dataset_splits,
1216-
desc=f"greedy -- Node {dist.get_rank(self.parallel_context.world_pg)}",
1217-
position=0,
1200+
range(0, total_length, subset_length),
12181201
disable=disable_tqdm,
1202+
position=0,
1203+
desc=f"greedy -- Node {dist.get_rank(self.parallel_context.world_pg)}",
12191204
)
12201205
):
1221-
# print(dataset[0])
1206+
dataset.split_start = subset_start
1207+
dataset.split_end = min(subset_start + subset_length, total_length)
1208+
12221209
_, (context_enc, _, _) = dataset[0]
12231210
max_gen = max(d[1][1][1] for d in dataset)
12241211
max_input_length = min(len(context_enc) + max_gen, self.max_length)
1225-
# max_input_length = len(context_enc)
12261212
batch_size = self._get_batch_size(
12271213
override_bs=override_bs, max_input_length=max_input_length, starting_batch_size=starting_batch_size
12281214
)
@@ -1374,7 +1360,7 @@ def greedy_until(
13741360
# We are in a process which return no output (beginning/middle of the PP group)
13751361
return []
13761362

1377-
return dataset.get_original_order(res)
1363+
return dataset.ordered.get_original(res)
13781364

13791365

13801366
class MultiTokenEOSCriteria(transformers.StoppingCriteria):

0 commit comments

Comments
 (0)