Skip to content

[Splash Attention] Remove Unnecessary head_dim_v Constraint and Update Scratch Array Shapes #27427 #27461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

SwarnimShekhar
Copy link

@SwarnimShekhar SwarnimShekhar commented Mar 26, 2025

This PR addresses issue #27427 by removing the unnecessary constraint that enforced head_dim_v to be a multiple of NUM_LANES. The following changes are made:

Constraint Removal:
The check:

head_dim_v_repeats, rem = divmod(head_dim_v, NUM_LANES)
if rem != 0:
    raise NotImplementedError(
        f"{head_dim_v=} should be a multiple of {NUM_LANES}"
    )

has been removed. This constraint prevented small models (e.g. sLLMs) from using Splash Attention when head_dim_v is not a multiple of NUM_LANES.

Scratch Array Shape Update:
The scratch arrays m_scratch and l_scratch are now allocated with shape [bq, 1] instead of [bq, NUM_LANES], reducing redundant memory allocation. Broadcasting is now relied upon to handle the expansion as needed in downstream operations.

Alpha Repetition Removal:
All occurrences where pltpu.repeat(alpha, head_dim_v_repeats, axis=1) was used have been replaced with a direct assignment (alpha_o = alpha), which leverages broadcasting for correct behavior.

Testing:

All existing tests pass on both parallel and serial test runs.

@ds-hwang
Copy link
Contributor

Lovely. Thank you for quick action!

@SwarnimShekhar
Copy link
Author

Glad to contribute!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants