webgpu: Fuse FlashAttention decode kernels and extend to any sequence length#28389
Open
qjia7 wants to merge 5 commits into
Open
webgpu: Fuse FlashAttention decode kernels and extend to any sequence length#28389qjia7 wants to merge 5 commits into
qjia7 wants to merge 5 commits into
Conversation
The 3-kernel FlashAttention decode path (QKT, SplitVx, VxReduce) previously only handled sequence_length=1. This extends it to work with any sequence length, providing a fallback for devices without Subgroups support. When Subgroups is available and seq_len>1, the subgroup-based prefill path is still preferred. The extended decode path activates when Subgroups is unavailable. Changes: - Workgroup layout now includes new_sequence_length dimension in all 3 kernels - Q offset supports both BSNH and BNSH layouts via q_BNSH template param - Causal masking (is_unidirectional) for self-attention with seq_len>1 - use_seqlen_k support for static KV cache (past_present_share_buffer) - Relaxed CanApplyFlashAttention to allow seq_len>1 without Subgroups Verified: whisper-tiny-int4 correct transcription, phi4-mini correct generation, 16/16 MHA unit tests pass.
…FlashAttention decode path Extend the 3-kernel decode path to process multiple Q rows per workgroup (m_tile=1/2/4) to amortize K/V memory loads for larger sequence lengths. Remove the Subgroups feature requirement from CanApplyFlashAttention so the decode path works on all WebGPU devices. The subgroup-based prefill path is replaced by the extended decode path with m_tile. Fix causal masking to compute past_sequence_length dynamically from total_sequence_length - new_sequence_length, which is correct for both GQA (where past_sequence_length_ is the buffer size) and graph capture (where total_sequence_length_ is on GPU).
…n decode Merge the separate QK^T and SplitVx shaders into a fused QKV shader that computes QK^T, local softmax, and V multiply in one kernel launch, eliminating the intermediate qk tensor and reducing dispatch count from 3 to 2. The VxReduce shader now rescales partial outputs using per-tile online softmax metadata (local_max, local_sum).
Use FlashAttentionProgram (single kernel, subgroup-based) for prefill when sequence is long enough to benefit. Fall back to the split-reduce path (fused QKV + VxReduce) for short sequences, when subgroups are unavailable, or when total_sequence_length is large relative to sequence_length.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR updates the WebGPU FlashAttention implementation to improve the decode/prefill path by fusing decode shaders (QKᵀ + softmax + V) into a single kernel, adding m_tile to process multiple query rows per workgroup, extending decode beyond seq_len=1, and routing between subgroup-based and non-subgroup paths to broaden device support.
Changes:
- Replace the decode QKT + SplitVx pipeline with a fused QKV shader + VxReduce (online softmax) pipeline, supporting
sequence_length > 1. - Add path routing between subgroup-based prefill and split-reduce decode to enable FlashAttention without subgroup support.
- Extend decode uniforms/shader logic to support causal masking, optional
seqlen_k, andm_tileoptimization.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | Updates program declarations/uniforms: replaces QKT/SplitVx with fused QKV; extends VxReduce options (seqlen_k/head_sink/m_tile). |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | Implements fused QKV dispatch, split-reduce routing, new metadata/output shapes, and removes subgroup gating from CanApplyFlashAttention. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template | Updates reduction shader to rescale partial V outputs using online softmax metadata; adds m_tile, head_sink support, and sequence-length indexing. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template | Deleted (functionality folded into fused QKV shader). |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template | New fused shader implementing QKᵀ + bias/mask + local softmax + V multiply, emitting per-tile metadata for reduce. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template | Deleted (replaced by fused QKV shader). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+510
to
516
| // When use_seqlen_k is true, total_sequence_length_ may be 0 (actual value is in seqlen_k tensor). | ||
| // Use present_sequence_length for tile count calculations; shaders will read the actual value from seqlen_k. | ||
| const uint32_t effective_total_seq_len = use_seqlen_k ? present_sequence_length | ||
| : static_cast<uint32_t>(parameters.total_sequence_length_); | ||
|
|
||
| const uint32_t num_total_seq_length_tile = (effective_total_seq_len + tile_size - 1) / tile_size; | ||
| const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size; |
Comment on lines
+60
to
+66
| let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; | ||
| let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; | ||
| // Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile] | ||
| let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; | ||
| let q_tile_idx = (workgroup_idx / num_total_seq_length_tile) % num_q_tiles; | ||
| let q_base = q_tile_idx * m_tile; | ||
| let batch_head_idx = u32(workgroup_idx / (num_total_seq_length_tile * num_q_tiles)); |
Comment on lines
+33
to
38
| #if use_seqlen_k | ||
| let total_sequence_length = u32(seqlens_k[0]) + 1u; | ||
| let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size; | ||
| #else | ||
| let num_total_seq_length_tile = uniforms.num_total_seq_length_tile; | ||
| #endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
use_seqlen_ksupport for static KV cacheqktensor (B×H×seq×present_seq) and reducing dispatch count from 3 to 2CanApplyFlashAttention, enabling FlashAttention on devices without subgroup supportResolved Issues
FlashAttention now works without subgroup support. Previously
CanApplyFlashAttentionrequiredSubgroupsfeature, blocking devices that lack it. The split-reduce path needs no subgroup intrinsics, so FlashAttention is now available on all WebGPU devices.Whisper decoding prefill improved from 4.68ms to 1.09ms. Whisper's decoder attention has a small sequence length but large total sequence length (seq_len=4, total_seq_len=1500). The default prefill shader (FlashAttentionProgram) has low parallelism in this case because each workgroup iterates serially over the full KV cache. The split-reduce path tiles the KV dimension across workgroups, achieving much higher GPU occupancy for this workload shape.
Details
Fused QKV kernel: Each workgroup computes QK^T dot products, applies attention bias and causal mask, computes local softmax (per-tile max and sum), normalizes, and multiplies by V — all in one kernel. Per-tile metadata (max, sum) is written for the VxReduce shader to rescale partial outputs using online softmax:
output = Σ(partial_i × local_sum_i × exp(local_max_i - global_max)) / global_sum.Path routing (
use_split_reduce): The split-reduce path is used whenseq_len ≤ 4, subgroups are unavailable, orseq_len < 64 && total_seq > 1000. Otherwise the single-kernel FlashAttentionProgram prefill path is used.Test plan