Skip to content

Commit 1925742

Browse files
authored
Merge pull request #10 from huggingface/fix-target-perplexity
Fixing target perplexity but
2 parents 0cf83ce + ae08474 commit 1925742

File tree

11 files changed

+94
-67
lines changed

11 files changed

+94
-67
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ LightEval is an evaluation suite which gathers a selection of features from wide
88

99
It is still an early, internal version - it should be nice to use but don't expect 100% stability!
1010

11-
In case of problems or question, feel free to open an issue!
11+
In case of problems or question, feel free to open an issue!
1212

1313
## How to install and use
1414
### Requirements
@@ -50,11 +50,11 @@ Lastly, create a **line summary** of your evaluation, in `metadata_table.json`.
5050
- `suite` (list), the suite(s) to which your evaluation should belong. This field allows us to compare different tasks implementation, and is used a task selection to differentiate the versions to launch. At the moment, you'll find the keywords ["helm", "bigbench", "original", "lighteval"]; you can add also add new ones (for test, we recommend using "custom").
5151
- `prompt_function` (str), the name of the prompt function you defined in the step above
5252
- `hf_repo` (str), the path to your evaluation dataset on the hub
53-
- `hf_subset` (str), the specific subset you want to use for your evaluation (note: when the dataset has no subset, fill this field with `"default"`, not with `None` or `""`)
53+
- `hf_subset` (str), the specific subset you want to use for your evaluation (note: when the dataset has no subset, fill this field with `"default"`, not with `None` or `""`)
5454
- `hf_avail_splits` (list), all the splits available for your dataset (train, valid or validation, test, other...)
5555
- `evaluation_splits` (list), the splits you want to use for evaluation
5656
- `few_shots_split` (str, can be `null`), the specific split from which you want to select samples for your few-shot examples. It should be different from the sets included in `evaluation_splits`
57-
- `few_shots_select` (str, can be `null`), the method that you will use to select items for your few-shot examples. Can be `null`, or one of:
57+
- `few_shots_select` (str, can be `null`), the method that you will use to select items for your few-shot examples. Can be `null`, or one of:
5858
- `balanced` selects examples from the `few_shots_split` with balanced labels, to avoid skewing the few shot examples (hence the model generations) towards one specific label
5959
- `random` selects examples at random from the `few_shots_split`
6060
- `random_sampling` selects new examples at random from the `few_shots_split` for every new item, but if a sampled item is equal to the current one, it is removed from the available samples
@@ -102,7 +102,7 @@ These metrics need the model to generate an output. They are therefore slower.
102102
- `exact_match_indicator`: Exact match with some preceding context (before an indicator) removed
103103
- `f1_score_quasi` (HELM): Average F1 score in terms of word overlap between the model output and gold, with both being normalized first
104104
- `f1_score`: Average F1 score in terms of word overlap between the model output and gold without normalisation
105-
- `f1_score_macro`: Corpus level macro F1 score
105+
- `f1_score_macro`: Corpus level macro F1 score
106106
- `f1_score_macro`: Corpus level micro F1 score
107107
- Summarization:
108108
- `rouge` (Harness): Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/)
@@ -141,7 +141,7 @@ These metrics need both the generation and its logprob. They are not working at
141141
- `prediction_perplexity` (HELM): Measure of the logprob of a given input.
142142

143143
## Adding a new metric
144-
If you want to add a new metric, first check if you can use one of the parametrized functions in `src.lighteval.metrics.metrics_corpus` or `metrics_sample`. If not, add it to either of these files depending on the level at which it is applied. Then, follow the example in `src.lighteval.metrics.metrics` to register your metric.
144+
If you want to add a new metric, first check if you can use one of the parametrized functions in `src.lighteval.metrics.metrics_corpus` or `metrics_sample`. If not, add it to either of these files depending on the level at which it is applied. Then, follow the example in `src.lighteval.metrics.metrics` to register your metric.
145145

