20
20
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
21
# SOFTWARE.
22
22
23
- from dataclasses import dataclass
23
+ from dataclasses import dataclass , asdict
24
24
from typing import Optional
25
25
26
26
@@ -57,24 +57,7 @@ def from_dict(cls, config_dict: dict):
57
57
}
58
58
}
59
59
"""
60
- if "generation" not in config_dict :
61
- return GenerationParameters ()
62
- return GenerationParameters (
63
- early_stopping = config_dict ["generation" ].get ("early_stopping" , None ),
64
- repetition_penalty = config_dict ["generation" ].get ("repetition_penalty" , None ),
65
- frequency_penalty = config_dict ["generation" ].get ("frequency_penalty" , None ),
66
- length_penalty = config_dict ["generation" ].get ("length_penalty" , None ),
67
- presence_penalty = config_dict ["generation" ].get ("presence_penalty" , None ),
68
- max_new_tokens = config_dict ["generation" ].get ("max_new_tokens" , None ),
69
- min_new_tokens = config_dict ["generation" ].get ("min_new_tokens" , None ),
70
- seed = config_dict ["generation" ].get ("seed" , None ),
71
- stop_tokens = config_dict ["generation" ].get ("stop_tokens" , None ),
72
- temperature = config_dict ["generation" ].get ("temperature" , None ),
73
- top_k = config_dict ["generation" ].get ("top_k" , None ),
74
- min_p = config_dict ["generation" ].get ("min_p" , None ),
75
- top_p = config_dict ["generation" ].get ("top_p" , None ),
76
- truncate_prompt = config_dict ["generation" ].get ("truncate_prompt" , None ),
77
- )
60
+ return GenerationParameters (** config_dict .get ("generation" , {}))
78
61
79
62
def to_vllm_openai_dict (self ) -> dict :
80
63
"""Selects relevant generation and sampling parameters for vllm and openai models.
@@ -85,23 +68,7 @@ def to_vllm_openai_dict(self) -> dict:
85
68
"""
86
69
# Task specific sampling params to set in model: n, best_of, use_beam_search
87
70
# Generation specific params to set in model: logprobs, prompt_logprobs
88
- args = {
89
- "presence_penalty" : self .presence_penalty ,
90
- "frequency_penalty" : self .frequency_penalty ,
91
- "repetition_penalty" : self .repetition_penalty ,
92
- "temperature" : self .temperature ,
93
- "top_p" : self .top_p ,
94
- "top_k" : self .top_k ,
95
- "min_p" : self .min_p ,
96
- "seed" : self .seed ,
97
- "length_penalty" : self .length_penalty ,
98
- "early_stopping" : self .early_stopping ,
99
- "stop" : self .stop_tokens ,
100
- "max_tokens" : self .max_new_tokens ,
101
- "min_tokens" : self .min_new_tokens ,
102
- "truncate_prompt_tokens" : self .truncate_prompt ,
103
- }
104
- return {k : v for k , v in args .items () if v is not None }
71
+ return {k : v for k , v in asdict (self ).items () if v is not None }
105
72
106
73
def to_transformers_dict (self ) -> dict :
107
74
"""Selects relevant generation and sampling parameters for transformers models.
@@ -117,7 +84,7 @@ def to_transformers_dict(self) -> dict:
117
84
args = {
118
85
"max_new_tokens" : self .max_new_tokens ,
119
86
"min_new_tokens" : self .min_new_tokens ,
120
- "early_stopping" : self .early_stopping or False ,
87
+ "early_stopping" : self .early_stopping ,
121
88
"stop_strings" : self .stop_tokens ,
122
89
"temperature" : self .temperature ,
123
90
"top_k" : self .top_k ,
0 commit comments