Skip to content

Commit 3e44f00

Browse files
committed
add cli dict extend
1 parent 9f3702f commit 3e44f00

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

trl/trainer/grpo_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
import warnings
1615
from dataclasses import dataclass, field
17-
from typing import Optional
16+
from typing import Optional, Union
1817

1918
from transformers import TrainingArguments
2019

@@ -34,7 +33,7 @@ class GRPOConfig(TrainingArguments):
3433
Parameters:
3534
> Parameters that control the model and reference model
3635
37-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
36+
model_init_kwargs (`str, dict[str, Any]` or `None`, *optional*, defaults to `None`):
3837
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
3938
argument of the [`GRPOTrainer`] is provided as a string.
4039
@@ -137,9 +136,10 @@ class GRPOConfig(TrainingArguments):
137136
num_completions_to_print (`int` or `None`, *optional*, defaults to `None`):
138137
Number of completions to print with `rich`. If `None`, all completions are logged.
139138
"""
139+
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
140140

141141
# Parameters that control the model and reference model
142-
model_init_kwargs: Optional[dict] = field(
142+
model_init_kwargs: Optional[Union[dict, str]] = field(
143143
default=None,
144144
metadata={
145145
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "

0 commit comments

Comments
 (0)