-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_trl.py
45 lines (35 loc) · 1.13 KB
/
train_trl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#!/usr/bin/env python
from transformers import AutoTokenizer
from trl import GRPOConfig, GRPOTrainer, ModelConfig, TrlParser
from prep_data import get_gsm8k_questions
from reward import (
correctness_reward_func,
int_reward_func,
strict_format_reward_func,
soft_format_reward_func,
xmlcount_reward_func,
)
def main(training_args, model_args):
data = get_gsm8k_questions()
reward_funcs = [
correctness_reward_func,
int_reward_func,
strict_format_reward_func,
soft_format_reward_func,
xmlcount_reward_func,
]
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
trainer = GRPOTrainer(
model=model_args.model_name_or_path,
processing_class=tokenizer,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=data,
)
trainer.train()
trainer.save_model(training_args.output_dir)
if __name__ == "__main__":
parser = TrlParser((GRPOConfig, ModelConfig))
training_args, model_args = parser.parse_args_and_config()
main(training_args, model_args)