-
Notifications
You must be signed in to change notification settings - Fork 676
[rl] refactor grader and trainer generator actor #2244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/wwwjn/7/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
|
|
||
|
|
||
| @dataclass | ||
| class TrajectoryData: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we deprecated the name trajectory which is intrinsically ambiguous, but I don't know what we replace it by, Episode?
| rewards: torch.Tensor | ||
|
|
||
|
|
||
| class Scorer(Actor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought you chose to use Grader. Not sure what's the difference but but aligned.
| def _load_initial_weights(self, model: torch.nn.Module, model_path: str) -> None: | ||
| """Load initial weights from HuggingFace checkpoint.""" | ||
| from torchtitan.experiments.rl.vllm_compat.weights.converter import ( | ||
| vllm_to_torchtitan, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why using this function instead of our utils like from_hf?
| q = q.transpose(1, 2) | ||
| k = k.transpose(1, 2) | ||
| v = v.transpose(1, 2) | ||
| # vLLM attention expects bfloat16 / inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can't just happen for attention.
For torchtitan, by default dtype is fp32, and mixed precision is handled by FSDP so under pure TP forward dtype is fp32.
If vllm by default use overall bf16, we should match. O/w this is another place where torchtitan-native vllm forward would be slow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes dtype difference could be a reason that we are 40% slow when TP is not enabled
| This demonstrates: | ||
| 1. Distributed actor architecture with Generator (vLLM) and Trainer (TorchTitan) components | ||
| 1. Distributed actor architecture with Generator (vLLM), Scorer, and Trainer (TorchTitan) components | ||
| 2. File based weight synchronization between trainer and generator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this still true?
| job_config, # Pass full job_config | ||
| ) | ||
|
|
||
| # Spawn scorer on trainer mesh (can share resources with trainer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would like to learn more on how Scorer/Grader work with trainer / generator.
Naively I would think they should be put on generator mesh, not trainer_mesh, although they may be the same and you are only using gpus=0 right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/meta-pytorch/monarch/blob/main/docs/source/examples/grpo_actor.py#L505
I follow the practice here, the scorer is spawned on trainer mesh. My intuition is the main bottleneck is generator (generator takes longer time), so we want to put more work (eg, calculate rewards + advantages) on trainer side instead.
If we only think about algorithm , we can put it on trainer or generator. If we put it on generator, the generated "episode" will be scored episode. If we put it on trainer, the generator can just pass "unscored" episode to trainer
[ghstack-poisoned]
Stack from ghstack (oldest at bottom):