Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient Accumulation API aliasing #8697

Open
rpsilva-aws opened this issue Feb 10, 2025 · 4 comments
Open

Gradient Accumulation API aliasing #8697

rpsilva-aws opened this issue Feb 10, 2025 · 4 comments
Assignees
Labels

Comments

@rpsilva-aws
Copy link
Collaborator

rpsilva-aws commented Feb 10, 2025

🐛 Bug

In Llama-3 8B, with 4-layers, we have noticed that the aliasing is missed for half of the outputs (70/140), particularly for gradients and the second momentum optimizer state. This ends up accounting to an additional 1.38GB tensor memory on the device, for the second step onwards.

When minimally reproducing with a simple linear model, the issue does not reproduce, and we see the expected behavior when using the traditional gradient accumulation and the optimized API counterpart for the program shape:

2025-02-10 23:43:20.151213: I torch_xla/csrc/xla_graph_executor.cpp:1445] Compiled program shape (p0: f32[], p1: f32[], p2: f32[3], p3: f32[3,3], p4: f32[3], p5: f32[3,8417], p6: f32[8417], p7: f32[8417,16834], p8: f32[3], p9: f32[3,3], p10: f32[3], p11: f32[3,8417], p12: f32[8417], p13: f32[8417,16834], p14: s64[8,128], p15: f32[8,128,16834]) -> (f32[], f32[], f32[8417,16834], f32[8417], f32[3,8417], /*index=5*/f32[3], f32[3,3], f32[3])

With the gradient accumulation XLA loop:

Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 1.055937 GB
Post Compilation Analysis: Graph output size: 0.527969 GB
Post Compilation Analysis: Aliased Input size: 0.527968 GB
Post Compilation Analysis: Intermediate tensor size: 0.000000 GB
Post Compilation Analysis: Compiled program size: 0.000000 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================


Alias count: (2, 6.0, ((1739230998.1717062, 0.0), (1739230998.288421, 6.0)))

When not using the loop:

Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 1.055937 GB
Post Compilation Analysis: Graph output size: 0.527969 GB
Post Compilation Analysis: Aliased Input size: 0.527968 GB
Post Compilation Analysis: Intermediate tensor size: 0.000000 GB
Post Compilation Analysis: Compiled program size: 0.000000 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================


Alias count: (2, 6.0, ((1739230863.2698746, 0.0), (1739230863.6376648, 6.0)))

I will be using Llama-3 to understand why the buffer donations are not kicking in for 70 of the input/outputs.

@rpsilva-aws rpsilva-aws self-assigned this Feb 10, 2025
@ysiraichi
Copy link
Collaborator

Thank you for filing this issue. We might be able to help you more if you give us a bit more information:

  1. Do you have a minimal reproducer?
  2. Is there a specific device that where you observe this behavior?

cc @miladm @zpcore @tengyifei

@rpsilva-aws
Copy link
Collaborator Author

Do you have a minimal reproducer?

I attempted it with MLP, as mentioned above, to no avail. I have only encountered this with Llama3, so it might be specific to the implementation.

Is there a specific device that where you observe this behavior?

The behavior should be device agnostic afaik. The MLP example doesn't reproduce on CPU/GPU/Neuron, but Llama3 does on Neuron (haven't tried other devices for different incompatibility reasons).

Another interesting question when dealing with aliasing, is whether it reproduces with functionalization or not (e.g. requires RCA'ing aten's propagation path), but it seems to reproduce with both.

I have assigned it to myself, so I'll be sharing insights as I RCA it. I have opened it already in case we see similar issues elsewhere.

@rpsilva-aws
Copy link
Collaborator Author

Tracing the LM head grad on step 1 with TP=32 and functionalization enabled, with gradients preemptively zero'ed:

TensorID: 848
Device: SPMD:0
XLA Shape: f32[128256,4096]
ShardingSpec: {devices=[32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
IR: None
XLAShardedData: 
  Data Device: SPMD:0
  Data Shape: f32[128256,4096]
  OpSharding: {devices=[32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
  NumShards: 32
Tensor on host: None

When tracing the XLA data propagation of the aliasing, I see that when transforming an in-place op, we have the following identifiers:

  • Output alias: 848
  • Input ID: 668
  • Input alias: 668

When we go over all the indices when syncing the graph, we see that 668 is one of the buffer donor indices that is added to the output buffer map (need to verify if it indeed does).

Interestingly, prior to the mark step, the final tensor id for the grad is:

TensorID: 1446
Device: SPMD:0
XLA Shape: f32[128256,4096]
ShardingSpec: {devices=[32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
IR: None
XLAShardedData: 
  Data Device: SPMD:0
  Data Shape: f32[128256,4096]
  OpSharding: {devices=[32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
  NumShards: 32
Tensor on host: None
}

so somewhere in between, we lose the aliasing propagating from the input. The expectation is that the input grad (assumed to be initially 0), should be propagating resulting value inside the gradient accumulation API, which should be incrementally accumulating the gradients across the XLA loop iterations.

We'll need to print all the IRs when tracing the user computation, to see where the aliasing is being missed.

@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented Feb 12, 2025

37/80 tensors to sync...
Parameter at index 9 (alias_id=2811, shape=f32[1,1,131072,128]) is not donated.
Parameter at index 10 (alias_id=2810, shape=f32[1,1,131072,128]) is not donated.
Parameter at index 14 (alias_id=2826, shape=f32[128256,4096]) is not donated.
Parameter at index 15 (alias_id=2825, shape=f32[4096]) is not donated.
Parameter at index 16 (alias_id=2824, shape=f32[4096]) is not donated.
Parameter at index 17 (alias_id=2823, shape=f32[4096]) is not donated.
Parameter at index 18 (alias_id=2822, shape=f32[4096,14336]) is not donated.
Parameter at index 19 (alias_id=2821, shape=f32[4096,2,14336]) is not donated.
Parameter at index 20 (alias_id=2820, shape=f32[4096,3,32,128]) is not donated.
Parameter at index 21 (alias_id=2819, shape=f32[4096,4096]) is not donated.
Parameter at index 22 (alias_id=2818, shape=f32[128256,4096]) is not donated.
Parameter at index 23 (alias_id=2800, shape=f32[128256,4096]) is not donated.
Parameter at index 24 (alias_id=2813, shape=f32[4096]) is not donated.
Parameter at index 25 (alias_id=2812, shape=f32[4096]) is not donated.
Parameter at index 26 (alias_id=2808, shape=f32[4096]) is not donated.
Parameter at index 27 (alias_id=2802, shape=f32[4096,14336]) is not donated.
Parameter at index 28 (alias_id=2803, shape=f32[4096,2,14336]) is not donated.
Parameter at index 29 (alias_id=2805, shape=f32[4096,3,32,128]) is not donated.
Parameter at index 30 (alias_id=2804, shape=f32[4096,4096]) is not donated.
Parameter at index 31 (alias_id=2807, shape=f32[128256,4096]) is not donated.
Parameter at index 32 (alias_id=25, shape=s32[4,1,4096]) is not donated.
Parameter at index 33 (alias_id=23, shape=s32[4,1,4096]) is not donated.
Parameter at index 35 (alias_id=651, shape=f32[128256,4096]) is donated.
Parameter at index 38 (alias_id=650, shape=f32[128256,4096]) is donated.
Parameter at index 41 (alias_id=653, shape=f32[4096,4096]) is donated.
Parameter at index 42 (alias_id=652, shape=f32[4096,4096]) is donated.
Parameter at index 43 (alias_id=655, shape=f32[4096,3,32,128]) is donated.
Parameter at index 44 (alias_id=654, shape=f32[4096,3,32,128]) is donated.
Parameter at index 45 (alias_id=659, shape=f32[4096,14336]) is donated.
Parameter at index 46 (alias_id=658, shape=f32[4096,14336]) is donated.
Parameter at index 47 (alias_id=657, shape=f32[4096,2,14336]) is donated.
Parameter at index 48 (alias_id=656, shape=f32[4096,2,14336]) is donated.
Parameter at index 49 (alias_id=661, shape=f32[4096]) is donated.
Parameter at index 50 (alias_id=660, shape=f32[4096]) is donated.
Parameter at index 51 (alias_id=663, shape=f32[4096]) is donated.
Parameter at index 52 (alias_id=662, shape=f32[4096]) is donated.
Parameter at index 53 (alias_id=665, shape=f32[4096]) is donated.
Parameter at index 54 (alias_id=664, shape=f32[4096]) is donated.
Parameter at index 55 (alias_id=667, shape=f32[128256,4096]) is donated.
Parameter at index 56 (alias_id=666, shape=f32[128256,4096]) is donated.
18/57 parameters are donated when sync'ing 37/80 tensors.

Looking at a random sample of donated tensors when syncing the graph at mark step, we can cross check the identified missing aliases from the HLOs:

  • First momentum optim state
TensorID: 651
AliasID: 651
Device: SPMD:0
XLA Shape: f32[128256,4096]
ShardingSpec: {devices=[32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
IR: [] aten::addcmul, [email protected]:63, xla_shape=f32[128256,4096]{1,0}, dynamic_dims: ()
XLAData: None
Tensor on host: None
}
  • Second momentum optim state
TensorID: 653
AliasID: 653
Device: SPMD:0
XLA Shape: f32[4096,4096]
ShardingSpec: {devices=[1,32]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
IR: None
XLAShardedData: 
  Data Device: SPMD:0
  Data Shape: f32[4096,4096]
  OpSharding: {devices=[1,32]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
  NumShards: 32
Tensor on host: None
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants