Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
424 changes: 424 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Large diffs are not rendered by default.

73 changes: 73 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "contrib_ops/webgpu/bert/attention_common.h"
#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 16 in onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/flash_attention.h:16: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
public:
CopyKVCacheProgram(const std::string& kernel_name, int components, bool has_past)
: Program{kernel_name}, components_(components), has_past_(has_past) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"vectorized_head_size", ProgramUniformVariableDataType::Uint32});

private:
int components_;
bool has_past_;
};

class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
public:
FlashAttentionProgram(const std::string& kernel_name,

Check warning on line 37 in onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/flash_attention.h:37: Add #include <string> for string [build/include_what_you_use] [4]
bool has_attention_bias,
uint32_t subgroup_size,
uint32_t tile_size,
int qkv_head_size,
int qkv_num_heads)
: Program{kernel_name},
has_attention_bias_(has_attention_bias),
subgroup_size_(subgroup_size),
tile_size_(tile_size),
qkv_head_size_(qkv_head_size),
qkv_num_heads_(qkv_num_heads) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"alpha", ProgramUniformVariableDataType::Float32});

private:
bool has_attention_bias_;
uint32_t subgroup_size_;
uint32_t tile_size_;
int qkv_head_size_;
int qkv_num_heads_;
};

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);

bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
8 changes: 7 additions & 1 deletion onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
#include "contrib_ops/webgpu/bert/attention_common.h"
#include "contrib_ops/webgpu/bert/flash_attention.h"
#include "contrib_ops/webgpu/bert/multihead_attention.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

Expand Down Expand Up @@ -74,8 +75,13 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
Tensor* present_key = context.Output(1, present_shape);
Tensor* present_value = context.Output(2, present_shape);

if (CanApplyFlashAttention(bias, present_key, present_value, parameters, context)) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context);
}

TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_,
parameters.sequence_length_, parameters.head_size_});
parameters.kv_sequence_length_, parameters.head_size_});
TensorShape q_new_shape(q_new_dims);
Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape);
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/webgpu/compute_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class ComputeContext {
inline const wgpu::Limits& DeviceLimits() const {
return webgpu_context_.DeviceLimits();
}
inline const uint32_t MinSubgroupSize() const {
return webgpu_context_.MinSubgroupSize();
}
inline const bool IsFlashAttentionEnabled() const {
return webgpu_context_.IsFlashAttentionEnabled();
}

//
// Get the kernel context.
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/webgpu/shader_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ Status ShaderHelper::Init() {
"fn main(@builtin(global_invocation_id) global_id : vec3<u32>,\n"
" @builtin(workgroup_id) workgroup_id : vec3<u32>,\n"
" @builtin(local_invocation_index) local_idx : u32,\n"
" @builtin(local_invocation_id) local_id : vec3<u32>";
" @builtin(local_invocation_id) local_id : vec3<u32>,\n"
" @builtin(subgroup_invocation_id) sg_id : u32,\n"
" @builtin(subgroup_size) sg_size : u32";
if (!is_1d_dispatch) {
body_ss_ << ",\n"
" @builtin(num_workgroups) num_workgroups : vec3<u32>";
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,11 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info
ORT_ENFORCE(Adapter().GetInfo(&adapter_info_));
// cache device limits
wgpu::SupportedLimits device_supported_limits;
wgpu::DawnExperimentalSubgroupLimits subgroup_limits;
device_supported_limits.nextInChain = &subgroup_limits;
ORT_ENFORCE(Device().GetLimits(&device_supported_limits));
device_limits_ = device_supported_limits.limits;
min_subgroup_size_ = subgroup_limits.minSubgroupSize;
Comment on lines +121 to +125
Copy link
Contributor

@fs-eire fs-eire Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected behavior when subgroup is not supported?

I didn't test this code but it looks like it will abort. Would be better to disable features that use subgroup if it's not available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, on devices where subgroup is not supported, the minSubgroupSize will be 0 - AFAIK. This is because
https://github.com/google/dawn/blob/5d28e25927778b028473c4aa7af11fd5a5c9f76b/src/dawn/wire/server/ServerAdapter.cpp#L129 - Inits the struct to 0.

https://github.com/google/dawn/blob/5d28e25927778b028473c4aa7af11fd5a5c9f76b/src/dawn/wire/client/LimitsAndFeatures.cpp#L53 - reads from the struct when we call GetLimits. If no one set the Limits for this feature it will remain as 0.

Later in CanApplyFlashAttention in flash_attention.cc I only support FlashAttention if we have >= 16 subgroups. So I expect no crashes, just that FlashAttention will not be activated on those machines without subgroup support.

I dont have a machine without subgroups to test this though.


// create buffer manager
buffer_mgr_ = BufferManagerFactory::Create(*this, webgpu_ep_info.storage_buffer_cache_mode, webgpu_ep_info.uniform_buffer_cache_mode, webgpu_ep_info.query_resolve_buffer_cache_mode);
Expand All @@ -135,6 +138,7 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info
} else {
query_type_ = TimestampQueryType::None;
}
is_flash_attention_enabled_ = webgpu_ep_info.enable_flash_attention;
});
}

Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class WebGpuContext final {

const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; }
const wgpu::Limits& DeviceLimits() const { return device_limits_; }
uint32_t MinSubgroupSize() const { return min_subgroup_size_; }
const bool IsFlashAttentionEnabled() const { return is_flash_attention_enabled_; }

