Skip to content

Implementation of Group Relative Policy Optimization (GRPO) to fine-tune Open Language Models like LlaMa-3.2, Qwen2 for Math Tasks.

License

Notifications You must be signed in to change notification settings

Bharath2/Minimal-GRPO

Repository files navigation

Minimal-GRPO

Implementation of Group Relative Policy Optimization (GRPO) to fine-tune Open Language Models like LlaMa-3.2, Qwen2.5 for Math Tasks.

Overview

This project implements a simple and customizable pipeline for fine-tuning language models using GRPO strategy introduced by DeepSeekMath. This approach utilizes group-based relative advantage estimates to guide the optimization. The training data comprises synthetic math expressions and their solutions. The framework is flexible and can be extended to other tasks.

Requirements

pytorch · transformers · peft · flash-attn · pyyaml · tensorboard

Usage

  1. Edit the config.yml file with the required hyperparameters and Language model path.
  2. Then, run the following command to fine_tune the model:
python3 grpo_train.py

If needed, replace the existing dataset with your custom data.

Results

The table below demonstrates the improvement in model responses before and after GRPO fine-tuning:

Model: LlaMa-3.2-1B-Instruct

Before Fine-tuning
After GRPO Fine-tuning
Prompt: Expand 3*(x+1)*(x+2)
Responses:
<think>
1. Multiply the first two terms: 3*(x+1) = 3x + 3
2. Multiply the result by the third term: (3x + 3)*(x+2)
</think>
<answer> 3x^2 + 3x + 6x + 6 </answer>
<think>
1. To expand the expression 3*(x+1)(x+2), we need to apply the distributive property to expand the expression.
2. The distributive property states that a(b+c) = ab + ac.
3. So, we can expand the expression as 3
(x+1)(x+2) = 3(xx + x2 + 1x + 12).
4. Now, we can simplify the expression by multiplying the terms.
5. 3*(xx + x2 + 1x + 12) = 3x^2 + 6x + 3x + 2.
6. Combine like terms.
7. 3
x^2 + 6x + 3x + 2 = 3x^2 + 9x + 2.
8. The final expanded expression is 3x^2 + 9x + 2.
</think>
<answer> 3x^2 + 9x + 2 </answer>
Prompt: Solve 78-(3*(47-2)+2)
Responses:
<think>
1. First, we need to evaluate the expression inside the parentheses: 47-2
2. 47 - 2 = 45
3. Now, we multiply 3 by the result: 3 * 45 = 135
4. Finally, we subtract 135 from 78: 78 - 135 = -57
</think>
<answer> The final answer is -57 </answer>
<think>
1. Start by evaluating the expression inside the parentheses: 47-2 = 45
2. Now, multiply 3 by the result: 3 * 45 = 135
3. Add 2 to the result: 135 + 2 = 137
4. Finally, subtract 137 from 78: 78 - 137 = -59
</think>
<answer> -59 </answer>

Rewards accumulated per step during training

Reward during training

Contributing

Open an issue or submit a pull request if you have any suggestions or improvements.

Acknowledgments

  • This project utilizes models and tools made available by Hugging Face.

About

Implementation of Group Relative Policy Optimization (GRPO) to fine-tune Open Language Models like LlaMa-3.2, Qwen2 for Math Tasks.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published