Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(nyz&dcy): add LLM/VLM RLHF loss (PPO/GRPO/RLOO) #857

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

PaParaZz1
Copy link
Member

@PaParaZz1 PaParaZz1 commented Feb 13, 2025

Description

  • polish and test original PPO
  • implement GRPO and RLOO
  • optimize efficiency in real LLM cases

Related Issue

TODO

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

@PaParaZz1 PaParaZz1 added the enhancement New feature or request label Feb 13, 2025
@PaParaZz1 PaParaZz1 changed the title feature(nyz): add LLM/VLM RLHF loss (PPO/GRPO/RLOO) feature(nyz&dcy): add LLM/VLM RLHF loss (PPO/GRPO/RLOO) Feb 13, 2025
Copy link

codecov bot commented Feb 13, 2025

Codecov Report

Attention: Patch coverage is 98.22222% with 4 lines in your changes missing coverage. Please review.

Project coverage is 75.52%. Comparing base (64efcb3) to head (17a7a71).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
ding/rl_utils/tests/test_grpo_rlhf.py 94.54% 3 Missing ⚠️
ding/rl_utils/ppo.py 85.71% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #857      +/-   ##
==========================================
+ Coverage   75.44%   75.52%   +0.07%     
==========================================
  Files         689      698       +9     
  Lines       56360    56679     +319     
==========================================
+ Hits        42523    42807     +284     
- Misses      13837    13872      +35     
Flag Coverage Δ
unittests 75.52% <98.22%> (+0.07%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@PaParaZz1 PaParaZz1 added the algo Add new algorithm or improve old one label Feb 13, 2025
- Add test_grpo_rlhf.py for GRPO unit tests
- Add test_rloo_rlhf.py for RLOO unit tests
- Update GRPO implementation
- Update RLOO implementation
) -> Tuple[namedtuple, namedtuple]:
"""Calculate the policy loss for GRPO
Args:
data (grpo_policy_data): Data containing the following fields:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish comment formats

clip_ratio: float = 0.2,
beta: float = 0.1, # Weight coefficient for KL divergence
) -> Tuple[namedtuple, namedtuple]:
"""Calculate the policy loss for GRPO
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add paper link

}

# Create return namedtuples
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there is just one field, you can directly return it rather than use namedtuple


# Create return namedtuples
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss)
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can define namedtuple at the beginning of this file

data (grpo_policy_data): Data containing the following fields:
- logit_new: Current policy logits [B, L, V]
- logit_old: Old policy logits [B, L, V]
- logit_ref: Reference policy logits [B, L, V]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

B, S, N

- logit_ref: Reference policy logits [B, L, V]
- action: Actions taken [B, L]
- adv: Advantage values [B]
- weight: Attention mask [B, L]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the extra Shapes part

- adv: Advantage values [B]
- weight: Attention mask [B, L]
clip_ratio (float): PPO clipping ratio, default 0.2
beta (float): Weight coefficient for KL divergence, default 0.1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a period to the end of each sentence.

- logit_old: Old policy logits [B, L, V]
- logit_ref: Reference policy logits [B, L, V]
- action: Actions taken [B, L]
- adv: Advantage values [B]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[B, ]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algo Add new algorithm or improve old one enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants