Skip to content

Splash Attention is Broken on TPU Pods and does not follow keras.config.disable_flash_attention() #21116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
chaosmaster142857 opened this issue Apr 2, 2025 · 1 comment
Assignees
Labels

Comments

@chaosmaster142857
Copy link

Describe the bug

  • Splash Attention breaks many small keras hub models on tpuv4-pods when using DataParallel. For instance I have tested a few siglip models and the clip model, and the same thing still happens.

To Reproduce
Colab Link

Expected behavior
The model to not crash with an error and to properly run the siglip model:
Keras:
20.6% that image 0 is 'This is a photo of 2 cats'
0.0% that image 1 is 'This is a photo of 2 dogs'
Loss: 6.132134437561035

Additional context

Would you like to help us fix it?

  • Frankly, I'm not sure how to fix the splash attention implementation in jax's pallas, so I think that having a way to disable it for tpus and use normal attention would be helpful.
  • I was able to get it running by using normal attention in jax:

return jax.nn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
implementation= "xla",
)

before


This just returns using normal flash attention.

@chaosmaster142857
Copy link
Author

To fix this, you can also just check the value of use_flash_attention in keras's config.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants