-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
Lines 598 to 613 in 133434b
| rl_global_batch = args.rl_global_batch | |
| if args.filter_trajectory: | |
| _world_size = actor_dp_mesh.size() | |
| _data_size = len(trajectory_dataset) | |
| # train_global_batch is divisible by world_size | |
| rl_global_batch = _data_size // _world_size * _world_size | |
| rl_loader = DataLoader( | |
| trajectory_dataset, | |
| batch_size=args.rl_mirco_batch, | |
| num_workers=0, | |
| collate_fn=TrajectoryCollator(pack_batch=True), | |
| shuffle=False, | |
| sampler=RLParallelSampler(trajectory_dataset, actor_dp_mesh, rl_global_batch, shuffle=False), | |
| persistent_workers=False, | |
| ) |
When training large models (especially 32B parameter models) with distributed processing, there's a potential issue where rl_global_batch can become zero if _world_size is large. This causes a ZeroDivisionError in the code. Is there any reasonable method to fix this problem?
Metadata
Metadata
Assignees
Labels
No labels