-
Notifications
You must be signed in to change notification settings - Fork 498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient Accumulation API aliasing #8697
Comments
Thank you for filing this issue. We might be able to help you more if you give us a bit more information:
|
I attempted it with MLP, as mentioned above, to no avail. I have only encountered this with Llama3, so it might be specific to the implementation.
The behavior should be device agnostic afaik. The MLP example doesn't reproduce on CPU/GPU/Neuron, but Llama3 does on Neuron (haven't tried other devices for different incompatibility reasons). Another interesting question when dealing with aliasing, is whether it reproduces with functionalization or not (e.g. requires RCA'ing aten's propagation path), but it seems to reproduce with both. I have assigned it to myself, so I'll be sharing insights as I RCA it. I have opened it already in case we see similar issues elsewhere. |
Tracing the LM head grad on step 1 with TP=32 and functionalization enabled, with gradients preemptively zero'ed:
When tracing the XLA data propagation of the aliasing, I see that when transforming an in-place op, we have the following identifiers:
When we go over all the indices when syncing the graph, we see that Interestingly, prior to the mark step, the final tensor id for the grad is:
so somewhere in between, we lose the aliasing propagating from the input. The expectation is that the input grad (assumed to be initially 0), should be propagating resulting value inside the gradient accumulation API, which should be incrementally accumulating the gradients across the XLA loop iterations. We'll need to print all the IRs when tracing the user computation, to see where the aliasing is being missed. |
Looking at a random sample of donated tensors when syncing the graph at mark step, we can cross check the identified missing aliases from the HLOs:
|
🐛 Bug
In Llama-3 8B, with 4-layers, we have noticed that the aliasing is missed for half of the outputs (70/140), particularly for gradients and the second momentum optimizer state. This ends up accounting to an additional 1.38GB tensor memory on the device, for the second step onwards.
When minimally reproducing with a simple linear model, the issue does not reproduce, and we see the expected behavior when using the traditional gradient accumulation and the optimized API counterpart for the program shape:
With the gradient accumulation XLA loop:
When not using the loop:
I will be using Llama-3 to understand why the buffer donations are not kicking in for 70 of the input/outputs.
The text was updated successfully, but these errors were encountered: