-
Notifications
You must be signed in to change notification settings - Fork 388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(nyz&dcy): add LLM/VLM RLHF loss (PPO/GRPO/RLOO) #857
base: main
Are you sure you want to change the base?
Conversation
6965fd3
to
2e49437
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #857 +/- ##
==========================================
+ Coverage 75.44% 75.52% +0.07%
==========================================
Files 689 698 +9
Lines 56360 56679 +319
==========================================
+ Hits 42523 42807 +284
- Misses 13837 13872 +35
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
d3f6f3f
to
9cb6ca3
Compare
- Add test_grpo_rlhf.py for GRPO unit tests - Add test_rloo_rlhf.py for RLOO unit tests - Update GRPO implementation - Update RLOO implementation
9cb6ca3
to
8d34eac
Compare
eba91a1
to
2cbd9fb
Compare
7bcd64d
to
17a7a71
Compare
ding/rl_utils/grpo.py
Outdated
) -> Tuple[namedtuple, namedtuple]: | ||
"""Calculate the policy loss for GRPO | ||
Args: | ||
data (grpo_policy_data): Data containing the following fields: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
polish comment formats
ding/rl_utils/grpo.py
Outdated
clip_ratio: float = 0.2, | ||
beta: float = 0.1, # Weight coefficient for KL divergence | ||
) -> Tuple[namedtuple, namedtuple]: | ||
"""Calculate the policy loss for GRPO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add paper link
ding/rl_utils/rloo.py
Outdated
} | ||
|
||
# Create return namedtuples | ||
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if there is just one field, you can directly return it rather than use namedtuple
ding/rl_utils/rloo.py
Outdated
|
||
# Create return namedtuples | ||
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss) | ||
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can define namedtuple at the beginning of this file
ding/rl_utils/grpo.py
Outdated
data (grpo_policy_data): Data containing the following fields: | ||
- logit_new: Current policy logits [B, L, V] | ||
- logit_old: Old policy logits [B, L, V] | ||
- logit_ref: Reference policy logits [B, L, V] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
B, S, N
ding/rl_utils/grpo.py
Outdated
- logit_ref: Reference policy logits [B, L, V] | ||
- action: Actions taken [B, L] | ||
- adv: Advantage values [B] | ||
- weight: Attention mask [B, L] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use the extra Shapes
part
ding/rl_utils/grpo.py
Outdated
- adv: Advantage values [B] | ||
- weight: Attention mask [B, L] | ||
clip_ratio (float): PPO clipping ratio, default 0.2 | ||
beta (float): Weight coefficient for KL divergence, default 0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a period to the end of each sentence.
ding/rl_utils/grpo.py
Outdated
- logit_old: Old policy logits [B, L, V] | ||
- logit_ref: Reference policy logits [B, L, V] | ||
- action: Actions taken [B, L] | ||
- adv: Advantage values [B] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[B, ]
7a82a7b
to
5358b8d
Compare
Description
Related Issue
TODO
Check List