Skip to content

Commit 0743c5a

Browse files
committed
add cli dict extend
1 parent fc4dae2 commit 0743c5a

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

trl/trainer/grpo_config.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,21 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import json
1515
from dataclasses import dataclass, field
16-
from typing import Optional
16+
from typing import Optional, Union
1717

1818
from transformers import TrainingArguments
19+
from transformers.training_args import _convert_str_dict
20+
21+
22+
# Sometimes users will pass in a `str` repr of a dict in the CLI
23+
# We need to track what fields those can be. Each time a new arg
24+
# has a dict type, it must be added to this list.
25+
# Important: These should be typed with Optional[Union[dict,str,...]]
26+
_VALID_DICT_FIELDS = [
27+
"model_init_kwargs",
28+
]
1929

2030

2131
@dataclass
@@ -33,7 +43,7 @@ class GRPOConfig(TrainingArguments):
3343
Parameters:
3444
> Parameters that control the model and reference model
3545
36-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
46+
model_init_kwargs (`str, dict[str, Any]` or `None`, *optional*, defaults to `None`):
3747
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
3848
argument of the [`GRPOTrainer`] is provided as a string.
3949
@@ -140,7 +150,7 @@ class GRPOConfig(TrainingArguments):
140150
"""
141151

142152
# Parameters that control the model and reference model
143-
model_init_kwargs: Optional[dict] = field(
153+
model_init_kwargs: Optional[Union[dict, str]] = field(
144154
default=None,
145155
metadata={
146156
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
@@ -338,3 +348,16 @@ class GRPOConfig(TrainingArguments):
338348
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
339349
},
340350
)
351+
352+
def __post_init__(self):
353+
# Parse in args that could be `dict` sent in from the CLI as a string
354+
for field in _VALID_DICT_FIELDS:
355+
passed_value = getattr(self, field)
356+
# We only want to do this if the str starts with a bracket to indiciate a `dict`
357+
# else its likely a filename if supported
358+
if isinstance(passed_value, str) and passed_value.startswith("{"):
359+
loaded_dict = json.loads(passed_value)
360+
# Convert str values to types if applicable
361+
loaded_dict = _convert_str_dict(loaded_dict)
362+
setattr(self, field, loaded_dict)
363+
super().__post_init__()

0 commit comments

Comments
 (0)