You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[JAX] Add support for Fused Attn MLA head_dim_qk != head_dim_v (NVIDIA#1851)
* Add support for Fused Attn MLA head_dim_qk != head_dim_v
Modify is_fused_attn_kernel_available() to accept different head_dims for qk and v
Modify FusedAttnHelper to accept different head_dims for qk and v and modify assert dims checks in parse_qkv_aval()
Modify FusedAttnFwdPrimitive and FusedAttnBwdPrimitive to accept different head_dims for qk and v
Modify Fused Attn related cpp and csrc extension API calls to accept different head_dims for qk and v
Modify DotProductAttention call() to extract head dims separately for qk and v
Modify the FusedAttn Tests to accommodate for API changes in FusedAttn API
Add test case for head_dim_qk != head_dim_v (failing)
Modify the baseline JAX appropriately to reshape the output vector based on v dims and not q dims
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Fix context dims in general DPA in test_fused_attn
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
* Fix dim for output tensor by replacing with v head dim rather than q head dim
Add test cases for jax fused attn where head_dim_qk != head_dim_v for a combination of data types and attention type
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
* Modify the fused attn jax unit test case for head dim qk != head dim v
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
* Use new FusedAttnRunner function signature for separate hidden dim for qk and v in Fused Attn distributed tests
Code clean up
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
* Fix usage of is_fused_attn signature in distributed tests
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
* Remove unnecessary assert
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
---------
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
0 commit comments