[Splash Attention] Remove unnecessary head_dim_v % NUM_LANES == 0
constraint
#27427
Labels
enhancement
New feature or request
Hi JAX team and @sharadmv,
I found an unnecessary constraint in the
flash_attention_kernel
implementation in the Splash Attention module that prevents small models from using it.Currently, the kernel enforces that
head_dim_v
must be a multiple ofNUM_LANES
(128):Source:
jax/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py
Lines 714 to 718 in 49aad1b
However, this constraint is unnecessary. It arises from the way
head_dim_v_repeats
is used to reshapealpha
:This reshape assumes$(d_{i-1}' * e^{m_{i-1} - m_i}) / d_i$ in Algorithm FlashAttention (Tiling) and it should have shape [bq, 1], not
alpha
has shape[bq, NUM_LANES]
and change it to[bq, head_dim]
. However in fact,alpha
is[bq, NUM_LANES]
.In the code:
m_scratch
l_scratch
Currently, both are unnecessarily allocated with shape
[bq, NUM_LANES]
, causing all NUM_LANES elements to redundantly store the same value.Source:
jax/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py
Lines 1056 to 1057 in 49aad1b
To be clear, let me annotate the shape in the Flash attention algorithm in https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf
Why this matters
By changing the shape of
m_scratch
,l_scratch
, andalpha
to[bq, 1]
:repeat(alpha, head_dim_v_repeats, ...)
.head_dim_v % NUM_LANES == 0
constraint entirely.This enables smaller models (e.g. sLLMs) to use Splash Attention.
Unfortunately, due to my company's internal policy, I’m unable to submit a PR myself. I hope the JAX maintainers can address this issue — fixing it should not require major changes and will benefit a wider range of users.
Thanks for your excellent work!
The text was updated successfully, but these errors were encountered: