11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
-
14
+ import json
15
15
from dataclasses import dataclass , field
16
- from typing import Optional
16
+ from typing import Optional , Union
17
17
18
18
from transformers import TrainingArguments
19
+ from transformers .training_args import _convert_str_dict
20
+
21
+
22
+ # Sometimes users will pass in a `str` repr of a dict in the CLI
23
+ # We need to track what fields those can be. Each time a new arg
24
+ # has a dict type, it must be added to this list.
25
+ # Important: These should be typed with Optional[Union[dict,str,...]]
26
+ _VALID_DICT_FIELDS = [
27
+ "model_init_kwargs" ,
28
+ ]
19
29
20
30
21
31
@dataclass
@@ -33,7 +43,7 @@ class GRPOConfig(TrainingArguments):
33
43
Parameters:
34
44
> Parameters that control the model and reference model
35
45
36
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
46
+ model_init_kwargs (`str, dict[str, Any]` or `None`, *optional*, defaults to `None`):
37
47
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
38
48
argument of the [`GRPOTrainer`] is provided as a string.
39
49
@@ -140,7 +150,7 @@ class GRPOConfig(TrainingArguments):
140
150
"""
141
151
142
152
# Parameters that control the model and reference model
143
- model_init_kwargs : Optional [dict ] = field (
153
+ model_init_kwargs : Optional [Union [ dict , str ] ] = field (
144
154
default = None ,
145
155
metadata = {
146
156
"help" : "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
@@ -338,3 +348,16 @@ class GRPOConfig(TrainingArguments):
338
348
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
339
349
},
340
350
)
351
+
352
+ def __post_init__ (self ):
353
+ # Parse in args that could be `dict` sent in from the CLI as a string
354
+ for field in _VALID_DICT_FIELDS :
355
+ passed_value = getattr (self , field )
356
+ # We only want to do this if the str starts with a bracket to indiciate a `dict`
357
+ # else its likely a filename if supported
358
+ if isinstance (passed_value , str ) and passed_value .startswith ("{" ):
359
+ loaded_dict = json .loads (passed_value )
360
+ # Convert str values to types if applicable
361
+ loaded_dict = _convert_str_dict (loaded_dict )
362
+ setattr (self , field , loaded_dict )
363
+ super ().__post_init__ ()
0 commit comments