-
Notifications
You must be signed in to change notification settings - Fork 432
[JAX] Add support for Fused Attn MLA head_dim_qk != head_dim_v #1851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
KshitijLakhani
merged 8 commits into
NVIDIA:main
from
KshitijLakhani:klakhani/feature/add-mla-jax-fused-support
Jun 13, 2025
Merged
[JAX] Add support for Fused Attn MLA head_dim_qk != head_dim_v #1851
KshitijLakhani
merged 8 commits into
NVIDIA:main
from
KshitijLakhani:klakhani/feature/add-mla-jax-fused-support
Jun 13, 2025
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cyanguwa
reviewed
Jun 4, 2025
612ffdf
to
ed071aa
Compare
/te-ci jax |
Modify is_fused_attn_kernel_available() to accept different head_dims for qk and v Modify FusedAttnHelper to accept different head_dims for qk and v and modify assert dims checks in parse_qkv_aval() Modify FusedAttnFwdPrimitive and FusedAttnBwdPrimitive to accept different head_dims for qk and v Modify Fused Attn related cpp and csrc extension API calls to accept different head_dims for qk and v Modify DotProductAttention call() to extract head dims separately for qk and v Modify the FusedAttn Tests to accommodate for API changes in FusedAttn API Add test case for head_dim_qk != head_dim_v (failing) Modify the baseline JAX appropriately to reshape the output vector based on v dims and not q dims Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
…head dim Add test cases for jax fused attn where head_dim_qk != head_dim_v for a combination of data types and attention type Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
…r qk and v in Fused Attn distributed tests Code clean up Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
780c0d7
to
24788f1
Compare
/te-ci JAX |
cyanguwa
previously approved these changes
Jun 13, 2025
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
cyanguwa
approved these changes
Jun 13, 2025
Successful pipeline : 30053940 |
phu0ngng
pushed a commit
to phu0ngng/TransformerEngine
that referenced
this pull request
Jun 16, 2025
…A#1851) * Add support for Fused Attn MLA head_dim_qk != head_dim_v Modify is_fused_attn_kernel_available() to accept different head_dims for qk and v Modify FusedAttnHelper to accept different head_dims for qk and v and modify assert dims checks in parse_qkv_aval() Modify FusedAttnFwdPrimitive and FusedAttnBwdPrimitive to accept different head_dims for qk and v Modify Fused Attn related cpp and csrc extension API calls to accept different head_dims for qk and v Modify DotProductAttention call() to extract head dims separately for qk and v Modify the FusedAttn Tests to accommodate for API changes in FusedAttn API Add test case for head_dim_qk != head_dim_v (failing) Modify the baseline JAX appropriately to reshape the output vector based on v dims and not q dims Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix context dims in general DPA in test_fused_attn Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Fix dim for output tensor by replacing with v head dim rather than q head dim Add test cases for jax fused attn where head_dim_qk != head_dim_v for a combination of data types and attention type Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Modify the fused attn jax unit test case for head dim qk != head dim v Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Use new FusedAttnRunner function signature for separate hidden dim for qk and v in Fused Attn distributed tests Code clean up Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Fix usage of is_fused_attn signature in distributed tests Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Remove unnecessary assert Signed-off-by: Kshitij Janardan Lakhani <[email protected]> --------- Signed-off-by: Kshitij Janardan Lakhani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
The MLA (DS_v3) support is available for hopper with CUDNN 9.10. However, support for this through the TE-JAX fused attention pathway is unavailable. This PR aims to provide this support.
Type of change
Changes
is_fused_attn_kernel_available()
to accept different head_dims for qk and vFusedAttnHelper
to accept different head_dims for qk and v and modify assert dims checks in parse_qkv_aval()FusedAttnFwdPrimitive
andFusedAttnBwdPrimitive
to accept different head_dims for qk and vDotProductAttention call()
to extract head dims separately for qk and vhead_dim_qk != head_dim_v
Checklist: