11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
-
15
14
import warnings
16
15
from dataclasses import dataclass , field
17
- from typing import Optional
16
+ from typing import Optional , Union
18
17
19
18
from transformers import TrainingArguments
20
19
@@ -34,7 +33,7 @@ class GRPOConfig(TrainingArguments):
34
33
Parameters:
35
34
> Parameters that control the model and reference model
36
35
37
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
36
+ model_init_kwargs (`str, dict[str, Any]` or `None`, *optional*, defaults to `None`):
38
37
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
39
38
argument of the [`GRPOTrainer`] is provided as a string.
40
39
@@ -137,9 +136,10 @@ class GRPOConfig(TrainingArguments):
137
136
num_completions_to_print (`int` or `None`, *optional*, defaults to `None`):
138
137
Number of completions to print with `rich`. If `None`, all completions are logged.
139
138
"""
139
+ _VALID_DICT_FIELDS = TrainingArguments ._VALID_DICT_FIELDS + ["model_init_kwargs" ]
140
140
141
141
# Parameters that control the model and reference model
142
- model_init_kwargs : Optional [dict ] = field (
142
+ model_init_kwargs : Optional [Union [ dict , str ] ] = field (
143
143
default = None ,
144
144
metadata = {
145
145
"help" : "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
0 commit comments