Skip to content

webgpu: Fuse FlashAttention decode kernels and extend to any sequence length#28389

Open
qjia7 wants to merge 5 commits into
mainfrom
webgpu-flash-attention-relax-subgroups
Open

webgpu: Fuse FlashAttention decode kernels and extend to any sequence length#28389
qjia7 wants to merge 5 commits into
mainfrom
webgpu-flash-attention-relax-subgroups

Conversation

@qjia7
Copy link
Copy Markdown
Contributor

@qjia7 qjia7 commented May 7, 2026

Summary

  • Extend the FlashAttention decode path to work with any sequence length (not just seq_len=1), with causal masking and use_seqlen_k support for static KV cache
  • Add m_tile optimization to process multiple Q rows per workgroup (m_tile=1/2/4), amortizing K/V loads
  • Fuse the separate QKT and SplitVx shaders into a single QKV kernel using online softmax, eliminating the intermediate qk tensor (B×H×seq×present_seq) and reducing dispatch count from 3 to 2
  • Route between prefill (FlashAttentionProgram, subgroup-based) and split-reduce (fused QKV + VxReduce) paths based on sequence length, subgroup availability, and KV cache size
  • Remove Subgroups requirement from CanApplyFlashAttention, enabling FlashAttention on devices without subgroup support

Resolved Issues

  1. FlashAttention now works without subgroup support. Previously CanApplyFlashAttention required Subgroups feature, blocking devices that lack it. The split-reduce path needs no subgroup intrinsics, so FlashAttention is now available on all WebGPU devices.

  2. Whisper decoding prefill improved from 4.68ms to 1.09ms. Whisper's decoder attention has a small sequence length but large total sequence length (seq_len=4, total_seq_len=1500). The default prefill shader (FlashAttentionProgram) has low parallelism in this case because each workgroup iterates serially over the full KV cache. The split-reduce path tiles the KV dimension across workgroups, achieving much higher GPU occupancy for this workload shape.

Details

Fused QKV kernel: Each workgroup computes QK^T dot products, applies attention bias and causal mask, computes local softmax (per-tile max and sum), normalizes, and multiplies by V — all in one kernel. Per-tile metadata (max, sum) is written for the VxReduce shader to rescale partial outputs using online softmax: output = Σ(partial_i × local_sum_i × exp(local_max_i - global_max)) / global_sum.

Path routing (use_split_reduce): The split-reduce path is used when seq_len ≤ 4, subgroups are unavailable, or seq_len < 64 && total_seq > 1000. Otherwise the single-kernel FlashAttentionProgram prefill path is used.

Test plan

  • 30/30 MHA unit tests pass
  • phi4-graph-prune produces correct output
  • whisper-tiny-int4 produces correct transcription
  • clang-format clean

qjia7 added 4 commits May 7, 2026 10:21
The 3-kernel FlashAttention decode path (QKT, SplitVx, VxReduce) previously
only handled sequence_length=1. This extends it to work with any sequence
length, providing a fallback for devices without Subgroups support.

When Subgroups is available and seq_len>1, the subgroup-based prefill path
is still preferred. The extended decode path activates when Subgroups is
unavailable.

Changes:
- Workgroup layout now includes new_sequence_length dimension in all 3 kernels
- Q offset supports both BSNH and BNSH layouts via q_BNSH template param
- Causal masking (is_unidirectional) for self-attention with seq_len>1
- use_seqlen_k support for static KV cache (past_present_share_buffer)
- Relaxed CanApplyFlashAttention to allow seq_len>1 without Subgroups

Verified: whisper-tiny-int4 correct transcription, phi4-mini correct
generation, 16/16 MHA unit tests pass.
…FlashAttention decode path

Extend the 3-kernel decode path to process multiple Q rows per workgroup
(m_tile=1/2/4) to amortize K/V memory loads for larger sequence lengths.
Remove the Subgroups feature requirement from CanApplyFlashAttention so
the decode path works on all WebGPU devices. The subgroup-based prefill
path is replaced by the extended decode path with m_tile.

Fix causal masking to compute past_sequence_length dynamically from
total_sequence_length - new_sequence_length, which is correct for both
GQA (where past_sequence_length_ is the buffer size) and graph capture
(where total_sequence_length_ is on GPU).
…n decode

Merge the separate QK^T and SplitVx shaders into a fused QKV shader
that computes QK^T, local softmax, and V multiply in one kernel launch,
eliminating the intermediate qk tensor and reducing dispatch count from
3 to 2. The VxReduce shader now rescales partial outputs using per-tile
online softmax metadata (local_max, local_sum).
Use FlashAttentionProgram (single kernel, subgroup-based) for prefill
when sequence is long enough to benefit. Fall back to the split-reduce
path (fused QKV + VxReduce) for short sequences, when subgroups are
unavailable, or when total_sequence_length is large relative to
sequence_length.
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Outdated
@qjia7 qjia7 changed the title webgpu: Extend FlashAttention decode path for any sequence length webgpu: Fuse FlashAttention decode kernels and extend to any sequence length May 12, 2026
@qjia7 qjia7 marked this pull request as ready for review May 12, 2026 10:05
@qjia7 qjia7 requested a review from Copilot May 12, 2026 10:06
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the WebGPU FlashAttention implementation to improve the decode/prefill path by fusing decode shaders (QKᵀ + softmax + V) into a single kernel, adding m_tile to process multiple query rows per workgroup, extending decode beyond seq_len=1, and routing between subgroup-based and non-subgroup paths to broaden device support.

Changes:

  • Replace the decode QKT + SplitVx pipeline with a fused QKV shader + VxReduce (online softmax) pipeline, supporting sequence_length > 1.
  • Add path routing between subgroup-based prefill and split-reduce decode to enable FlashAttention without subgroup support.
  • Extend decode uniforms/shader logic to support causal masking, optional seqlen_k, and m_tile optimization.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Updates program declarations/uniforms: replaces QKT/SplitVx with fused QKV; extends VxReduce options (seqlen_k/head_sink/m_tile).
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Implements fused QKV dispatch, split-reduce routing, new metadata/output shapes, and removes subgroup gating from CanApplyFlashAttention.
onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template Updates reduction shader to rescale partial V outputs using online softmax metadata; adds m_tile, head_sink support, and sequence-length indexing.
onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template Deleted (functionality folded into fused QKV shader).
onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template New fused shader implementing QKᵀ + bias/mask + local softmax + V multiply, emitting per-tile metadata for reduce.
onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template Deleted (replaced by fused QKV shader).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +510 to 516
// When use_seqlen_k is true, total_sequence_length_ may be 0 (actual value is in seqlen_k tensor).
// Use present_sequence_length for tile count calculations; shaders will read the actual value from seqlen_k.
const uint32_t effective_total_seq_len = use_seqlen_k ? present_sequence_length
: static_cast<uint32_t>(parameters.total_sequence_length_);

const uint32_t num_total_seq_length_tile = (effective_total_seq_len + tile_size - 1) / tile_size;
const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size;
Comment on lines +60 to +66
let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile;
// Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile]
let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size;
let q_tile_idx = (workgroup_idx / num_total_seq_length_tile) % num_q_tiles;
let q_base = q_tile_idx * m_tile;
let batch_head_idx = u32(workgroup_idx / (num_total_seq_length_tile * num_q_tiles));
Comment on lines +33 to 38
#if use_seqlen_k
let total_sequence_length = u32(seqlens_k[0]) + 1u;
let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size;
#else
let num_total_seq_length_tile = uniforms.num_total_seq_length_tile;
#endif
@qjia7 qjia7 requested a review from xiaofeihan1 May 13, 2026 01:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants