1
1
import os
2
2
import os .path
3
3
from abc import ABC
4
- from typing import Any , Callable , Dict , List , Optional
5
4
from collections import Counter
5
+ from typing import Any , Callable , Dict , List , Optional
6
6
7
7
import torch
8
8
import torch .distributed as dist
@@ -261,9 +261,8 @@ def fit(
261
261
else :
262
262
status = {}
263
263
264
-
265
264
## log acc change
266
- accuracy_ = torch .cat ([experience .info ["accuracy_reward " ] for experience in experiences ])
265
+ accuracy_ = torch .cat ([experience .info ["accuracy_rewards " ] for experience in experiences ])
267
266
accuracy_ = accuracy_ .reshape (- 1 , args .n_samples_per_prompt ).to (device = "cuda" )
268
267
accuracy_ = torch .mean (accuracy_ , dim = - 1 )
269
268
accuracy_counts = sorted (Counter (accuracy_ .tolist ()).items ())
@@ -274,12 +273,13 @@ def fit(
274
273
status ["easy_counts" ] = easy_counts
275
274
status ["mid_counts" ] = mid_counts
276
275
print ("=== Accuracy distribution ===:" , " " .join (f"{ k :.2f} :{ v } " for k , v in accuracy_counts ))
277
-
276
+
278
277
## log the entropy for a group of responses
279
- joint_action_log_probs_ = torch .cat ([(experience .action_log_probs * experience .action_mask ).sum (- 1 ) for experience in experiences ])
278
+ joint_action_log_probs_ = torch .cat (
279
+ [(experience .action_log_probs * experience .action_mask ).sum (- 1 ) for experience in experiences ]
280
+ )
280
281
status ["entropy_per_prompt" ] = - joint_action_log_probs_ .mean ().item ()
281
282
282
-
283
283
status ["accuracy_rewards_original" ] = accuracy_rewards_original
284
284
285
285
if "kl" in status :
@@ -496,39 +496,35 @@ def training_step_actor(self, experience: Experience) -> Dict[str, float]:
496
496
self .strategy .optimizer_step (self .actor_optim , self .actor , self .actor_scheduler , name = "actor" )
497
497
if self .ema_model :
498
498
self .strategy .moving_average (self .actor , self .ema_model , self .ema_beta , "cuda" )
499
-
500
-
499
+
501
500
## compute the ratio and grad_norm to log
502
501
with torch .no_grad ():
503
502
ratio = (action_log_probs - old_action_log_probs ).exp ().detach ()
504
503
ratio_mean = masked_mean (ratio , experience .action_mask , dim = - 1 ).mean ()
505
504
eps = 0.2
506
505
ratio_clip_upper = masked_mean ((ratio > 1 + eps ), experience .action_mask , dim = - 1 ).mean ()
507
506
ratio_clip_lower = masked_mean ((ratio < 1 - eps ), experience .action_mask , dim = - 1 ).mean ()
508
-
509
- grad_norm = nn .utils .clip_grad_norm_ (
510
- self .actor .parameters (),
511
- max_norm = 1e6 ,
512
- norm_type = 2
513
- )
514
-
515
- correct_response_length = (experience .info ["response_length" ] * experience .info ["accuracy_reward" ]).sum () / (experience .info ["accuracy_reward" ].sum ()).clamp (min = 1.0 )
516
- wrong_response_length = (experience .info ["response_length" ] * (1 - experience .info ["accuracy_reward" ])).sum () / ((1 - experience .info ["accuracy_reward" ]).sum ()).clamp (min = 1.0 )
517
507
508
+ grad_norm = nn .utils .clip_grad_norm_ (self .actor .parameters (), max_norm = 1e6 , norm_type = 2 )
509
+
510
+ correct_response_length = (
511
+ experience .info ["response_length" ] * experience .info ["accuracy_rewards" ]
512
+ ).sum () / (experience .info ["accuracy_rewards" ].sum ()).clamp (min = 1.0 )
513
+ wrong_response_length = (
514
+ experience .info ["response_length" ] * (1 - experience .info ["accuracy_rewards" ])
515
+ ).sum () / ((1 - experience .info ["accuracy_rewards" ]).sum ()).clamp (min = 1.0 )
518
516
519
517
# status
520
518
status = {"policy_loss" : actor_loss .item (), "actor_lr" : self .actor_scheduler .get_last_lr ()[0 ]}
521
519
522
-
523
520
status ["ratio" ] = ratio_mean .item ()
524
521
status ["ratio_clip_upper" ] = ratio_clip_upper .item ()
525
522
status ["ratio_clip_lower" ] = ratio_clip_lower .item ()
526
523
status ["grad_norm" ] = grad_norm .item ()
527
-
524
+
528
525
status ["correct_response_length" ] = correct_response_length .item ()
529
526
status ["wrong_response_length" ] = wrong_response_length .item ()
530
527
531
-
532
528
if self .pretrain_dataloader is not None :
533
529
status ["ptx_loss" ] = ptx_loss .item ()
534
530
for k , v in experience .info .items ():
0 commit comments