Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of flash attention for native webgpu ep #22932

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
52656bf
FA Base - Does Not Work
sushraja-msft Nov 11, 2024
ed8bf5d
The new Copy KV Cache works.
sushraja-msft Nov 11, 2024
75aa49d
Add flash attention
sushraja-msft Nov 16, 2024
58157c5
Integrate FA
sushraja-msft Nov 19, 2024
80296aa
Try fix the divide by zero issue
sushraja-msft Nov 20, 2024
c281f84
FA works onn intel (TILE_SIZE == SUBGROUP_SIZE) for seq length of 1.
sushraja-msft Nov 21, 2024
3d25852
Update subgroup_size and tile_size to be actual intel values
sushraja-msft Nov 23, 2024
b19070a
Commit temporarily
sushraja-msft Nov 23, 2024
228b840
Works so far.
sushraja-msft Nov 23, 2024
122e5f9
This works.
sushraja-msft Nov 23, 2024
8b5fcc7
Matches past and should work.
sushraja-msft Nov 23, 2024
ed310de
Multi Q also works
sushraja-msft Nov 23, 2024
ee1051f
Switch back to safer algorithm that is optimized for prefill.
sushraja-msft Nov 23, 2024
d1d8175
Refactor into separate file
sushraja-msft Nov 23, 2024
6febf6c
Get subgroup size from device limits
sushraja-msft Nov 23, 2024
ec59ba5
Switch to the dot product optimized for 1 token length, since we are …
sushraja-msft Nov 23, 2024
0dd5b70
Add an alias to control math precision of flash attention.
sushraja-msft Nov 24, 2024
8bcb47a
Improve comments.
sushraja-msft Nov 24, 2024
832c323
Improve comments
sushraja-msft Nov 24, 2024
a814770
Fix comment about intel subgroup size.
sushraja-msft Nov 24, 2024
ab11009
Fix AttentionBias loading
sushraja-msft Nov 25, 2024
563e662
Add comment explaining Flash Attention
sushraja-msft Nov 26, 2024
ce2031e
Fix lint errors
sushraja-msft Nov 26, 2024
bf1b146
Add option to turn on flash attention via config
sushraja-msft Nov 26, 2024
d1e442e
Fix min value
sushraja-msft Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
424 changes: 424 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Large diffs are not rendered by default.

72 changes: 72 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 15 in onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/flash_attention.h:15: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
public:
CopyKVCacheProgram(const std::string& kernel_name, int components, bool has_past)
: Program{kernel_name}, components_(components), has_past_(has_past) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"vectorized_head_size", ProgramUniformVariableDataType::Uint32});

private:
int components_;
bool has_past_;
};

class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
public:
FlashAttentionProgram(const std::string& kernel_name,

Check warning on line 36 in onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/flash_attention.h:36: Add #include <string> for string [build/include_what_you_use] [4]
bool has_attention_bias,
uint32_t subgroup_size,
uint32_t tile_size,
int qkv_head_size,
int qkv_num_heads)
: Program{kernel_name},
has_attention_bias_(has_attention_bias),
subgroup_size_(subgroup_size),
tile_size_(tile_size),
qkv_head_size_(qkv_head_size),
qkv_num_heads_(qkv_num_heads) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"alpha", ProgramUniformVariableDataType::Float32});

private:
bool has_attention_bias_;
uint32_t subgroup_size_;
uint32_t tile_size_;
int qkv_head_size_;
int qkv_num_heads_;
};

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);

bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
const AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
#include "contrib_ops/webgpu/bert/flash_attention.h"
#include "contrib_ops/webgpu/bert/multihead_attention.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

Expand Down Expand Up @@ -457,6 +458,11 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
Tensor* present_key = context.Output(1, present_shape);
Tensor* present_value = context.Output(2, present_shape);

if (CanApplyFlashAttention(bias, present_key, present_value, parameters, context)) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context);
}

TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads,
parameters.sequence_length, parameters.head_size});
TensorShape q_new_shape(q_new_dims);
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/webgpu/compute_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class ComputeContext {
inline const wgpu::Limits& DeviceLimits() const {
return webgpu_context_.DeviceLimits();
}
inline const uint32_t MinSubgroupSize() const {
return webgpu_context_.MinSubgroupSize();
}
inline const bool IsFlashAttentionEnabled() const {
return webgpu_context_.IsFlashAttentionEnabled();
}

//
// Get the kernel context.
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/webgpu/shader_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ Status ShaderHelper::Init() {
"fn main(@builtin(global_invocation_id) global_id : vec3<u32>,\n"
" @builtin(workgroup_id) workgroup_id : vec3<u32>,\n"
" @builtin(local_invocation_index) local_idx : u32,\n"
" @builtin(local_invocation_id) local_id : vec3<u32>";
" @builtin(local_invocation_id) local_id : vec3<u32>,\n"
" @builtin(subgroup_invocation_id) sg_id : u32,\n"
" @builtin(subgroup_size) sg_size : u32";
if (!is_1d_dispatch) {
body_ss_ << ",\n"
" @builtin(num_workgroups) num_workgroups : vec3<u32>";
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,11 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info
ORT_ENFORCE(Adapter().GetInfo(&adapter_info_));
// cache device limits
wgpu::SupportedLimits device_supported_limits;
wgpu::DawnExperimentalSubgroupLimits subgroup_limits;
device_supported_limits.nextInChain = &subgroup_limits;
ORT_ENFORCE(Device().GetLimits(&device_supported_limits));
device_limits_ = device_supported_limits.limits;
min_subgroup_size_ = subgroup_limits.minSubgroupSize;
sushraja-msft marked this conversation as resolved.
Show resolved Hide resolved

// create buffer manager
buffer_mgr_ = BufferManagerFactory::Create(*this, webgpu_ep_info.storage_buffer_cache_mode, webgpu_ep_info.uniform_buffer_cache_mode, webgpu_ep_info.query_resolve_buffer_cache_mode);
Expand All @@ -135,6 +138,7 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info
} else {
query_type_ = TimestampQueryType::None;
}
is_flash_attention_enabled_ = webgpu_ep_info.enable_flash_attention;
});
}

Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class WebGpuContext final {

const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; }
const wgpu::Limits& DeviceLimits() const { return device_limits_; }
uint32_t MinSubgroupSize() const { return min_subgroup_size_; }
const bool IsFlashAttentionEnabled() const { return is_flash_attention_enabled_; }

const wgpu::CommandEncoder& GetCommandEncoder() {
if (!current_command_encoder_) {
Expand Down Expand Up @@ -161,6 +163,7 @@ class WebGpuContext final {

wgpu::AdapterInfo adapter_info_;
wgpu::Limits device_limits_;
uint32_t min_subgroup_size_ = 0;

wgpu::CommandEncoder current_command_encoder_;
wgpu::ComputePassEncoder current_compute_pass_encoder_;
Expand All @@ -183,6 +186,7 @@ class WebGpuContext final {

uint64_t gpu_timestamp_offset_ = 0;
bool is_profiling_ = false;
bool is_flash_attention_enabled_ = false;
};

} // namespace webgpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ class WebGpuProfiler;
} // namespace webgpu

struct WebGpuExecutionProviderInfo {
WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture)
WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture, bool enable_flash_attention)
: data_layout{data_layout},
enable_graph_capture{enable_graph_capture},
enable_flash_attention{enable_flash_attention},
storage_buffer_cache_mode{},
uniform_buffer_cache_mode{},
query_resolve_buffer_cache_mode{},
Expand All @@ -36,6 +37,7 @@ struct WebGpuExecutionProviderInfo {

DataLayout data_layout;
bool enable_graph_capture;
bool enable_flash_attention;
webgpu::BufferCacheMode storage_buffer_cache_mode;
webgpu::BufferCacheMode uniform_buffer_cache_mode;
webgpu::BufferCacheMode query_resolve_buffer_cache_mode;
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
DataLayout::NHWC,
// graph capture feature is disabled by default
false,
// flash attention feature is disabled by default
false,
};

std::string preferred_layout_str;
Expand Down Expand Up @@ -67,6 +69,18 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
}
LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture;

std::string enable_flash_attention_str;
if (config_options.TryGetConfigEntry(kEnableFlashAttention, enable_flash_attention_str)) {
if (enable_flash_attention_str == kEnableFlashAttention_ON) {
webgpu_ep_info.enable_flash_attention = true;
} else if (enable_flash_attention_str == kEnableFlashAttention_OFF) {
webgpu_ep_info.enable_flash_attention = false;
} else {
ORT_THROW("Invalid enable flash attention: ", enable_flash_attention_str);
}
}
LOGS_DEFAULT(VERBOSE) << "WebGPU EP flash attention enabled: " << webgpu_ep_info.enable_flash_attention;

auto parse_buffer_cache_mode = [&config_options](const std::string& config_entry_str,
webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode {
std::string buffer_cache_mode_str;
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace options {

constexpr const char* kPreferredLayout = "WebGPU:preferredLayout";
constexpr const char* kEnableGraphCapture = "WebGPU:enableGraphCapture";
constexpr const char* kEnableFlashAttention = "WebGPU:enableFlashAttention";

constexpr const char* kDawnProcTable = "WebGPU:dawnProcTable";

Expand All @@ -36,6 +37,9 @@ constexpr const char* kPreferredLayout_NHWC = "NHWC";
constexpr const char* kEnableGraphCapture_ON = "1";
constexpr const char* kEnableGraphCapture_OFF = "0";

constexpr const char* kEnableFlashAttention_ON = "1";
constexpr const char* kEnableFlashAttention_OFF = "0";

constexpr const char* kBufferCacheMode_Disabled = "disabled";
constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease";
constexpr const char* kBufferCacheMode_Simple = "simple";
Expand Down
Loading