Skip to content

Commit

Permalink
(dcy) Polish style:Use selective log-softmax to reduce peak vram cons…
Browse files Browse the repository at this point in the history
…umption
  • Loading branch information
Berit-chengyi committed Feb 20, 2025
1 parent 17a7a71 commit 5358b8d
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 103 deletions.
109 changes: 69 additions & 40 deletions ding/rl_utils/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,75 @@
import torch

grpo_policy_data = namedtuple('grpo_policy_data', ['logit_new', 'logit_old', 'logit_ref', 'action', 'adv', 'weight'])
MetricInfo = namedtuple('MetricInfo', ['mean_kl', 'mean_ratio', 'mean_clipped'])


def naive_method(logits, index):
# Calculate log probabilities for each token
log_prob_new = torch.log_softmax(logits, dim=-1)
# Get log probabilities for selected actions
index = index.unsqueeze(-1) # [B, L, 1]
per_token_logps = torch.gather(log_prob_new, -1, index).squeeze(-1)
return per_token_logps


def efficient_method(logits, index):
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
per_token_logps = []
for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption
row_logps = torch.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps


def less_efficient_method(logits, action):
dist = torch.distributions.categorical.Categorical(logits=logits)
logp = dist.log_prob(action)
return logp


def grpo_policy_error(
data: namedtuple,
logpro_cal=efficient_method, # Method to calculate the log probabilities
clip_ratio: float = 0.2,
beta: float = 0.1, # Weight coefficient for KL divergence
beta: float = 0.1 # Weight coefficient for KL divergence
) -> Tuple[namedtuple, namedtuple]:
"""Calculate the policy loss for GRPO
Args:
data (grpo_policy_data): Data containing the following fields:
- logit_new: Current policy logits [B, L, V]
- logit_old: Old policy logits [B, L, V]
- logit_ref: Reference policy logits [B, L, V]
- action: Actions taken [B, L]
- adv: Advantage values [B]
- weight: Attention mask [B, L]
clip_ratio (float): PPO clipping ratio, default 0.2
beta (float): Weight coefficient for KL divergence, default 0.1
Returns:
Tuple[namedtuple, namedtuple]:
- First namedtuple contains policy_loss
- Second namedtuple contains additional metrics
"""
Overview:
Implementation of Generalized Reward-Conditioned Policy Optimization( arXiv:2405.20304) .
Arguments:
- data (:obj:`namedtuple`): the grpo input data with fields shown in ``grpo_policy_data``.
- clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2.
- beta (:obj:`float`): weight coefficient for KL divergence regularization, defaults to 0.1.
Returns:
- loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor.
- grpo_info (:obj:`namedtuple`): the grpo optim information for monitoring, all of them are Python scalar.
Shapes:
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length,
and V is vocabulary size.
- logit_old (:obj:`torch.FloatTensor`): :math:`(B, S, V)`.
- logit_ref (:obj:`torch.FloatTensor`): :math:`(B, S, V)`.
- action (:obj:`torch.LongTensor`): :math:`(B, S)`.
- adv (:obj:`torch.FloatTensor`): :math:`(B, )`.
- weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, S)`.
- policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor.
- mean_kl (:obj:`float`): mean KL divergence between current and reference policy.
- mean_ratio (:obj:`float`): mean probability ratio.
- mean_clipped (:obj:`float`): proportion of clipped probability ratios.
"""

# Calculate log probabilities for each token
log_prob_new = torch.log_softmax(data.logit_new, dim=-1)
log_prob_old = torch.log_softmax(data.logit_old, dim=-1)
log_prob_ref = torch.log_softmax(data.logit_ref, dim=-1)

# Get log probabilities for selected actions
action = data.action.unsqueeze(-1) # [B, L, 1]
per_token_logps = torch.gather(log_prob_new, -1, action).squeeze(-1)
per_token_old_logps = torch.gather(log_prob_old, -1, action).squeeze(-1)
per_token_ref_logps = torch.gather(log_prob_ref, -1, action).squeeze(-1)
# Calculate log probabilities for selected token
per_token_logps = logpro_cal(data.logit_new, data.action)
per_token_ref_logps = logpro_cal(data.logit_ref, data.action)
per_token_old_logps = logpro_cal(data.logit_old, data.action)

# Calculate KL divergence: exp(q-p) - (q-p) - 1,
# where p is current policy and q is reference policy
Expand All @@ -62,16 +96,11 @@ def grpo_policy_error(
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()

# Calculate additional metrics
metrics = {
'mean_kl': ((per_token_kl * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_clipped': (
(ratio > (1 + clip_ratio)).float().mean().item() + (ratio < (1 - clip_ratio)).float().mean().item()
),
}

# Create return namedtuples
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss)
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics)

return loss_info, metric_info
metric_info = MetricInfo(
mean_kl=((per_token_kl * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
mean_ratio=((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
mean_clipped=(ratio > (1 + clip_ratio)).float().mean().item() + (ratio <
(1 - clip_ratio)).float().mean().item(),
)

return loss, metric_info
101 changes: 66 additions & 35 deletions ding/rl_utils/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,77 @@
import torch

rloo_policy_data = namedtuple('rloo_policy_data', ['logit_new', 'logit_old', 'action', 'reward', 'weight'])
MetricInfo = namedtuple('MetricInfo', ['mean_ratio', 'mean_clipped', 'mean_advantage'])


def naive_method(logits, index):
# Calculate log probabilities for each token
log_prob_new = torch.log_softmax(logits, dim=-1)
# Get log probabilities for selected actions
index = index.unsqueeze(-1) # [B, L, 1]
per_token_logps = torch.gather(log_prob_new, -1, index).squeeze(-1)
return per_token_logps


def efficient_method(logits, index):
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
per_token_logps = []
for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption
row_logps = torch.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps


def less_efficient_method(logits, action):
dist = torch.distributions.categorical.Categorical(logits=logits)
logp = dist.log_prob(action)
return logp


def rloo_policy_error(
data: namedtuple,
logpro_cal=efficient_method, # Method to calculate the log probabilities
clip_ratio: float = 0.2,
) -> Tuple[namedtuple, namedtuple]:
"""Calculate the policy loss for RLOO
Args:
data (rloo_policy_data): Data containing the following fields:
- logit_new: Current policy logits [B, L, V]
- logit_old: Old policy logits [B, L, V]
- action: Actions taken [B, L]
- reward: Advantage values [B]
- weight: Attention mask [B, L]
clip_ratio (float): PPO clipping ratio, default 0.2
Returns:
Tuple[namedtuple, namedtuple]:
- First namedtuple contains policy_loss
- Second namedtuple contains additional metrics
"""
"""
Overview:
Implementation of Rejection Learning with Optimistic Optimization (RLOO) for RLHF.
Arguments:
- data (:obj:`namedtuple`): the rloo input data with fields shown in ``rloo_policy_data``.
- clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2.
Returns:
- loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor.
- rloo_info (:obj:`namedtuple`): the rloo optim information for monitoring, all of them are Python scalar.
Shapes:
- logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length,
and V is vocabulary size.
- logit_old (:obj:`torch.FloatTensor`): :math:`(B, S, V)`.
- action (:obj:`torch.LongTensor`): :math:`(B, S)`.
- reward (:obj:`torch.FloatTensor`): :math:`(K, B)`, where K is the number of samples per prompt.
- weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, S)`.
- policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor.
- mean_ratio (:obj:`float`): mean probability ratio.
- mean_clipped (:obj:`float`): proportion of clipped probability ratios.
- mean_advantage (:obj:`float`): mean advantage value.
"""

# Calculate advantage of each action
rloo_k = data.reward.size(0)
baseline = (data.reward.sum(0) - data.reward) / (rloo_k - 1)
adv = data.reward - baseline
adv = adv.flatten()

# Calculate log probabilities for each token
log_prob_new = torch.log_softmax(data.logit_new, dim=-1)
log_prob_old = torch.log_softmax(data.logit_old, dim=-1)

# Get log probabilities for selected actions
action = data.action.unsqueeze(-1) # [B, L, 1]
per_token_logps = torch.gather(log_prob_new, -1, action).squeeze(-1)
per_token_old_logps = torch.gather(log_prob_old, -1, action).squeeze(-1)
per_token_logps = logpro_cal(data.logit_new, data.action)
per_token_old_logps = logpro_cal(data.logit_old, data.action)

# Calculate policy ratio
ratio = torch.exp(per_token_logps - per_token_old_logps)
Expand All @@ -55,15 +90,11 @@ def rloo_policy_error(
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()

# Calculate additional metrics
metrics = {
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() + (ratio <
(1 - clip_ratio)).float().mean().item(),
'mean_advantage': advantages.mean().item(),
}

# Create return namedtuples
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss)
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics)

return loss_info, metric_info
metric_info = MetricInfo(
mean_ratio=((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
mean_clipped=(ratio > (1 + clip_ratio)).float().mean().item() + (ratio <
(1 - clip_ratio)).float().mean().item(),
mean_advantage=advantages.mean().item(),
)

return loss, metric_info
79 changes: 62 additions & 17 deletions ding/rl_utils/tests/test_grpo_rlhf.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
import numpy as np
import torch
# Import GRPO related functions
from ding.rl_utils.grpo import grpo_policy_data, grpo_policy_error
from ding.rl_utils.grpo import (
grpo_policy_data, grpo_policy_error, naive_method, efficient_method, less_efficient_method
)


@pytest.fixture
Expand Down Expand Up @@ -43,21 +44,17 @@ def test_grpo_policy_loss_with_mask(batch_size: int = 4, seq_length: int = 8, vo
)

# 3. Calculate GRPO loss
loss, info = grpo_policy_error(
data=data,
clip_ratio=0.2, # PPO clipping ratio
beta=0.1 # KL divergence weight
)
loss, info = grpo_policy_error(data=data)

# 4. Verify outputs
assert isinstance(loss.policy_loss, torch.Tensor)
assert loss.policy_loss.shape == torch.Size([]) # Ensure scalar output
assert not torch.isnan(loss.policy_loss)
assert not torch.isinf(loss.policy_loss)
assert isinstance(loss, torch.Tensor)
assert loss.shape == torch.Size([]) # Ensure scalar output
assert not torch.isnan(loss)
assert not torch.isinf(loss)

# 5. Test gradients
assert logit_new.grad is None
loss.policy_loss.backward()
loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)

# 6. Verify metrics
Expand Down Expand Up @@ -95,18 +92,66 @@ def test_grpo_policy_loss_without_mask(batch_size: int = 4, seq_length: int = 8,
)

# 4. Verify outputs
assert isinstance(loss.policy_loss, torch.Tensor)
assert loss.policy_loss.shape == torch.Size([]) # Ensure scalar output
assert not torch.isnan(loss.policy_loss)
assert not torch.isinf(loss.policy_loss)
assert isinstance(loss, torch.Tensor)
assert loss.shape == torch.Size([]) # Ensure scalar output
assert not torch.isnan(loss)
assert not torch.isinf(loss)

# 5. Test gradients
assert logit_new.grad is None
loss.policy_loss.backward()
loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)

# 6. Verify metrics
assert 'mean_kl' in info._asdict()
assert 'mean_ratio' in info._asdict()
assert 'mean_clipped' in info._asdict()
assert all([np.isscalar(v) for v in info._asdict().values()])


@pytest.mark.benchmark
def test_log_prob_methods_benchmark():
"""Benchmark different methods for calculating log probabilities"""
# 设置参数
vocab_size = 32768
seq_len = 1024
batch_size = 16
device = "cuda" if torch.cuda.is_available() else "cpu"

# 生成测试数据
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)

# 预热 GPU
for _ in range(3):
_ = naive_method(logits[:2], input_ids[:2])
torch.cuda.synchronize()

# 测试每个方法
results = {}
for method, name in [(naive_method, "Naive"), (efficient_method, "Efficient"),
(less_efficient_method, "Less_Efficient")]:
# 运行多次并计时
times = []
for _ in range(10):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
result = method(logits, input_ids)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
if len(times) == 1:
results[name] = result

# 计算统计信息
mean_time = np.mean(times)
std_time = np.std(times)
print(f"\n{name}: {mean_time:.2f} ± {std_time:.2f} ms")

# 验证结果正确性
for name, result in results.items():
if name != "Naive":
diff = (results["Naive"] - result).abs().max().item()
assert diff < 1e-5, f"Results mismatch between Naive and {name}: {diff}"
Loading

0 comments on commit 5358b8d

Please sign in to comment.