Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient synchronization in data-parallel trainers #12

Open
cgarciae opened this issue Feb 28, 2024 · 1 comment
Open

Gradient synchronization in data-parallel trainers #12

cgarciae opened this issue Feb 28, 2024 · 1 comment

Comments

@cgarciae
Copy link

cgarciae commented Feb 28, 2024

Hey, great job with nanodl!

I was just looking through the code and noticed that when in Lambda's Trainer the gradients are not being averaged across devices here:

loss, grads = jax.value_and_grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)

Not sure if this is happening elsewhere but usually to keep the weights in sync you apply a jax.lax.pmean over the gradients before passing them to apply_gradients, e.g.

grads = jax.lax.pmean(grads, axis_name='devices')
@cgarciae cgarciae changed the title Gradient synchronization in data-parallel traininers Gradient synchronization in data-parallel trainers Feb 28, 2024
@HMUNACHI
Copy link
Owner

HMUNACHI commented Mar 2, 2024

Thanks for noticing this! It's often challenging to test these portions due to the unavailability of a personal multi-GPU setup for development. However, I will be accessing 2 GPUs around 10th March. Will immediately examine this but you are more than welcome to make corrections from your end if convenient, I would in fact very much appreciate that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants