Skip to content

Commit ff5026b

Browse files
clefourrieralbertvillanovaNathanHB
authored
Apply suggestions from code review
Co-authored-by: Albert Villanova del Moral <[email protected]> Co-authored-by: Nathan Habib <[email protected]>
1 parent 90593a9 commit ff5026b

File tree

2 files changed

+6
-39
lines changed

2 files changed

+6
-39
lines changed

src/lighteval/models/model_input.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23-
from dataclasses import dataclass
23+
from dataclasses import dataclass, asdict
2424
from typing import Optional
2525

2626

@@ -57,24 +57,7 @@ def from_dict(cls, config_dict: dict):
5757
}
5858
}
5959
"""
60-
if "generation" not in config_dict:
61-
return GenerationParameters()
62-
return GenerationParameters(
63-
early_stopping=config_dict["generation"].get("early_stopping", None),
64-
repetition_penalty=config_dict["generation"].get("repetition_penalty", None),
65-
frequency_penalty=config_dict["generation"].get("frequency_penalty", None),
66-
length_penalty=config_dict["generation"].get("length_penalty", None),
67-
presence_penalty=config_dict["generation"].get("presence_penalty", None),
68-
max_new_tokens=config_dict["generation"].get("max_new_tokens", None),
69-
min_new_tokens=config_dict["generation"].get("min_new_tokens", None),
70-
seed=config_dict["generation"].get("seed", None),
71-
stop_tokens=config_dict["generation"].get("stop_tokens", None),
72-
temperature=config_dict["generation"].get("temperature", None),
73-
top_k=config_dict["generation"].get("top_k", None),
74-
min_p=config_dict["generation"].get("min_p", None),
75-
top_p=config_dict["generation"].get("top_p", None),
76-
truncate_prompt=config_dict["generation"].get("truncate_prompt", None),
77-
)
60+
return GenerationParameters(**config_dict.get("generation", {}))
7861

7962
def to_vllm_openai_dict(self) -> dict:
8063
"""Selects relevant generation and sampling parameters for vllm and openai models.
@@ -85,23 +68,7 @@ def to_vllm_openai_dict(self) -> dict:
8568
"""
8669
# Task specific sampling params to set in model: n, best_of, use_beam_search
8770
# Generation specific params to set in model: logprobs, prompt_logprobs
88-
args = {
89-
"presence_penalty": self.presence_penalty,
90-
"frequency_penalty": self.frequency_penalty,
91-
"repetition_penalty": self.repetition_penalty,
92-
"temperature": self.temperature,
93-
"top_p": self.top_p,
94-
"top_k": self.top_k,
95-
"min_p": self.min_p,
96-
"seed": self.seed,
97-
"length_penalty": self.length_penalty,
98-
"early_stopping": self.early_stopping,
99-
"stop": self.stop_tokens,
100-
"max_tokens": self.max_new_tokens,
101-
"min_tokens": self.min_new_tokens,
102-
"truncate_prompt_tokens": self.truncate_prompt,
103-
}
104-
return {k: v for k, v in args.items() if v is not None}
71+
return {k: v for k, v in asdict(self).items() if v is not None}
10572

10673
def to_transformers_dict(self) -> dict:
10774
"""Selects relevant generation and sampling parameters for transformers models.
@@ -117,7 +84,7 @@ def to_transformers_dict(self) -> dict:
11784
args = {
11885
"max_new_tokens": self.max_new_tokens,
11986
"min_new_tokens": self.min_new_tokens,
120-
"early_stopping": self.early_stopping or False,
87+
"early_stopping": self.early_stopping,
12188
"stop_strings": self.stop_tokens,
12289
"temperature": self.temperature,
12390
"top_k": self.top_k,

src/lighteval/models/transformers/transformers_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,9 +1350,9 @@ def _loglikelihood_single_token(
13501350

13511351
class BaseModel(TransformersModel):
13521352
def __post_init__(self):
1353-
super()
1353+
super().__post_init__()
13541354

1355-
logger.warning(
1355+
warnings.warn(
13561356
"Careful, the BaseModel name is deprecated and will be removed, you should use TransformersModel instead!"
13571357
)
13581358

0 commit comments

Comments
 (0)