const wgpu::CommandEncoder& GetCommandEncoder() {
if (!current_command_encoder_) {
Expand Down Expand Up @@ -161,6 +163,7 @@ class WebGpuContext final {

wgpu::AdapterInfo adapter_info_;
wgpu::Limits device_limits_;
uint32_t min_subgroup_size_ = 0;

wgpu::CommandEncoder current_command_encoder_;
wgpu::ComputePassEncoder current_compute_pass_encoder_;
Expand All @@ -183,6 +186,7 @@ class WebGpuContext final {

uint64_t gpu_timestamp_offset_ = 0;
bool is_profiling_ = false;
bool is_flash_attention_enabled_ = false;
};

} // namespace webgpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ class WebGpuProfiler;
} // namespace webgpu

struct WebGpuExecutionProviderInfo {
WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture)
WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture, bool enable_flash_attention)
: data_layout{data_layout},
enable_graph_capture{enable_graph_capture},
enable_flash_attention{enable_flash_attention},
storage_buffer_cache_mode{},
uniform_buffer_cache_mode{},
query_resolve_buffer_cache_mode{},
Expand All @@ -36,6 +37,7 @@ struct WebGpuExecutionProviderInfo {

DataLayout data_layout;
bool enable_graph_capture;
bool enable_flash_attention;
webgpu::BufferCacheMode storage_buffer_cache_mode;
webgpu::BufferCacheMode uniform_buffer_cache_mode;
webgpu::BufferCacheMode query_resolve_buffer_cache_mode;
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
DataLayout::NHWC,
// graph capture feature is disabled by default
false,
// flash attention feature is disabled by default
false,
};

std::string preferred_layout_str;
Expand Down Expand Up @@ -67,6 +69,18 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
}
LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture;

std::string enable_flash_attention_str;
if (config_options.TryGetConfigEntry(kEnableFlashAttention, enable_flash_attention_str)) {
if (enable_flash_attention_str == kEnableFlashAttention_ON) {
webgpu_ep_info.enable_flash_attention = true;
} else if (enable_flash_attention_str == kEnableFlashAttention_OFF) {
webgpu_ep_info.enable_flash_attention = false;
} else {
ORT_THROW("Invalid enable flash attention: ", enable_flash_attention_str);
}
}
LOGS_DEFAULT(VERBOSE) << "WebGPU EP flash attention enabled: " << webgpu_ep_info.enable_flash_attention;

auto parse_buffer_cache_mode = [&config_options](const std::string& config_entry_str,
webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode {
std::string buffer_cache_mode_str;
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace options {

constexpr const char* kPreferredLayout = "WebGPU:preferredLayout";
constexpr const char* kEnableGraphCapture = "WebGPU:enableGraphCapture";
constexpr const char* kEnableFlashAttention = "WebGPU:enableFlashAttention";

constexpr const char* kDawnProcTable = "WebGPU:dawnProcTable";

Expand All @@ -36,6 +37,9 @@ constexpr const char* kPreferredLayout_NHWC = "NHWC";
constexpr const char* kEnableGraphCapture_ON = "1";
constexpr const char* kEnableGraphCapture_OFF = "0";

constexpr const char* kEnableFlashAttention_ON = "1";
constexpr const char* kEnableFlashAttention_OFF = "0";

constexpr const char* kBufferCacheMode_Disabled = "disabled";
constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease";
constexpr const char* kBufferCacheMode_Simple = "simple";
Expand Down
Loading