Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(nyz&dcy): add LLM/VLM RLHF loss (PPO/GRPO/RLOO) #857

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1429,3 +1429,8 @@ collect_demo_data_config.py
events.*

evogym/*
ding/example/*
ding/framework/middleware/tests/wandb/
ding/.style.yapf
ding/format.sh
ding/framework/middleware_v3/
77 changes: 77 additions & 0 deletions ding/rl_utils/grpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Tuple
from collections import namedtuple
import torch

grpo_policy_data = namedtuple('grpo_policy_data', ['logit_new', 'logit_old', 'logit_ref', 'action', 'adv', 'weight'])


def grpo_policy_error(
data: namedtuple,
clip_ratio: float = 0.2,
beta: float = 0.1, # Weight coefficient for KL divergence
) -> Tuple[namedtuple, namedtuple]:
"""Calculate the policy loss for GRPO
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add paper link

Args:
data (grpo_policy_data): Data containing the following fields:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish comment formats

- logit_new: Current policy logits [B, L, V]
- logit_old: Old policy logits [B, L, V]
- logit_ref: Reference policy logits [B, L, V]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

B, S, N

- action: Actions taken [B, L]
- adv: Advantage values [B]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[B, ]

- weight: Attention mask [B, L]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the extra Shapes part

clip_ratio (float): PPO clipping ratio, default 0.2
beta (float): Weight coefficient for KL divergence, default 0.1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a period to the end of each sentence.


Returns:
Tuple[namedtuple, namedtuple]:
- First namedtuple contains policy_loss
- Second namedtuple contains additional metrics
"""

# Calculate log probabilities for each token
log_prob_new = torch.log_softmax(data.logit_new, dim=-1)
log_prob_old = torch.log_softmax(data.logit_old, dim=-1)
log_prob_ref = torch.log_softmax(data.logit_ref, dim=-1)

# Get log probabilities for selected actions
action = data.action.unsqueeze(-1) # [B, L, 1]
per_token_logps = torch.gather(log_prob_new, -1, action).squeeze(-1)
per_token_old_logps = torch.gather(log_prob_old, -1, action).squeeze(-1)
per_token_ref_logps = torch.gather(log_prob_ref, -1, action).squeeze(-1)

# Calculate KL divergence: exp(q-p) - (q-p) - 1,
# where p is current policy and q is reference policy
per_token_kl = (torch.exp(per_token_ref_logps - per_token_logps) - (per_token_ref_logps - per_token_logps) - 1)

# Calculate policy ratio
ratio = torch.exp(per_token_logps - per_token_old_logps)
ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio)

# Calculate loss for each token
advantages = data.adv.unsqueeze(1) # [B, 1]
per_token_loss_unclipped = ratio * advantages
per_token_loss_clipped = ratio_clipped * advantages
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)

# Add KL divergence regularization term
per_token_loss = per_token_loss + beta * per_token_kl

# Calculate average loss using weight mask
weight = data.weight if data.weight is not None \
else torch.ones_like(per_token_loss)
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()

# Calculate additional metrics
metrics = {
'mean_kl': ((per_token_kl * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_clipped': (
(ratio > (1 + clip_ratio)).float().mean().item() + (ratio < (1 - clip_ratio)).float().mean().item()
),
}

# Create return namedtuples
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss)
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics)

return loss_info, metric_info
43 changes: 29 additions & 14 deletions ding/rl_utils/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,21 @@
return ppo_loss(policy_output.policy_loss, value_loss, policy_output.entropy_loss), policy_info


def ppo_policy_error(data: namedtuple,
clip_ratio: float = 0.2,
dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]:
'''
def ppo_policy_error(
data: namedtuple,
clip_ratio: float = 0.2,
dual_clip: Optional[float] = None,
entropy_bonus: bool = True
) -> Tuple[namedtuple, namedtuple]:
"""
Overview:
Get PPO policy loss
Get PPO policy loss (both for classical RL in control/video games and LLM/VLM RLHF).
Arguments:
- data (:obj:`namedtuple`): ppo input data with fieids shown in ``ppo_policy_data``
- clip_ratio (:obj:`float`): clip value for ratio
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
defaults to 5.0, if you don't want to use it, set this parameter to None
- data (:obj:`namedtuple`): Ppo input data with fieids shown in ``ppo_policy_data``.
- clip_ratio (:obj:`float`): Clip value for ratio, defaults to 0.2.
- dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \
defaults to 5.0, if you don't want to use it, set this parameter to None
- entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it.
Returns:
- ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
Expand All @@ -136,18 +140,29 @@
>>> weight=torch.ones(3),
>>> )
>>> loss, info = ppo_policy_error(data)
'''

.. note::
This function can be extended from `B` to more parallel dimensions, like `(B, S)`, where `S` is the
sequence length in LLM/VLM.

.. note::
For the action mask often used in LLM/VLM, users can set the `weight` to the action mask.
"""
logit_new, logit_old, action, adv, weight = data
if weight is None:
weight = torch.ones_like(adv)
dist_new = torch.distributions.categorical.Categorical(logits=logit_new)
dist_old = torch.distributions.categorical.Categorical(logits=logit_old)
logp_new = dist_new.log_prob(action)
logp_old = dist_old.log_prob(action)
dist_new_entropy = dist_new.entropy()
if dist_new_entropy.shape != weight.shape:
dist_new_entropy = dist_new.entropy().mean(dim=1)
entropy_loss = (dist_new_entropy * weight).mean()

if entropy_bonus:
dist_new_entropy = dist_new.entropy()
if dist_new_entropy.shape != weight.shape: # for the multi-agent rl case
dist_new_entropy = dist_new.entropy().mean(dim=1)

Check warning on line 162 in ding/rl_utils/ppo.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/ppo.py#L162

Added line #L162 was not covered by tests
entropy_loss = (dist_new_entropy * weight).mean()
else:
entropy_loss = torch.tensor(0.0)
# policy_loss
ratio = torch.exp(logp_new - logp_old)
if ratio.shape != adv.shape:
Expand Down
69 changes: 69 additions & 0 deletions ding/rl_utils/rloo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Tuple
from collections import namedtuple
import torch

rloo_policy_data = namedtuple('rloo_policy_data', ['logit_new', 'logit_old', 'action', 'reward', 'weight'])


def rloo_policy_error(
data: namedtuple,
clip_ratio: float = 0.2,
) -> Tuple[namedtuple, namedtuple]:
"""Calculate the policy loss for RLOO

Args:
data (rloo_policy_data): Data containing the following fields:
- logit_new: Current policy logits [B, L, V]
- logit_old: Old policy logits [B, L, V]
- action: Actions taken [B, L]
- reward: Advantage values [B]
- weight: Attention mask [B, L]
clip_ratio (float): PPO clipping ratio, default 0.2

Returns:
Tuple[namedtuple, namedtuple]:
- First namedtuple contains policy_loss
- Second namedtuple contains additional metrics
"""
# Calculate advantage of each action
rloo_k = data.reward.size(0)
baseline = (data.reward.sum(0) - data.reward) / (rloo_k - 1)
adv = data.reward - baseline
adv = adv.flatten()

# Calculate log probabilities for each token
log_prob_new = torch.log_softmax(data.logit_new, dim=-1)
log_prob_old = torch.log_softmax(data.logit_old, dim=-1)

# Get log probabilities for selected actions
action = data.action.unsqueeze(-1) # [B, L, 1]
per_token_logps = torch.gather(log_prob_new, -1, action).squeeze(-1)
per_token_old_logps = torch.gather(log_prob_old, -1, action).squeeze(-1)

# Calculate policy ratio
ratio = torch.exp(per_token_logps - per_token_old_logps)
ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio)

# Calculate loss for each token
advantages = adv.unsqueeze(1) # [B, 1]
per_token_loss_unclipped = ratio * advantages
per_token_loss_clipped = ratio_clipped * advantages
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)

# Calculate average loss using weight mask
weight = data.weight if data.weight is not None else (torch.ones_like(per_token_loss))
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()

# Calculate additional metrics
metrics = {
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() + (ratio <
(1 - clip_ratio)).float().mean().item(),
'mean_advantage': advantages.mean().item(),
}

# Create return namedtuples
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there is just one field, you can directly return it rather than use namedtuple

metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can define namedtuple at the beginning of this file


return loss_info, metric_info
112 changes: 112 additions & 0 deletions ding/rl_utils/tests/test_grpo_rlhf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import pytest
import numpy as np
import torch
# Import GRPO related functions
from ding.rl_utils.grpo import grpo_policy_data, grpo_policy_error


@pytest.fixture
def batch_size():
return 4

Check warning on line 10 in ding/rl_utils/tests/test_grpo_rlhf.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_grpo_rlhf.py#L10

Added line #L10 was not covered by tests


@pytest.fixture
def seq_length():
return 8

Check warning on line 15 in ding/rl_utils/tests/test_grpo_rlhf.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_grpo_rlhf.py#L15

Added line #L15 was not covered by tests


@pytest.fixture
def dictionary_num():
return 1000

Check warning on line 20 in ding/rl_utils/tests/test_grpo_rlhf.py

View check run for this annotation

Codecov / codecov/patch

ding/rl_utils/tests/test_grpo_rlhf.py#L20

Added line #L20 was not covered by tests


@pytest.mark.unittest
def test_grpo_policy_loss_with_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000):
"""Test GRPO policy loss calculation with mask"""
# 1. Create test data
logit_new = (torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True))
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
logit_ref = logit_new + torch.randn_like(logit_new) * 0.2
action = torch.randint(0, vocab_size, (batch_size, seq_length))
adv = torch.randn(batch_size)
weight = torch.ones(batch_size, seq_length)
weight[:, -2:] = 0

# 2. Create grpo_policy_data instance
data = grpo_policy_data(
logit_new=logit_new, # Current policy output
logit_old=logit_old, # Old policy output
logit_ref=logit_ref, # Reference policy output
action=action, # Sampled tokens
adv=adv, # Advantage values
weight=weight # Attention mask
)

# 3. Calculate GRPO loss
loss, info = grpo_policy_error(
data=data,
clip_ratio=0.2, # PPO clipping ratio
beta=0.1 # KL divergence weight
)

# 4. Verify outputs
assert isinstance(loss.policy_loss, torch.Tensor)
assert loss.policy_loss.shape == torch.Size([]) # Ensure scalar output
assert not torch.isnan(loss.policy_loss)
assert not torch.isinf(loss.policy_loss)

# 5. Test gradients
assert logit_new.grad is None
loss.policy_loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)

# 6. Verify metrics
assert 'mean_kl' in info._asdict()
assert 'mean_ratio' in info._asdict()
assert 'mean_clipped' in info._asdict()
assert all([np.isscalar(v) for v in info._asdict().values()])


@pytest.mark.unittest
def test_grpo_policy_loss_without_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000):
"""Test GRPO policy loss calculation without mask"""
# 1. Create test data
logit_new = torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True)
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
logit_ref = logit_new + torch.randn_like(logit_new) * 0.2
action = torch.randint(0, vocab_size, (batch_size, seq_length))
adv = torch.randn(batch_size)

# 2. Create grpo_policy_data instance
data = grpo_policy_data(
logit_new=logit_new, # Current policy output
logit_old=logit_old, # Old policy output
logit_ref=logit_ref, # Reference policy output
action=action, # Sampled tokens
adv=adv, # Advantage values
weight=None # No mask
)

# 3. Calculate GRPO loss
loss, info = grpo_policy_error(
data=data,
clip_ratio=0.2, # PPO clipping ratio
beta=0.1 # KL divergence weight
)

# 4. Verify outputs
assert isinstance(loss.policy_loss, torch.Tensor)
assert loss.policy_loss.shape == torch.Size([]) # Ensure scalar output
assert not torch.isnan(loss.policy_loss)
assert not torch.isinf(loss.policy_loss)

# 5. Test gradients
assert logit_new.grad is None
loss.policy_loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)

# 6. Verify metrics
assert 'mean_kl' in info._asdict()
assert 'mean_ratio' in info._asdict()
assert 'mean_clipped' in info._asdict()
assert all([np.isscalar(v) for v in info._asdict().values()])
Loading
Loading