Skip to content

Commit a585701

Browse files
committed
removed the use of a GenerationConfig object, as it's got lots of params set by default which slow down generations
1 parent 3eb7d0f commit a585701

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

src/lighteval/models/model_input.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,17 @@ def to_transformers_dict(self) -> dict:
107107
"""Selects relevant generation and sampling parameters for transformers models.
108108
Doc: https://huggingface.co/docs/transformers/v4.46.3/en/main_classes/text_generation#transformers.GenerationConfig
109109
110+
Note: We actually don't use the GenerationConfig object itself because it has a huge number of parameters automatically
111+
initialized, to a config which slows down evals insanely.
112+
110113
Returns:
111114
dict: The parameters to create a transformers.GenerationConfig in the model config.
112115
"""
113116
# Task specific sampling params to set in model: do_sample, num_return_sequences, num_beans
114117
args = {
115118
"max_new_tokens": self.max_new_tokens,
116119
"min_new_tokens": self.min_new_tokens,
117-
"early_stopping": self.early_stopping,
120+
"early_stopping": self.early_stopping or False,
118121
"stop_strings": self.stop_tokens,
119122
"temperature": self.temperature,
120123
"top_k": self.top_k,
@@ -125,8 +128,6 @@ def to_transformers_dict(self) -> dict:
125128
"output_scores": True,
126129
"return_dict_in_generate": True,
127130
}
128-
# Even though we only use the dict representation of the GenerationConfig
129-
# we still create the object as it uses validation steps
130131
return {k: v for k, v in args.items() if v is not None}
131132

132133
def to_tgi_inferenceendpoint_dict(self) -> dict:

src/lighteval/models/transformers/transformers_model.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
AutoModelForCausalLM,
3737
AutoTokenizer,
3838
BitsAndBytesConfig,
39-
GenerationConfig,
4039
GPTQConfig,
4140
PretrainedConfig,
4241
)
@@ -656,7 +655,7 @@ def greedy_until_multi_turn( # noqa: C901
656655
]
657656
)
658657

659-
generation_config = GenerationConfig.from_dict(self.generation_config_dict)
658+
generation_config = self.generation_config_dict.copy()
660659
generation_config.update(
661660
{
662661
"max_new_tokens": max_generated_tokens,
@@ -669,7 +668,7 @@ def greedy_until_multi_turn( # noqa: C901
669668
)
670669

671670
model_outputs: GenerateOutput = self.model.generate(
672-
**model_inputs, stopping_criteria=stopping_criteria, generation_config=generation_config
671+
**model_inputs, stopping_criteria=stopping_criteria, **generation_config
673672
)
674673
model_outputs = model_outputs.sequences[0, model_inputs["input_ids"].size(1) :]
675674
model_generations = [model_outputs]
@@ -699,7 +698,7 @@ def greedy_until_multi_turn( # noqa: C901
699698
]
700699
)
701700

702-
generation_config = GenerationConfig.from_dict(self.generation_config_dict)
701+
generation_config = self.generation_config_dict.copy()
703702
generation_config.update(
704703
{
705704
"max_new_tokens": max_generated_tokens,
@@ -715,7 +714,7 @@ def greedy_until_multi_turn( # noqa: C901
715714
input_ids=model_inputs["input_ids"],
716715
attention_mask=model_inputs["attention_mask"],
717716
stopping_criteria=stopping_criteria,
718-
generation_config=generation_config,
717+
**generation_config,
719718
)
720719
model_outputs = model_outputs.sequences[0, model_inputs["input_ids"].size(1) :]
721720
model_generations.append(model_outputs)
@@ -896,7 +895,7 @@ def _generate(
896895
stopping_criteria = stop_sequences_criteria(self.tokenizer, stop_sequences=stop_tokens, batch=batch)
897896
batch_size, _ = batch.input_ids.shape
898897

899-
generation_config = GenerationConfig.from_dict(self.generation_config_dict)
898+
generation_config = self.generation_config_dict.copy()
900899
generation_config.update(
901900
max_new_tokens=max_new_tokens,
902901
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id,
@@ -912,7 +911,7 @@ def _generate(
912911
input_ids=batch.input_ids,
913912
attention_mask=batch.input_mask,
914913
stopping_criteria=stopping_criteria,
915-
generation_config=generation_config,
914+
**generation_config,
916915
)
917916
generations = outputs.sequences[:, batch.input_ids.size(1) :]
918917
generations = torch.reshape(generations, (batch_size, num_samples, -1))

0 commit comments

Comments
 (0)