-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtrain_grpo.py
142 lines (122 loc) · 4.21 KB
/
train_grpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Train a model using GRPO (Generative Reward-Penalized Optimization).
"""
from unsloth import FastLanguageModel, is_bfloat16_supported
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
from config import (
MODEL_CONFIG,
MODEL_NAME,
OUTPUT_DIR,
TRAINING_CONFIG,
get_sampling_params,
init_training_dirs,
logger,
update_log_path,
)
# Import reward functions
from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry
from src.agent import Agent
from src.rewards import (
build_reward_correctness_fn,
reward_em_chunk,
reward_format,
reward_retry,
reward_search_diversity,
reward_search_strategy,
)
from src.search_module import get_qa_dataset
from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
# Initialize training directories
paths = init_training_dirs()
# Update logger to use the training directory
update_log_path(paths["log_dir"])
logger.info(f"Training output directory: {paths['output_dir']}")
logger.info(f"Logs are being saved to both ./logs and {paths['log_dir']}")
# Initialize model and tokenizer
logger.info(f"Initializing model {MODEL_NAME}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=MODEL_CONFIG["max_seq_length"],
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
max_lora_rank=MODEL_CONFIG["lora_rank"],
gpu_memory_utilization=MODEL_CONFIG["gpu_memory_utilization"],
)
# Setup LoRA
logger.info("Setting up LoRA adapter")
model = FastLanguageModel.get_peft_model(
model,
r=MODEL_CONFIG["lora_rank"],
target_modules=MODEL_CONFIG["target_modules"],
lora_alpha=MODEL_CONFIG["lora_rank"],
use_gradient_checkpointing=True, # Enable long context finetuning
random_state=3407,
)
# Load datasets
logger.info("Loading datasets")
train_dataset, test_dataset = get_qa_dataset(randomize=False, test_size=0.1, seed=42)
logger.info(f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples")
# Setup training arguments
logger.info("Setting up training arguments")
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
use_vllm=True, # use vLLM for fast inference!
use_agentic_generate=True, # use agentic generation
**TRAINING_CONFIG,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
output_dir=OUTPUT_DIR,
reward_weights=[4.0, 2.0, 1.0, 1.0, 1.0, 1.0],
)
# Setup model generation functions
def agentic_generate(
prompts: list,
generate_fn,
max_generations: int = 32,
max_new_tokens: int = 4096 * 2,
):
# Create agent with appropriate adapter based on tokenizer
tokenizer_name = tokenizer.name_or_path.lower()
if "deepseek-ai/deepseek-r1-distill" in tokenizer_name:
adapter = R1DistilTokenizerAdapter()
elif "llama" in tokenizer_name:
adapter = LlamaTokenizerAdapter()
elif "qwen" in tokenizer_name:
adapter = QwenTokenizerAdapter()
else:
raise ValueError(f"Unsupported tokenizer: {tokenizer_name}")
agent = Agent(adapter)
return agent.run_agent(generate_fn, tokenizer, prompts, max_generations, max_new_tokens=max_new_tokens)
model.agentic_generate = agentic_generate
# Setup verifier
logger.info("Setting up verifier")
verifier_sampling_params = get_sampling_params(temperature=0.1)
def verifier_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=verifier_sampling_params,
)
# Setup trainer
logger.info("Initializing trainer")
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
build_reward_correctness_fn(
vllm_generate_func=verifier_generate_fn,
tokenizer=tokenizer,
),
reward_format,
reward_retry,
reward_em_chunk,
reward_search_strategy,
reward_search_diversity,
],
args=training_args,
train_dataset=train_dataset,
)
# Train the model
if __name__ == "__main__":
logger.info("Starting training")
trainer.train()
logger.info("Training completed")
logger.info(f"Model saved to {OUTPUT_DIR}")