146146
## Examples of scripts to launch lighteval on the cluster
147147
### Evaluate a whole suite on one node, 8 GPUs

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ optimum = ["optimum==1.12.0"]
8282
quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"]
8383
adapters = ["peft==0.3.0"]
8484
nanotron = [
85-
"nanotron@git+https://github.com/huggingface/nanotron@8c1a49588d0745a6404644a86547c2dd6a63640e",
86-
"brrr@git+https://github.com/huggingface/brrr@e8a503e2ec08b34eed7522d331aec3bee8cdd29b",
85+
"nanotron@git+https://github.com/huggingface/nanotron@main",
86+
"brrr@git+https://github.com/huggingface/brrr@fix-lighteval",
8787
"tensorboardX"
8888
]
8989

src/lighteval/data.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_original_order(self, new_arr: list) -> list:
7272

7373
return original_order
7474

75-
def get_split_start_end(self, split_id: int) -> tuple[int, int]:
75+
def get_set_split_start_end(self, split_id: int) -> tuple[int, int]:
7676
"""
7777
Get the start and end indices of a dataset split.
7878
@@ -96,7 +96,7 @@ def splits_start_end_iterator(self) -> tuple[int, int]:
9696
tuple: A tuple containing the start and end indices of a split.
9797
"""
9898
for split_id in range(self.dataset_splits):
99-
yield self.get_split_start_end(split_id)
99+
yield self.get_set_split_start_end(split_id)
100100

101101
def __getitem__(self, index) -> Request:
102102
"""
@@ -189,7 +189,9 @@ def _sorting_criteria(self, x) -> int:
189189
Returns:
190190
Any: The collated data.
191191
"""
192-
toks, (stop_tokens, gen_length) = x
192+
toks = x[0]
193+
meta_data = x[1]
194+
stop_tokens, gen_length = meta_data[0], meta_data[1]
193195
return -(len(toks) + gen_length)
194196

195197

src/lighteval/metrics/__init__.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,21 @@
77

88

99
def apply_target_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]):
10+
if len(formatted_doc.get_golds()) != 1:
11+
raise ValueError("Target perplexity metric can only be used with one gold reference")
1012
outputs = {}
11-
current_results = [results.pop(0) for _ in range(len(formatted_doc.get_golds()))]
13+
reference_text = formatted_doc.get_golds()[0]
14+
current_result = results.pop(0)
15+
target_logprob = current_result.result[0]
16+
target_acc = current_result.result[1]
1217

1318
for metric in metrics:
14-
if Metrics[metric].value.category == MetricCategory.PERPLEXITY:
15-
outputs.update(Metrics[metric].value.compute(results=current_results))
19+
if Metrics[metric].value.category == MetricCategory.TARGET_PERPLEXITY:
20+
outputs.update(
21+
Metrics[metric].value.compute(
22+
logprobs=target_logprob, target_acc=target_acc, reference_text=reference_text
23+
)
24+
)
1625

1726
return results, outputs
1827

@@ -30,7 +39,9 @@ def apply_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metr
3039

3140
for metric in metrics:
3241
if Metrics[metric].value.category == MetricCategory.PERPLEXITY:
33-
outputs.update(Metrics[metric].value.compute(results=current_result, reference_text=reference_text))
42+
outputs.update(
43+
Metrics[metric].value.compute(logprobs=current_result.result, reference_text=reference_text)
44+
)
3445

3546
return results, outputs
3647

src/lighteval/metrics/metrics_sample.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,17 +275,16 @@ def compute(self, choices_logprob: list[float], gold_ixs: list[float], formatted
275275
return 1.0 / (min(ranked_choices) + 1)
276276

277277

278-
def acc_golds_likelihood(results: list[tuple[float, int]], **kwargs) -> int:
278+
def acc_golds_likelihood(target_acc: list[int] | int, **kwargs) -> int:
279279
"""Tests if at least one of predicted gold targets' log-likelihood is above 0.5.
280280
281281
Args:
282-
results (list[int]): List of tuples containing, for each gold, the predictions log-probabilities associated with whether they are above 0.5 aggregated.
283-
formatted_doc (Doc): _description_
282+
target_acc (list[int]): List of scores indicating whether the predictions log-probabilities are above 0.5 aggregated.
284283
285284
Returns:
286285
int: 1 if at least one of the possible golds had a log-likelihood above 0.5.
287286
"""
288-
return max([int(acc_ppl) for _, acc_ppl in results])
287+
return max([int(acc_ppl) for acc_ppl in as_list(target_acc)])
289288

290289

291290
class ROUGE:

src/lighteval/metrics/sample_preparator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,14 @@ def count_units(self, text: str) -> int:
106106
if self.units_type == "bytes":
107107
return len(text.encode("utf-8"))
108108

109-
def prepare(self, results, reference_text, **kwargs):
109+
def prepare(self, logprobs: list[float] | float, reference_text: str, **kwargs):
110110
"""Prepares an individual perplexity example to the format expected by metrics computed at the corpus level (aggregated).
111111
112112
Args:
113-
results (list[float]): List of the logprobabilities computed for each item
113+
logprobs (list[float]): List of the logprobabilities computed for each item of the sequence or single aggregated logprob over the sequence
114114
reference_text (str): Current reference text for which to compute the length in self.units_type
115115
116116
Returns:
117117
PerplexityCorpusMetricInput: Stores the measured logprobs and associated text lengths, counted in the reference unit.
118118
"""
119-
return PerplexityCorpusMetricInput(logprobs=results.result, weights=self.count_units(reference_text))
119+
return PerplexityCorpusMetricInput(logprobs=logprobs, weights=self.count_units(reference_text))

src/lighteval/models/brrr_models.py

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
# flake8: noqa: C901
1+
# flake8: noqa: C901,E1120
22
import os
33
import time
4-
from typing import List, Optional, Tuple, Union
4+
from dataclasses import dataclass
5+
from typing import List, Optional, Tuple, Union, Type
56

67
import torch
78
import torch.nn.functional as F
@@ -28,9 +29,22 @@
2829
from tqdm import tqdm
2930
from transformers import AutoTokenizer, BatchEncoding
3031

31-
from lighteval.data import GenDataset, GenDistributedSampler, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
32+
from lighteval.tasks.requests import (
33+
GreedyUntilRequest,
34+
LoglikelihoodRequest,
35+
LoglikelihoodRollingRequest,
36+
LoglikelihoodSingleTokenRequest,
37+
)
38+
from lighteval.data import (
39+
GenDistributedSampler,
40+
GenerativeTaskDataset,
41+
LoglikelihoodDataset,
42+
LoglikelihoodSingleTokenDataset,
43+
)
3244
from lighteval.models.model_output import Batch, GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
33-
from lighteval.utils import as_list, find_executable_batch_size
45+
from lighteval.tasks.requests import GreedyUntilRequest
46+
from lighteval.utils import as_list
47+
from lighteval.utils_parallelism import find_executable_batch_size
3448

3549

3650
# from .brrr_generation import GenerationConfig, GenerationInputs, SamplerType, greedy_search_tokenized
@@ -41,8 +55,7 @@
4155

4256
TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]
4357

44-
# _DeviceMapping = NewType("DeviceMapping", Mapping[str, Union[int, str, torch.device]])
45-
58+
STARTING_BATCH_SIZE = 512
4659

4760
class BRRRModel:
4861
# Default max sequence length setting for when no `max_length` is provided
@@ -68,6 +81,7 @@ def __init__(
6881
s5cmd_numworkers: int = 64,
6982
s5cmd_concurrency: int = 10,
7083
s5cmd_path: str = "/admin/home/thomwolf/miniconda/envs/b4r/bin/s5cmd",
84+
model_class: Optional[Type] = None,
7185
):
7286
"""Initializes a brrr model for evaluation.
7387
Args:
@@ -120,6 +134,9 @@ def __init__(
120134
self.tokenizer.model_max_length = self.max_length
121135

122136
model_config_cls = self.model_config.__class__.__name__
137+
if model_class is not None:
138+
CONFIG_TO_MODEL_CLASS[self.model_config.__class__.__name__] = model_class
139+
123140
if model_config_cls not in CONFIG_TO_MODEL_CLASS:
124141
raise ValueError(
125142
f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported"
@@ -394,7 +411,7 @@ def _encode_pair(self, context, continuation):
394411
continuation_enc = whole_enc[context_enc_len:]
395412
return context_enc, continuation_enc
396413

397-
def homogeneize_ending_conditions(self, ending_condition: tuple | dict | list | str) -> tuple[list, int]:
414+
def homogeneize_ending_conditions(self, ending_condition: Union[tuple, dict, list, str]) -> tuple[list, int]:
398415
"""Ending conditions are submitted in several possible formats.
399416
By default in lighteval we pass them as tuples (stop sequence, max number of items).
400417
In the harness they sometimes are passed as dicts {"until": .., "max_length": ...} or
@@ -489,7 +506,7 @@ def loglikelihood_single_token(
489506
disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0),
490507
)
491508

492-
def loglikelihood(self, requests: List[Tuple[str, str]], override_bs=None) -> List[LoglikelihoodReturn]:
509+
def loglikelihood(self, requests: List[LoglikelihoodRequest], override_bs=None) -> List[LoglikelihoodReturn]:
493510
"""Tokenize the context and continuation and compute the log likelihood of those
494511
tokenized sequences.
495512
@@ -518,7 +535,7 @@ def loglikelihood(self, requests: List[Tuple[str, str]], override_bs=None) -> Li
518535
disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0),
519536
)
520537

521-
def loglikelihood_rolling(self, requests: List[Tuple[str, str]], override_bs=None) -> List[LoglikelihoodReturn]:
538+
def loglikelihood_rolling(self, requests: List[LoglikelihoodRollingRequest], override_bs=None) -> List[LoglikelihoodReturn]:
522539
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
523540
tokenized_reqs = []
524541

@@ -608,7 +625,7 @@ def prepare_batch(
608625

609626
# when too long to fit in context, truncate from the left
610627
inp = torch.tensor(
611-
(tokens)[-max_context:], # [:-1],
628+
tokens[-max_context:], # [:-1],
612629
dtype=torch.long,
613630
)
614631

@@ -699,7 +716,7 @@ def _get_subsets(self, dataset, dataset_splits):
699716

700717
@torch.inference_mode()
701718
def _loglikelihood_single_token(
702-
self, requests, disable_tqdm: bool = False, override_bs: int = -1, dataset_splits: int = 1
719+
self, requests: List[LoglikelihoodSingleTokenRequest], disable_tqdm: bool = False, override_bs: int = -1, dataset_splits: int = 1
703720
) -> List[LoglikelihoodSingleTokenReturn]:
704721
dataset = LoglikelihoodSingleTokenDataset(requests=requests)
705722
res = []
@@ -921,7 +938,7 @@ def _loglikelihood_single_token(
921938
# We are in a process which return no output (beginning/middle of the PP group)
922939
return []
923940

924-
return dataset.ordered.get_original(res)
941+
return dataset.get_original_order(res)
925942

926943
@torch.inference_mode()
927944
def _loglikelihood_tokens(
@@ -932,26 +949,14 @@ def _loglikelihood_tokens(
932949
dataset_splits: int = 1,
933950
return_bool_score: bool = True,
934951
) -> List[LoglikelihoodReturn]:
935-
dataset = LoglikelihoodDataset(requests=requests)
952+
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=dataset_splits)
936953
res = []
937954

938955
# Dataset is sorted in descending size.
939956
# every 20-25% of the dataset we try to double the batch size for speed up
940-
starting_batch_size = 512
941-
942-
total_length, subset_length = self._get_subsets(dataset, dataset_splits)
943-
944-
for s, subset_start in enumerate(
945-
tqdm(
946-
range(0, total_length, subset_length),
947-
disable=disable_tqdm,
948-
position=0,
949-
desc=f"loglikelihood -- Node {dist.get_rank(self.parallel_context.world_pg)}",
950-
)
951-
):
952-
dataset.split_start = subset_start
953-
dataset.split_end = min(subset_start + subset_length, total_length)
957+
starting_batch_size = STARTING_BATCH_SIZE
954958

959+
for s, (split_start, split_end) in tqdm(enumerate(dataset.splits_start_end_iterator())):
955960
# automatic (variable) batch size detection for vectorization
956961
# pull longest context sample from request
957962
_, context_enc, continuation_enc = dataset[0]
@@ -1155,18 +1160,18 @@ def _loglikelihood_tokens(
11551160
# print(f"i {i} padded: {r.padded}")
11561161

11571162
if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank:
1158-
assert len(res) == total_length, "we didn't cover all the data"
1163+
assert len(res) == (split_end-split_start), "we didn't cover all the data"
11591164

11601165
if len(res) == 0:
11611166
# We are in a process which return no output (beginning/middle of the PP group)
11621167
return []
11631168

1164-
return dataset.ordered.get_original(res)
1169+
return dataset.get_original_order(res)
11651170

11661171
@torch.inference_mode()
11671172
def greedy_until(
11681173
self,
1169-
requests: List[Tuple[str, dict]],
1174+
requests: List[GreedyUntilRequest],
11701175
task_names: Optional[List[str]] = None,
11711176
returns_logits=False,
11721177
disable_tqdm: bool = False,
@@ -1178,15 +1183,24 @@ def greedy_until(
11781183
# pull longest context sample from request
11791184
if task_names:
11801185
enc_inputs = [
1181-
(self.tok_encode(req[0]), self.homogeneize_ending_conditions(req[1]), task_name)
1182-
for req, task_name in zip(requests, task_names)
1186+
(index, (
1187+
self.tok_encode(req.context),
1188+
self.homogeneize_ending_conditions((req.stop_sequence, req.generation_size)),
1189+
task_name,
1190+
))
1191+
for index, (req, task_name) in enumerate(zip(requests, task_names))
11831192
]
11841193
else:
11851194
enc_inputs = [
1186-
(self.tok_encode(req[0]), self.homogeneize_ending_conditions(req[1]), None) for req in requests
1195+
(index, (
1196+
self.tok_encode(req.context),
1197+
self.homogeneize_ending_conditions((req.stop_sequence, req.generation_size)),
1198+
None,
1199+
))
1200+
for index, req in enumerate(requests)
11871201
]
11881202

1189-
dataset = GenDataset(requests=enc_inputs)
1203+
dataset = GenerativeTaskDataset(requests=enc_inputs, dataset_splits=dataset_splits)
11901204
res = []
11911205

11921206
# Dataset is sorted in descending size.
@@ -1195,20 +1209,20 @@ def greedy_until(
11951209

11961210
total_length, subset_length = self._get_subsets(dataset, dataset_splits)
11971211

1198-
for s, subset_start in enumerate(
1212+
for s, _ in enumerate(
11991213
tqdm(
1200-
range(0, total_length, subset_length),
1201-
disable=disable_tqdm,
1202-
position=0,
1214+
dataset.splits_start_end_iterator(),
1215+
total=dataset_splits,
12031216
desc=f"greedy -- Node {dist.get_rank(self.parallel_context.world_pg)}",
1217+
position=0,
1218+
disable=disable_tqdm,
12041219
)
12051220
):
1206-
dataset.split_start = subset_start
1207-
dataset.split_end = min(subset_start + subset_length, total_length)
1208-
1221+
# print(dataset[0])
12091222
_, (context_enc, _, _) = dataset[0]
12101223
max_gen = max(d[1][1][1] for d in dataset)
12111224
max_input_length = min(len(context_enc) + max_gen, self.max_length)
1225+
# max_input_length = len(context_enc)
12121226
batch_size = self._get_batch_size(
12131227
override_bs=override_bs, max_input_length=max_input_length, starting_batch_size=starting_batch_size
12141228
)
@@ -1360,7 +1374,7 @@ def greedy_until(
13601374
# We are in a process which return no output (beginning/middle of the PP group)
13611375
return []
13621376

1363-
return dataset.ordered.get_original(res)
1377+
return dataset.get_original_order(res)
13641378

13651379

13661380
class MultiTokenEOSCriteria(transformers.StoppingCriteria):

0 commit comments

Comments
 (0)