-
Notifications
You must be signed in to change notification settings - Fork 388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(nyz&dcy): add LLM/VLM RLHF loss (PPO/GRPO/RLOO) #857
base: main
Are you sure you want to change the base?
Changes from 6 commits
2a51392
2e49437
8d34eac
71190d4
2cbd9fb
17a7a71
5358b8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from typing import Tuple | ||
from collections import namedtuple | ||
import torch | ||
|
||
grpo_policy_data = namedtuple('grpo_policy_data', ['logit_new', 'logit_old', 'logit_ref', 'action', 'adv', 'weight']) | ||
|
||
|
||
def grpo_policy_error( | ||
data: namedtuple, | ||
clip_ratio: float = 0.2, | ||
beta: float = 0.1, # Weight coefficient for KL divergence | ||
) -> Tuple[namedtuple, namedtuple]: | ||
"""Calculate the policy loss for GRPO | ||
Args: | ||
data (grpo_policy_data): Data containing the following fields: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. polish comment formats |
||
- logit_new: Current policy logits [B, L, V] | ||
- logit_old: Old policy logits [B, L, V] | ||
- logit_ref: Reference policy logits [B, L, V] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. B, S, N |
||
- action: Actions taken [B, L] | ||
- adv: Advantage values [B] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [B, ] |
||
- weight: Attention mask [B, L] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use the extra |
||
clip_ratio (float): PPO clipping ratio, default 0.2 | ||
beta (float): Weight coefficient for KL divergence, default 0.1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a period to the end of each sentence. |
||
|
||
Returns: | ||
Tuple[namedtuple, namedtuple]: | ||
- First namedtuple contains policy_loss | ||
- Second namedtuple contains additional metrics | ||
""" | ||
|
||
# Calculate log probabilities for each token | ||
log_prob_new = torch.log_softmax(data.logit_new, dim=-1) | ||
log_prob_old = torch.log_softmax(data.logit_old, dim=-1) | ||
log_prob_ref = torch.log_softmax(data.logit_ref, dim=-1) | ||
|
||
# Get log probabilities for selected actions | ||
action = data.action.unsqueeze(-1) # [B, L, 1] | ||
per_token_logps = torch.gather(log_prob_new, -1, action).squeeze(-1) | ||
per_token_old_logps = torch.gather(log_prob_old, -1, action).squeeze(-1) | ||
per_token_ref_logps = torch.gather(log_prob_ref, -1, action).squeeze(-1) | ||
|
||
# Calculate KL divergence: exp(q-p) - (q-p) - 1, | ||
# where p is current policy and q is reference policy | ||
per_token_kl = (torch.exp(per_token_ref_logps - per_token_logps) - (per_token_ref_logps - per_token_logps) - 1) | ||
|
||
# Calculate policy ratio | ||
ratio = torch.exp(per_token_logps - per_token_old_logps) | ||
ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) | ||
|
||
# Calculate loss for each token | ||
advantages = data.adv.unsqueeze(1) # [B, 1] | ||
per_token_loss_unclipped = ratio * advantages | ||
per_token_loss_clipped = ratio_clipped * advantages | ||
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped) | ||
|
||
# Add KL divergence regularization term | ||
per_token_loss = per_token_loss + beta * per_token_kl | ||
|
||
# Calculate average loss using weight mask | ||
weight = data.weight if data.weight is not None \ | ||
else torch.ones_like(per_token_loss) | ||
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean() | ||
|
||
# Calculate additional metrics | ||
metrics = { | ||
'mean_kl': ((per_token_kl * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(), | ||
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(), | ||
'mean_clipped': ( | ||
(ratio > (1 + clip_ratio)).float().mean().item() + (ratio < (1 - clip_ratio)).float().mean().item() | ||
), | ||
} | ||
|
||
# Create return namedtuples | ||
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss) | ||
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics) | ||
|
||
return loss_info, metric_info |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from typing import Tuple | ||
from collections import namedtuple | ||
import torch | ||
|
||
rloo_policy_data = namedtuple('rloo_policy_data', ['logit_new', 'logit_old', 'action', 'reward', 'weight']) | ||
|
||
|
||
def rloo_policy_error( | ||
data: namedtuple, | ||
clip_ratio: float = 0.2, | ||
) -> Tuple[namedtuple, namedtuple]: | ||
"""Calculate the policy loss for RLOO | ||
|
||
Args: | ||
data (rloo_policy_data): Data containing the following fields: | ||
- logit_new: Current policy logits [B, L, V] | ||
- logit_old: Old policy logits [B, L, V] | ||
- action: Actions taken [B, L] | ||
- reward: Advantage values [B] | ||
- weight: Attention mask [B, L] | ||
clip_ratio (float): PPO clipping ratio, default 0.2 | ||
|
||
Returns: | ||
Tuple[namedtuple, namedtuple]: | ||
- First namedtuple contains policy_loss | ||
- Second namedtuple contains additional metrics | ||
""" | ||
# Calculate advantage of each action | ||
rloo_k = data.reward.size(0) | ||
baseline = (data.reward.sum(0) - data.reward) / (rloo_k - 1) | ||
adv = data.reward - baseline | ||
adv = adv.flatten() | ||
|
||
# Calculate log probabilities for each token | ||
log_prob_new = torch.log_softmax(data.logit_new, dim=-1) | ||
log_prob_old = torch.log_softmax(data.logit_old, dim=-1) | ||
|
||
# Get log probabilities for selected actions | ||
action = data.action.unsqueeze(-1) # [B, L, 1] | ||
per_token_logps = torch.gather(log_prob_new, -1, action).squeeze(-1) | ||
per_token_old_logps = torch.gather(log_prob_old, -1, action).squeeze(-1) | ||
|
||
# Calculate policy ratio | ||
ratio = torch.exp(per_token_logps - per_token_old_logps) | ||
ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) | ||
|
||
# Calculate loss for each token | ||
advantages = adv.unsqueeze(1) # [B, 1] | ||
per_token_loss_unclipped = ratio * advantages | ||
per_token_loss_clipped = ratio_clipped * advantages | ||
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped) | ||
|
||
# Calculate average loss using weight mask | ||
weight = data.weight if data.weight is not None else (torch.ones_like(per_token_loss)) | ||
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean() | ||
|
||
# Calculate additional metrics | ||
metrics = { | ||
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(), | ||
'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() + (ratio < | ||
(1 - clip_ratio)).float().mean().item(), | ||
'mean_advantage': advantages.mean().item(), | ||
} | ||
|
||
# Create return namedtuples | ||
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if there is just one field, you can directly return it rather than use namedtuple |
||
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can define namedtuple at the beginning of this file |
||
|
||
return loss_info, metric_info |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import pytest | ||
import numpy as np | ||
import torch | ||
# Import GRPO related functions | ||
from ding.rl_utils.grpo import grpo_policy_data, grpo_policy_error | ||
|
||
|
||
@pytest.fixture | ||
def batch_size(): | ||
return 4 | ||
|
||
|
||
@pytest.fixture | ||
def seq_length(): | ||
return 8 | ||
|
||
|
||
@pytest.fixture | ||
def dictionary_num(): | ||
return 1000 | ||
|
||
|
||
@pytest.mark.unittest | ||
def test_grpo_policy_loss_with_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000): | ||
"""Test GRPO policy loss calculation with mask""" | ||
# 1. Create test data | ||
logit_new = (torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True)) | ||
logit_old = logit_new + torch.randn_like(logit_new) * 0.1 | ||
logit_ref = logit_new + torch.randn_like(logit_new) * 0.2 | ||
action = torch.randint(0, vocab_size, (batch_size, seq_length)) | ||
adv = torch.randn(batch_size) | ||
weight = torch.ones(batch_size, seq_length) | ||
weight[:, -2:] = 0 | ||
|
||
# 2. Create grpo_policy_data instance | ||
data = grpo_policy_data( | ||
logit_new=logit_new, # Current policy output | ||
logit_old=logit_old, # Old policy output | ||
logit_ref=logit_ref, # Reference policy output | ||
action=action, # Sampled tokens | ||
adv=adv, # Advantage values | ||
weight=weight # Attention mask | ||
) | ||
|
||
# 3. Calculate GRPO loss | ||
loss, info = grpo_policy_error( | ||
data=data, | ||
clip_ratio=0.2, # PPO clipping ratio | ||
beta=0.1 # KL divergence weight | ||
) | ||
|
||
# 4. Verify outputs | ||
assert isinstance(loss.policy_loss, torch.Tensor) | ||
assert loss.policy_loss.shape == torch.Size([]) # Ensure scalar output | ||
assert not torch.isnan(loss.policy_loss) | ||
assert not torch.isinf(loss.policy_loss) | ||
|
||
# 5. Test gradients | ||
assert logit_new.grad is None | ||
loss.policy_loss.backward() | ||
assert isinstance(logit_new.grad, torch.Tensor) | ||
|
||
# 6. Verify metrics | ||
assert 'mean_kl' in info._asdict() | ||
assert 'mean_ratio' in info._asdict() | ||
assert 'mean_clipped' in info._asdict() | ||
assert all([np.isscalar(v) for v in info._asdict().values()]) | ||
|
||
|
||
@pytest.mark.unittest | ||
def test_grpo_policy_loss_without_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000): | ||
"""Test GRPO policy loss calculation without mask""" | ||
# 1. Create test data | ||
logit_new = torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True) | ||
logit_old = logit_new + torch.randn_like(logit_new) * 0.1 | ||
logit_ref = logit_new + torch.randn_like(logit_new) * 0.2 | ||
action = torch.randint(0, vocab_size, (batch_size, seq_length)) | ||
adv = torch.randn(batch_size) | ||
|
||
# 2. Create grpo_policy_data instance | ||
data = grpo_policy_data( | ||
logit_new=logit_new, # Current policy output | ||
logit_old=logit_old, # Old policy output | ||
logit_ref=logit_ref, # Reference policy output | ||
action=action, # Sampled tokens | ||
adv=adv, # Advantage values | ||
weight=None # No mask | ||
) | ||
|
||
# 3. Calculate GRPO loss | ||
loss, info = grpo_policy_error( | ||
data=data, | ||
clip_ratio=0.2, # PPO clipping ratio | ||
beta=0.1 # KL divergence weight | ||
) | ||
|
||
# 4. Verify outputs | ||
assert isinstance(loss.policy_loss, torch.Tensor) | ||
assert loss.policy_loss.shape == torch.Size([]) # Ensure scalar output | ||
assert not torch.isnan(loss.policy_loss) | ||
assert not torch.isinf(loss.policy_loss) | ||
|
||
# 5. Test gradients | ||
assert logit_new.grad is None | ||
loss.policy_loss.backward() | ||
assert isinstance(logit_new.grad, torch.Tensor) | ||
|
||
# 6. Verify metrics | ||
assert 'mean_kl' in info._asdict() | ||
assert 'mean_ratio' in info._asdict() | ||
assert 'mean_clipped' in info._asdict() | ||
assert all([np.isscalar(v) for v in info._asdict().values()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add paper link