Skip to content

Commit bb2d303

Browse files
committed
make _VALID_DICT_FIELDS as a class attribute
1 parent c9d1e52 commit bb2d303

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

src/transformers/training_args.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,6 @@ class OptimizerNames(ExplicitEnum):
189189
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
190190

191191

192-
# Sometimes users will pass in a `str` repr of a dict in the CLI
193-
# We need to track what fields those can be. Each time a new arg
194-
# has a dict type, it must be added to this list.
195-
# Important: These should be typed with Optional[Union[dict,str,...]]
196-
_VALID_DICT_FIELDS = [
197-
"accelerator_config",
198-
"fsdp_config",
199-
"deepspeed",
200-
"gradient_checkpointing_kwargs",
201-
"lr_scheduler_kwargs",
202-
]
203-
204-
205192
def _convert_str_dict(passed_value: dict):
206193
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
207194
for key, value in passed_value.items():
@@ -823,6 +810,18 @@ class TrainingArguments:
823810
https://github.com/huggingface/transformers/issues/34242
824811
"""
825812

813+
# Sometimes users will pass in a `str` repr of a dict in the CLI
814+
# We need to track what fields those can be. Each time a new arg
815+
# has a dict type, it must be added to this list.
816+
# Important: These should be typed with Optional[Union[dict,str,...]]
817+
_VALID_DICT_FIELDS = [
818+
"accelerator_config",
819+
"fsdp_config",
820+
"deepspeed",
821+
"gradient_checkpointing_kwargs",
822+
"lr_scheduler_kwargs",
823+
]
824+
826825
framework = "pt"
827826
output_dir: Optional[str] = field(
828827
default=None,
@@ -1584,7 +1583,7 @@ def __post_init__(self):
15841583
)
15851584

15861585
# Parse in args that could be `dict` sent in from the CLI as a string
1587-
for field in _VALID_DICT_FIELDS:
1586+
for field in self._VALID_DICT_FIELDS:
15881587
passed_value = getattr(self, field)
15891588
# We only want to do this if the str starts with a bracket to indicate a `dict`
15901589
# else its likely a filename if supported

0 commit comments

Comments
 (0)