Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new flash attn features to cuDNN SDPA API and remove fused attn #21228

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Cjkkkk
Copy link
Contributor

@Cjkkkk Cjkkkk commented May 14, 2024

  • add variable sequence length: Accepts two additional tensor seqlen_q and seqlen_kv to indicate the non padded length to reduce computation.
  • add MQA/GQA.
  • add broadcast bias: bias can be broadcast on batch/head dim.
  • add dbias calculation.
  • remove fused attn and default to flash attn.

@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented May 14, 2024

@superbobry Hi Sergei, could you help review this PR?

@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented May 16, 2024

@superbobry hi Sergei, any updates on this?

@superbobry
Copy link
Member

No updates just yet, sorry. I will review some time tomorrow.

Copy link
Member

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did my best to read through, but these large diffs are really hard to get through. Please send smaller PRs for any follow up changes.

I would also recommend to ask someone from NVidia to review for CUDNN APIs etc.

jax/_src/cudnn/fused_attention_stablehlo.py Outdated Show resolved Hide resolved
jax/_src/cudnn/fused_attention_stablehlo.py Outdated Show resolved Hide resolved
jax/_src/cudnn/fused_attention_stablehlo.py Outdated Show resolved Hide resolved
tests/fused_attention_stablehlo_test.py Outdated Show resolved Hide resolved
tests/fused_attention_stablehlo_test.py Outdated Show resolved Hide resolved
@@ -41,10 +41,42 @@ class AttentionLayout(Enum):
BTNH = 0
BNTH = 1

class MaskType(Enum):
NO_MASK = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC why not use None instead, when a mask is not specified?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a choice to make it more explicit

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels May 17, 2024
@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented May 17, 2024

I did my best to read through, but these large diffs are really hard to get through. Please send smaller PRs for any follow up changes.

I would also recommend to ask someone from NVidia to review for CUDNN APIs etc.

Understood, sorry for the large PR, i will create smaller one next time. I think people from Nvidia don't have access to approve and merge the PR?

@superbobry
Copy link
Member

Please address the comments, and the we can merge.

@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented May 20, 2024

Please address the comments, and the we can merge.

Comments addressed, sorry about the delay.

@superbobry
Copy link
Member

Can you squash the PR please?

copybara-service bot pushed a commit that referenced this pull request May 20, 2024
--
f625317 by cjkkkk <[email protected]>:

init

COPYBARA_INTEGRATE_REVIEW=#21228 from Cjkkkk:sdpa_new_cudnn_frontend f625317
PiperOrigin-RevId: 635518631
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants