Skip to content

[JAX] Add support for Fused Attn MLA head_dim_qk != head_dim_v #1851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Jun 4, 2025

Description

The MLA (DS_v3) support is available for hopper with CUDNN 9.10. However, support for this through the TE-JAX fused attention pathway is unavailable. This PR aims to provide this support.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Modify is_fused_attn_kernel_available() to accept different head_dims for qk and v
  • Modify FusedAttnHelper to accept different head_dims for qk and v and modify assert dims checks in parse_qkv_aval()
  • Modify FusedAttnFwdPrimitive and FusedAttnBwdPrimitive to accept different head_dims for qk and v
  • Modify Fused Attn related cpp and csrc extension API calls to accept different head_dims for qk and v
  • Modify DotProductAttention call() to extract head dims separately for qk and v
  • Modify the FusedAttn Tests to accommodate for API changes in FusedAttn API
  • Add test case for head_dim_qk != head_dim_v
  • Modify the baseline JAX appropriately to reshape the output vector based on v dims and not q dims

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani requested a review from cyanguwa June 4, 2025 20:37
@KshitijLakhani KshitijLakhani self-assigned this Jun 4, 2025
@KshitijLakhani KshitijLakhani changed the title Add support for Fused Attn MLA head_dim_qk != head_dim_v [JAX] Add support for Fused Attn MLA head_dim_qk != head_dim_v Jun 12, 2025
@KshitijLakhani KshitijLakhani force-pushed the klakhani/feature/add-mla-jax-fused-support branch from 612ffdf to ed071aa Compare June 12, 2025 19:05
@KshitijLakhani KshitijLakhani marked this pull request as ready for review June 12, 2025 19:05
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax

KshitijLakhani and others added 7 commits June 13, 2025 12:17
	Modify is_fused_attn_kernel_available() to accept different head_dims for qk and v
	Modify FusedAttnHelper to accept different head_dims for qk and v and modify assert dims checks in parse_qkv_aval()
	Modify FusedAttnFwdPrimitive and FusedAttnBwdPrimitive to accept different head_dims for qk and v
	Modify Fused Attn related cpp and csrc extension API calls to accept different head_dims for qk and v
	Modify DotProductAttention call() to extract head dims separately for qk and v
	Modify the FusedAttn Tests to accommodate for API changes in FusedAttn API
	Add test case for head_dim_qk != head_dim_v (failing)
	Modify the baseline JAX appropriately to reshape the output vector based on v dims and not q dims

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
…head dim

Add test cases for jax fused attn where head_dim_qk != head_dim_v for a combination of data types and attention type

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
…r qk and v in Fused Attn distributed tests

Code clean up

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/feature/add-mla-jax-fused-support branch from 780c0d7 to 24788f1 Compare June 13, 2025 19:17
@KshitijLakhani
Copy link
Collaborator Author

/te-ci JAX

cyanguwa
cyanguwa previously approved these changes Jun 13, 2025
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
@KshitijLakhani KshitijLakhani merged commit 1ddfa0c into NVIDIA:main Jun 13, 2025
12 checks passed
@KshitijLakhani
Copy link
Collaborator Author

Successful pipeline : 30053940

phu0ngng pushed a commit to phu0ngng/TransformerEngine that referenced this pull request Jun 16, 2025
…A#1851)

* Add support for Fused Attn MLA head_dim_qk != head_dim_v
	Modify is_fused_attn_kernel_available() to accept different head_dims for qk and v
	Modify FusedAttnHelper to accept different head_dims for qk and v and modify assert dims checks in parse_qkv_aval()
	Modify FusedAttnFwdPrimitive and FusedAttnBwdPrimitive to accept different head_dims for qk and v
	Modify Fused Attn related cpp and csrc extension API calls to accept different head_dims for qk and v
	Modify DotProductAttention call() to extract head dims separately for qk and v
	Modify the FusedAttn Tests to accommodate for API changes in FusedAttn API
	Add test case for head_dim_qk != head_dim_v (failing)
	Modify the baseline JAX appropriately to reshape the output vector based on v dims and not q dims

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix context dims in general DPA in test_fused_attn

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>

* Fix dim for output tensor by replacing with v head dim rather than q head dim
Add test cases for jax fused attn where head_dim_qk != head_dim_v for a combination of data types and attention type

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>

* Modify the fused attn jax unit test case for head dim qk != head dim v

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>

* Use new FusedAttnRunner function signature for separate hidden dim for qk and v in Fused Attn distributed tests
Code clean up

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>

* Fix usage of is_fused_attn signature in distributed tests

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>

* Remove unnecessary assert

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>

---------

Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants