@@ -188,19 +188,6 @@ class OptimizerNames(ExplicitEnum):
188
188
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
189
189
190
190
191
- # Sometimes users will pass in a `str` repr of a dict in the CLI
192
- # We need to track what fields those can be. Each time a new arg
193
- # has a dict type, it must be added to this list.
194
- # Important: These should be typed with Optional[Union[dict,str,...]]
195
- _VALID_DICT_FIELDS = [
196
- "accelerator_config" ,
197
- "fsdp_config" ,
198
- "deepspeed" ,
199
- "gradient_checkpointing_kwargs" ,
200
- "lr_scheduler_kwargs" ,
201
- ]
202
-
203
-
204
191
def _convert_str_dict (passed_value : dict ):
205
192
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
206
193
for key , value in passed_value .items ():
@@ -814,6 +801,18 @@ class TrainingArguments:
814
801
https://github.com/huggingface/transformers/issues/34242
815
802
"""
816
803
804
+ # Sometimes users will pass in a `str` repr of a dict in the CLI
805
+ # We need to track what fields those can be. Each time a new arg
806
+ # has a dict type, it must be added to this list.
807
+ # Important: These should be typed with Optional[Union[dict,str,...]]
808
+ _VALID_DICT_FIELDS = [
809
+ "accelerator_config" ,
810
+ "fsdp_config" ,
811
+ "deepspeed" ,
812
+ "gradient_checkpointing_kwargs" ,
813
+ "lr_scheduler_kwargs" ,
814
+ ]
815
+
817
816
framework = "pt"
818
817
output_dir : Optional [str ] = field (
819
818
default = None ,
@@ -1561,7 +1560,7 @@ def __post_init__(self):
1561
1560
)
1562
1561
1563
1562
# Parse in args that could be `dict` sent in from the CLI as a string
1564
- for field in _VALID_DICT_FIELDS :
1563
+ for field in self . _VALID_DICT_FIELDS :
1565
1564
passed_value = getattr (self , field )
1566
1565
# We only want to do this if the str starts with a bracket to indicate a `dict`
1567
1566
# else its likely a filename if supported
0 commit comments