@@ -189,19 +189,6 @@ class OptimizerNames(ExplicitEnum):
189
189
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
190
190
191
191
192
- # Sometimes users will pass in a `str` repr of a dict in the CLI
193
- # We need to track what fields those can be. Each time a new arg
194
- # has a dict type, it must be added to this list.
195
- # Important: These should be typed with Optional[Union[dict,str,...]]
196
- _VALID_DICT_FIELDS = [
197
- "accelerator_config" ,
198
- "fsdp_config" ,
199
- "deepspeed" ,
200
- "gradient_checkpointing_kwargs" ,
201
- "lr_scheduler_kwargs" ,
202
- ]
203
-
204
-
205
192
def _convert_str_dict (passed_value : dict ):
206
193
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
207
194
for key , value in passed_value .items ():
@@ -823,6 +810,18 @@ class TrainingArguments:
823
810
https://github.com/huggingface/transformers/issues/34242
824
811
"""
825
812
813
+ # Sometimes users will pass in a `str` repr of a dict in the CLI
814
+ # We need to track what fields those can be. Each time a new arg
815
+ # has a dict type, it must be added to this list.
816
+ # Important: These should be typed with Optional[Union[dict,str,...]]
817
+ _VALID_DICT_FIELDS = [
818
+ "accelerator_config" ,
819
+ "fsdp_config" ,
820
+ "deepspeed" ,
821
+ "gradient_checkpointing_kwargs" ,
822
+ "lr_scheduler_kwargs" ,
823
+ ]
824
+
826
825
framework = "pt"
827
826
output_dir : Optional [str ] = field (
828
827
default = None ,
@@ -1584,7 +1583,7 @@ def __post_init__(self):
1584
1583
)
1585
1584
1586
1585
# Parse in args that could be `dict` sent in from the CLI as a string
1587
- for field in _VALID_DICT_FIELDS :
1586
+ for field in self . _VALID_DICT_FIELDS :
1588
1587
passed_value = getattr (self , field )
1589
1588
# We only want to do this if the str starts with a bracket to indicate a `dict`
1590
1589
# else its likely a filename if supported
0 commit comments