[Splash Attention] Remove Unnecessary head_dim_v Constraint and Update Scratch Array Shapes #27427 #27461
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR addresses issue #27427 by removing the unnecessary constraint that enforced
head_dim_v
to be a multiple ofNUM_LANES
. The following changes are made:Constraint Removal:
The check:
has been removed. This constraint prevented small models (e.g. sLLMs) from using Splash Attention when head_dim_v is not a multiple of NUM_LANES.
Scratch Array Shape Update:
The scratch arrays m_scratch and l_scratch are now allocated with shape [bq, 1] instead of [bq, NUM_LANES], reducing redundant memory allocation. Broadcasting is now relied upon to handle the expansion as needed in downstream operations.
Alpha Repetition Removal:
All occurrences where pltpu.repeat(alpha, head_dim_v_repeats, axis=1) was used have been replaced with a direct assignment (alpha_o = alpha), which leverages broadcasting for correct behavior.
Testing:
All existing tests pass on both parallel and serial test runs.