-
-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient synchronization in data-parallel trainers #12
Comments
cgarciae
changed the title
Gradient synchronization in data-parallel traininers
Gradient synchronization in data-parallel trainers
Feb 28, 2024
Thanks for noticing this! It's often challenging to test these portions due to the unavailability of a personal multi-GPU setup for development. However, I will be accessing 2 GPUs around 10th March. Will immediately examine this but you are more than welcome to make corrections from your end if convenient, I would in fact very much appreciate that. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey, great job with nanodl!
I was just looking through the code and noticed that when in Lambda's Trainer the gradients are not being averaged across devices here:
nanodl/nanodl/__src/models/lamda.py
Lines 564 to 565 in 18c7f8e
Not sure if this is happening elsewhere but usually to keep the weights in sync you apply a
jax.lax.pmean
over the gradients before passing them toapply_gradients
, e.g.The text was updated successfully, but these errors were encountered: