Implementation of Group Relative Policy Optimization (GRPO) to fine-tune Open Language Models like LlaMa-3.2, Qwen2.5 for Math Tasks.
This project implements a simple and customizable pipeline for fine-tuning language models using GRPO strategy introduced by DeepSeekMath. This approach utilizes group-based relative advantage estimates to guide the optimization. The training data comprises synthetic math expressions and their solutions. The framework is flexible and can be extended to other tasks.
pytorch · transformers · peft · flash-attn · pyyaml · tensorboard
- Edit the
config.yml
file with the required hyperparameters and Language model path. - Then, run the following command to fine_tune the model:
python3 grpo_train.py
If needed, replace the existing dataset with your custom data.
The table below demonstrates the improvement in model responses before and after GRPO fine-tuning:
Model: LlaMa-3.2-1B-Instruct
Before Fine-tuning |
After GRPO Fine-tuning |
---|---|
Prompt: | Expand 3*(x+1)*(x+2) |
Responses: | |
<think> 1. Multiply the first two terms: 3*(x+1) = 3x + 3 2. Multiply the result by the third term: (3x + 3)*(x+2) </think> <answer> 3x^2 + 3x + 6x + 6 </answer> |
<think> 1. To expand the expression 3*(x+1)(x+2), we need to apply the distributive property to expand the expression. 2. The distributive property states that a(b+c) = ab + ac. 3. So, we can expand the expression as 3(x+1)(x+2) = 3(xx + x2 + 1x + 12). 4. Now, we can simplify the expression by multiplying the terms. 5. 3*(xx + x2 + 1x + 12) = 3x^2 + 6x + 3x + 2. 6. Combine like terms. 7. 3x^2 + 6x + 3x + 2 = 3x^2 + 9x + 2. 8. The final expanded expression is 3x^2 + 9x + 2. </think> <answer> 3x^2 + 9x + 2 </answer> |
Prompt: | Solve 78-(3*(47-2)+2) |
Responses: | |
<think> 1. First, we need to evaluate the expression inside the parentheses: 47-2 2. 47 - 2 = 45 3. Now, we multiply 3 by the result: 3 * 45 = 135 4. Finally, we subtract 135 from 78: 78 - 135 = -57 </think> <answer> The final answer is -57 </answer> |
<think> 1. Start by evaluating the expression inside the parentheses: 47-2 = 45 2. Now, multiply 3 by the result: 3 * 45 = 135 3. Add 2 to the result: 135 + 2 = 137 4. Finally, subtract 137 from 78: 78 - 137 = -59 </think> <answer> -59 </answer> |
Open an issue or submit a pull request if you have any suggestions or improvements.
- This project utilizes models and tools made available by Hugging Face.