Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of flash attention for native webgpu ep #22932

Open
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

sushraja-msft
Copy link

@sushraja-msft sushraja-msft commented Nov 24, 2024

Description

This change implements flash attention in webgpu, to improve prefill speed.
Perf numbers from Intel Alderlake device

Baseline MHA

Batch size: 1, prompt tokens: 501, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       2.26746e+07
        avg (tokens/s): 22.0952              <<<
        p50 (us):       2.34637e+07
        stddev (us):    3.92912e+06
        n:              5 * 501 token(s)
Token generation:
        avg (us):       96519.8
        avg (tokens/s): 10.3606              <<<
        p50 (us):       98061.5
        stddev (us):    9220.87
        n:              635 * 1 token(s)

With FA

Batch size: 1, prompt tokens: 501, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       1.69236e+07
        avg (tokens/s): 29.6036             <<<
        p50 (us):       1.63162e+07
        stddev (us):    960417
        n:              5 * 501 token(s)
Token generation:
        avg (us):       91436.7
        avg (tokens/s): 10.9365             <<<
        p50 (us):       90397.1
        stddev (us):    5349.19
        n:              635 * 1 token(s)

Motivation and Context

On integrated GPUs memory bandwidth is premium, Flash attention makes softmax computation (and therefore output attention vector computation) a running operation instead of maintaining full QKt attention scores in memory. As a result, we see significant improvements in prefill speed - 30% speed up measured here.

This implementation also uses new webgpu feature subgroups to further accelerate attention computation.

  • Tested on Intel Alderlake (Subgroup Size 16) with Phi 3.5 mini.
  • Tested on Nvidia 2060 (Subgroup Size 32) with Phi 3.5 mini.
  • Tested with Lama 3.2 1B parameters, FlashAttention does not activate because past/present keys are always null. Needs investigation into the model to understand why this is the case.

Remaining work

  • Investigate and fix observed regression in model context length. That is though responses are coherent, with this implementation of FlashAttention, model seems to miss details from the prefill content. Responses for simple prompts without long context are identical to MHA.
  • Algorithm specialization for generation phase, here memory tiles for K/V can be removed because each K/V values are used just once creating more Shared memory space for larger tile size.
  • Algorithm specialization for no past KV case (prefill case). The CopyKVCache operation can likely be eliminated in this case, as there is no past KV values to copy over, new KV values can be copied to present KV as part of flash attention. PIX profiling shows CopyKVCache is almost as expensive as FlashAttention implementation. StaticKV cache will also eliminate this and result in more performance wins.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant