Skip to content

[Splash Attention] Remove unnecessary head_dim_v % NUM_LANES == 0 constraint #27427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ds-hwang opened this issue Mar 25, 2025 · 1 comment
Assignees
Labels
enhancement New feature or request

Comments

@ds-hwang
Copy link
Contributor

Hi JAX team and @sharadmv,
I found an unnecessary constraint in the flash_attention_kernel implementation in the Splash Attention module that prevents small models from using it.

Currently, the kernel enforces that head_dim_v must be a multiple of NUM_LANES (128):

head_dim_v_repeats, rem = divmod(head_dim_v, NUM_LANES)
if rem != 0:
    raise NotImplementedError(
        f"{head_dim_v=} should be a multiple of {NUM_LANES}"
    )

Source:

head_dim_v_repeats, rem = divmod(head_dim_v, NUM_LANES)
if rem != 0:
raise NotImplementedError(
f"{head_dim_v=} should be a multiple of {NUM_LANES}"
)

However, this constraint is unnecessary. It arises from the way head_dim_v_repeats is used to reshape alpha:

alpha_o = pltpu.repeat(alpha, head_dim_v_repeats, axis=1)

This reshape assumes alpha has shape [bq, NUM_LANES] and change it to [bq, head_dim]. However in fact, alpha is $(d_{i-1}' * e^{m_{i-1} - m_i}) / d_i$ in Algorithm FlashAttention (Tiling) and it should have shape [bq, 1], not [bq, NUM_LANES].

In the code:

  • $m_i$ corresponds to m_scratch
  • $d_i'$ corresponds to l_scratch

Currently, both are unnecessarily allocated with shape [bq, NUM_LANES], causing all NUM_LANES elements to redundantly store the same value.
Source:

jax.ShapeDtypeStruct((bq, NUM_LANES), jnp.float32), # m_scratch
jax.ShapeDtypeStruct((bq, NUM_LANES), jnp.float32), # l_scratch

To be clear, let me annotate the shape in the Flash attention algorithm in https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf

Image


Why this matters

By changing the shape of m_scratch, l_scratch, and alpha to [bq, 1]:

  • We avoid unnecessary memory and compute overhead.
  • We eliminate the need to repeat(alpha, head_dim_v_repeats, ...).
  • Most importantly, we can remove the head_dim_v % NUM_LANES == 0 constraint entirely.

This enables smaller models (e.g. sLLMs) to use Splash Attention.


Unfortunately, due to my company's internal policy, I’m unable to submit a PR myself. I hope the JAX maintainers can address this issue — fixing it should not require major changes and will benefit a wider range of users.

Thanks for your excellent work!

@ds-hwang
Copy link
Contributor Author

ds-hwang commented Apr 2, 2025

@jakevdp, @mattjj, @sharadmv Could you take a look when you have a room? Thank you.

@Rifur13 Rifur13 self-assigned this Apr 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants