You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Splash Attention breaks many small keras hub models on tpuv4-pods when using DataParallel. For instance I have tested a few siglip models and the clip model, and the same thing still happens.
Expected behavior
The model to not crash with an error and to properly run the siglip model:
Keras:
20.6% that image 0 is 'This is a photo of 2 cats'
0.0% that image 1 is 'This is a photo of 2 dogs'
Loss: 6.132134437561035
Additional context
This is reproducible with tpuv4-64 (and similar sizes) and tpuv2-8 (colab)
Frankly, I'm not sure how to fix the splash attention implementation in jax's pallas, so I think that having a way to disable it for tpus and use normal attention would be helpful.
I was able to get it running by using normal attention in jax:
Describe the bug
To Reproduce
Colab Link
Expected behavior
The model to not crash with an error and to properly run the siglip model:
Keras:
20.6% that image 0 is 'This is a photo of 2 cats'
0.0% that image 1 is 'This is a photo of 2 dogs'
Loss: 6.132134437561035
Additional context
Would you like to help us fix it?
return jax.nn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
implementation= "xla",
)
before
keras/keras/src/backend/jax/nn.py
Line 1181 in 6d26efb
This just returns using normal flash attention.
The text was updated successfully, but these errors were encountered: