Skip to content

Commit 5e703eb

Browse files
committed
make _VALID_DICT_FIELDS as a class attribute
1 parent 348f328 commit 5e703eb

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

src/transformers/training_args.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -188,19 +188,6 @@ class OptimizerNames(ExplicitEnum):
188188
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
189189

190190

191-
# Sometimes users will pass in a `str` repr of a dict in the CLI
192-
# We need to track what fields those can be. Each time a new arg
193-
# has a dict type, it must be added to this list.
194-
# Important: These should be typed with Optional[Union[dict,str,...]]
195-
_VALID_DICT_FIELDS = [
196-
"accelerator_config",
197-
"fsdp_config",
198-
"deepspeed",
199-
"gradient_checkpointing_kwargs",
200-
"lr_scheduler_kwargs",
201-
]
202-
203-
204191
def _convert_str_dict(passed_value: dict):
205192
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
206193
for key, value in passed_value.items():
@@ -814,6 +801,18 @@ class TrainingArguments:
814801
https://github.com/huggingface/transformers/issues/34242
815802
"""
816803

804+
# Sometimes users will pass in a `str` repr of a dict in the CLI
805+
# We need to track what fields those can be. Each time a new arg
806+
# has a dict type, it must be added to this list.
807+
# Important: These should be typed with Optional[Union[dict,str,...]]
808+
_VALID_DICT_FIELDS = [
809+
"accelerator_config",
810+
"fsdp_config",
811+
"deepspeed",
812+
"gradient_checkpointing_kwargs",
813+
"lr_scheduler_kwargs",
814+
]
815+
817816
framework = "pt"
818817
output_dir: Optional[str] = field(
819818
default=None,
@@ -1561,7 +1560,7 @@ def __post_init__(self):
15611560
)
15621561

15631562
# Parse in args that could be `dict` sent in from the CLI as a string
1564-
for field in _VALID_DICT_FIELDS:
1563+
for field in self._VALID_DICT_FIELDS:
15651564
passed_value = getattr(self, field)
15661565
# We only want to do this if the str starts with a bracket to indicate a `dict`
15671566
# else its likely a filename if supported

0 commit comments

Comments
 (0)