Skip to content

Commit 0a769fc

Browse files
committed
update
1 parent 40c7aa3 commit 0a769fc

File tree

1 file changed

+16
-20
lines changed

1 file changed

+16
-20
lines changed

openrlhf/trainer/ppo_trainer.py

+16-20
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22
import os.path
33
from abc import ABC
4-
from typing import Any, Callable, Dict, List, Optional
54
from collections import Counter
5+
from typing import Any, Callable, Dict, List, Optional
66

77
import torch
88
import torch.distributed as dist
@@ -261,9 +261,8 @@ def fit(
261261
else:
262262
status = {}
263263

264-
265264
## log acc change
266-
accuracy_ = torch.cat([experience.info["accuracy_reward"] for experience in experiences])
265+
accuracy_ = torch.cat([experience.info["accuracy_rewards"] for experience in experiences])
267266
accuracy_ = accuracy_.reshape(-1, args.n_samples_per_prompt).to(device="cuda")
268267
accuracy_ = torch.mean(accuracy_, dim=-1)
269268
accuracy_counts = sorted(Counter(accuracy_.tolist()).items())
@@ -274,12 +273,13 @@ def fit(
274273
status["easy_counts"] = easy_counts
275274
status["mid_counts"] = mid_counts
276275
print("=== Accuracy distribution ===:", " ".join(f"{k:.2f}:{v}" for k, v in accuracy_counts))
277-
276+
278277
## log the entropy for a group of responses
279-
joint_action_log_probs_ = torch.cat([(experience.action_log_probs * experience.action_mask).sum(-1) for experience in experiences])
278+
joint_action_log_probs_ = torch.cat(
279+
[(experience.action_log_probs * experience.action_mask).sum(-1) for experience in experiences]
280+
)
280281
status["entropy_per_prompt"] = -joint_action_log_probs_.mean().item()
281282

282-
283283
status["accuracy_rewards_original"] = accuracy_rewards_original
284284

285285
if "kl" in status:
@@ -496,39 +496,35 @@ def training_step_actor(self, experience: Experience) -> Dict[str, float]:
496496
self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor")
497497
if self.ema_model:
498498
self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cuda")
499-
500-
499+
501500
## compute the ratio and grad_norm to log
502501
with torch.no_grad():
503502
ratio = (action_log_probs - old_action_log_probs).exp().detach()
504503
ratio_mean = masked_mean(ratio, experience.action_mask, dim=-1).mean()
505504
eps = 0.2
506505
ratio_clip_upper = masked_mean((ratio > 1 + eps), experience.action_mask, dim=-1).mean()
507506
ratio_clip_lower = masked_mean((ratio < 1 - eps), experience.action_mask, dim=-1).mean()
508-
509-
grad_norm = nn.utils.clip_grad_norm_(
510-
self.actor.parameters(),
511-
max_norm=1e6,
512-
norm_type=2
513-
)
514-
515-
correct_response_length = (experience.info["response_length"] * experience.info["accuracy_reward"]).sum() / (experience.info["accuracy_reward"].sum()).clamp(min=1.0)
516-
wrong_response_length = (experience.info["response_length"] * (1 - experience.info["accuracy_reward"])).sum() / ((1 - experience.info["accuracy_reward"]).sum()).clamp(min=1.0)
517507

508+
grad_norm = nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=1e6, norm_type=2)
509+
510+
correct_response_length = (
511+
experience.info["response_length"] * experience.info["accuracy_rewards"]
512+
).sum() / (experience.info["accuracy_rewards"].sum()).clamp(min=1.0)
513+
wrong_response_length = (
514+
experience.info["response_length"] * (1 - experience.info["accuracy_rewards"])
515+
).sum() / ((1 - experience.info["accuracy_rewards"]).sum()).clamp(min=1.0)
518516

519517
# status
520518
status = {"policy_loss": actor_loss.item(), "actor_lr": self.actor_scheduler.get_last_lr()[0]}
521519

522-
523520
status["ratio"] = ratio_mean.item()
524521
status["ratio_clip_upper"] = ratio_clip_upper.item()
525522
status["ratio_clip_lower"] = ratio_clip_lower.item()
526523
status["grad_norm"] = grad_norm.item()
527-
524+
528525
status["correct_response_length"] = correct_response_length.item()
529526
status["wrong_response_length"] = wrong_response_length.item()
530527

531-
532528
if self.pretrain_dataloader is not None:
533529
status["ptx_loss"] = ptx_loss.item()
534530
for k, v in experience.info.items():

0 commit comments

Comments
 (0)