Skip to content

Commit d9b262f

Browse files
committed
fix suite 1
1 parent e3cbe24 commit d9b262f

File tree

3 files changed

+55
-20
lines changed

3 files changed

+55
-20
lines changed

src/lighteval/data.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,8 @@ def _sorting_criteria(self, request: LoglikelihoodSingleTokenRequest) -> int:
174174
automatic adaptive batches much much easier to implement
175175
- any OOMs will happen right away rather than near the end
176176
"""
177-
toks = (
178-
request.tokenized_context
179-
) # We take only the prompt, no need for the continuation (since it's a list of single tokens)
177+
# We take only the prompt, no need for the continuation (since it's a list of single tokens)
178+
toks = request.tokenized_context
180179
return -len(toks)
181180

182181

@@ -191,7 +190,7 @@ def _sorting_criteria(self, request: GreedyUntilRequest | GreedyUntilWithLogitsR
191190
Returns:
192191
Any: The collated data.
193192
"""
194-
toks = (request.context,)
193+
toks = request.context
195194
gen_length = request.generation_size
196195
return -(len(toks) + gen_length)
197196

src/lighteval/models/base_model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@
3939

4040

4141
class BaseModel(LightevalModel):
42-
# Default max sequence length setting for when no `max_length` is provided
43-
# or no max length config setting is found in the model or tokenizer.
44-
_DEFAULT_MAX_LENGTH: int = 2048
45-
4642
def __init__(
4743
self,
4844
config: BaseModelConfig,
@@ -239,7 +235,9 @@ def _init_max_length(self, max_length) -> int:
239235

240236
if hasattr(self.tokenizer, "model_max_length"):
241237
return self.tokenizer.model_max_length
242-
return self._DEFAULT_MAX_LENGTH
238+
# Default max sequence length setting for when no `max_length` is provided
239+
# or no max length config setting is found in the model or tokenizer.
240+
return 2048
243241

244242
@property
245243
def batch_size(self) -> int:
@@ -696,7 +694,8 @@ def prepare_batch(
696694
raise ValueError("Negative padding")
697695

698696
padded.append(padding_length - sequence_len)
699-
tokens = F.pad(tokens, (padding_length - sequence_len), value=0)
697+
# Right padding - it likely would be better to do left padding
698+
tokens = F.pad(tokens, (0, padding_length - sequence_len), value=0)
700699

701700
# We create the attention mask to ignore padding
702701
mask = tokens == 0
@@ -782,10 +781,11 @@ def _loglikelihood_single_token(
782781
dataloader = self.accelerator.prepare(dataloader)
783782

784783
for batch in tqdm(dataloader, disable=self.disable_tqdm, position=1):
785-
prepared_batch = self.prepare_batch(batch, padding_length=max_context, max_context=max_context)
784+
prepared_batch = self.prepare_batch(
785+
batch, padding_length=max_context, max_context=max_context, single_token=True
786+
)
786787

787788
out = self._model_call(prepared_batch.input_ids) # [batch, padding_length, vocab]
788-
789789
out = F.log_softmax(out, dim=-1) # we do a softmax over the options, no the vocab
790790

791791
batch_probs = []

tests/test_unit_reorder.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,49 @@
11
from lighteval.data import GenerativeTaskDataset
2+
from lighteval.tasks.requests import GreedyUntilRequest
23

34

45
# test data that will need to be sorted by length of the string
56
data = [
6-
("1 The quick brown fox jumps over the lazy dog", ([":", "stop"], 10)),
7-
("2 The quick brown fox jumps over the lazy dog njsa", ([":", "stop"], 10)),
8-
("Some text", ([":", "stop"], 10)),
9-
("some more text", ([":", "stop"], 10)),
10-
("not sure what to write here", ([":", "stop"], 10)),
7+
GreedyUntilRequest(
8+
task_name="test",
9+
example_index=0,
10+
request_index=0,
11+
context="1 The quick brown fox jumps over the lazy dog",
12+
stop_sequence=[":", "stop"],
13+
generation_size=10,
14+
),
15+
GreedyUntilRequest(
16+
task_name="test",
17+
example_index=2,
18+
request_index=0,
19+
context="2 The quick brown fox jumps over the lazy dog njsa",
20+
stop_sequence=[":", "stop"],
21+
generation_size=10,
22+
),
23+
GreedyUntilRequest(
24+
task_name="test",
25+
example_index=5,
26+
request_index=0,
27+
context="Some text",
28+
stop_sequence=[":", "stop"],
29+
generation_size=10,
30+
),
31+
GreedyUntilRequest(
32+
task_name="test",
33+
example_index=21,
34+
request_index=0,
35+
context="some more text",
36+
stop_sequence=[":", "stop"],
37+
generation_size=10,
38+
),
39+
GreedyUntilRequest(
40+
task_name="test",
41+
example_index=1,
42+
request_index=0,
43+
context="not sure what to write here",
44+
stop_sequence=[":", "stop"],
45+
generation_size=10,
46+
),
1147
]
1248

1349
DATASET_SPLITS = 1
@@ -21,9 +57,9 @@ def test_reorder_dataset(self):
2157
original_data = dataset.get_original_order(sorted_data)
2258

2359
for i in range(len(sorted_data) - 1):
24-
assert len(sorted_data[i][0]) >= len(
25-
sorted_data[i + 1][0]
26-
), f"dataset[{i}][0] = {sorted_data[i][0]} is shorter than dataset[{i+1}][0] = {sorted_data[i+1][0]}"
60+
assert (
61+
len(sorted_data[i].context) >= len(sorted_data[i + 1].context)
62+
), f"dataset[{i}][0] = {sorted_data[i].context} is shorter than dataset[{i+1}][0] = {sorted_data[i+1].context}"
2763

2864
assert len(sorted_data) == len(
2965
original_data

0 commit comments

Comments
 (0)