Closed
Description
🐛 Bug
In Llama-3 8B, with 4-layers, we have noticed that the aliasing is missed for half of the outputs (70/140), particularly for gradients and the second momentum optimizer state. This ends up accounting to an additional 1.38GB tensor memory on the device, for the second step onwards.
When minimally reproducing with a simple linear model, the issue does not reproduce, and we see the expected behavior when using the traditional gradient accumulation and the optimized API counterpart for the program shape:
2025-02-10 23:43:20.151213: I torch_xla/csrc/xla_graph_executor.cpp:1445] Compiled program shape (p0: f32[], p1: f32[], p2: f32[3], p3: f32[3,3], p4: f32[3], p5: f32[3,8417], p6: f32[8417], p7: f32[8417,16834], p8: f32[3], p9: f32[3,3], p10: f32[3], p11: f32[3,8417], p12: f32[8417], p13: f32[8417,16834], p14: s64[8,128], p15: f32[8,128,16834]) -> (f32[], f32[], f32[8417,16834], f32[8417], f32[3,8417], /*index=5*/f32[3], f32[3,3], f32[3])
With the gradient accumulation XLA loop:
Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 1.055937 GB
Post Compilation Analysis: Graph output size: 0.527969 GB
Post Compilation Analysis: Aliased Input size: 0.527968 GB
Post Compilation Analysis: Intermediate tensor size: 0.000000 GB
Post Compilation Analysis: Compiled program size: 0.000000 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================
Alias count: (2, 6.0, ((1739230998.1717062, 0.0), (1739230998.288421, 6.0)))
When not using the loop:
Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 1.055937 GB
Post Compilation Analysis: Graph output size: 0.527969 GB
Post Compilation Analysis: Aliased Input size: 0.527968 GB
Post Compilation Analysis: Intermediate tensor size: 0.000000 GB
Post Compilation Analysis: Compiled program size: 0.000000 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================
Alias count: (2, 6.0, ((1739230863.2698746, 0.0), (1739230863.6376648, 6.0)))
I will be using Llama-3 to understand why the buffer donations are not kicking in for 70 of the input/outputs.