-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Implementation of flash attention for native webgpu ep #22932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "contrib_ops/webgpu/bert/attention_common.h" | ||
#include "core/providers/webgpu/compute_context.h" | ||
#include "core/providers/webgpu/program.h" | ||
#include "core/providers/webgpu/shader_helper.h" | ||
#include "core/providers/webgpu/webgpu_kernel.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace webgpu { | ||
|
||
using namespace onnxruntime::webgpu; | ||
Check warning on line 16 in onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
|
||
|
||
class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> { | ||
public: | ||
CopyKVCacheProgram(const std::string& kernel_name, int components, bool has_past) | ||
: Program{kernel_name}, components_(components), has_past_(has_past) { | ||
} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"past_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"vectorized_head_size", ProgramUniformVariableDataType::Uint32}); | ||
|
||
private: | ||
int components_; | ||
bool has_past_; | ||
}; | ||
|
||
class FlashAttentionProgram final : public Program<FlashAttentionProgram> { | ||
public: | ||
FlashAttentionProgram(const std::string& kernel_name, | ||
Check warning on line 37 in onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
|
||
bool has_attention_bias, | ||
uint32_t subgroup_size, | ||
uint32_t tile_size, | ||
int qkv_head_size, | ||
int qkv_num_heads) | ||
: Program{kernel_name}, | ||
has_attention_bias_(has_attention_bias), | ||
subgroup_size_(subgroup_size), | ||
tile_size_(tile_size), | ||
qkv_head_size_(qkv_head_size), | ||
qkv_num_heads_(qkv_num_heads) { | ||
} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"present_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"alpha", ProgramUniformVariableDataType::Float32}); | ||
|
||
private: | ||
bool has_attention_bias_; | ||
uint32_t subgroup_size_; | ||
uint32_t tile_size_; | ||
int qkv_head_size_; | ||
int qkv_num_heads_; | ||
}; | ||
|
||
Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, | ||
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, | ||
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); | ||
|
||
bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value, | ||
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); | ||
} // namespace webgpu | ||
} // namespace contrib | ||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,8 +118,11 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info | |
ORT_ENFORCE(Adapter().GetInfo(&adapter_info_)); | ||
// cache device limits | ||
wgpu::SupportedLimits device_supported_limits; | ||
wgpu::DawnExperimentalSubgroupLimits subgroup_limits; | ||
device_supported_limits.nextInChain = &subgroup_limits; | ||
ORT_ENFORCE(Device().GetLimits(&device_supported_limits)); | ||
device_limits_ = device_supported_limits.limits; | ||
min_subgroup_size_ = subgroup_limits.minSubgroupSize; | ||
Comment on lines
+121
to
+125
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the expected behavior when subgroup is not supported? I didn't test this code but it looks like it will abort. Would be better to disable features that use subgroup if it's not available. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hello, on devices where subgroup is not supported, the minSubgroupSize will be 0 - AFAIK. This is because https://github.com/google/dawn/blob/5d28e25927778b028473c4aa7af11fd5a5c9f76b/src/dawn/wire/client/LimitsAndFeatures.cpp#L53 - reads from the struct when we call GetLimits. If no one set the Limits for this feature it will remain as 0. Later in CanApplyFlashAttention in flash_attention.cc I only support FlashAttention if we have >= 16 subgroups. So I expect no crashes, just that FlashAttention will not be activated on those machines without subgroup support. I dont have a machine without subgroups to test this though. |
||
|
||
// create buffer manager | ||
buffer_mgr_ = BufferManagerFactory::Create(*this, webgpu_ep_info.storage_buffer_cache_mode, webgpu_ep_info.uniform_buffer_cache_mode, webgpu_ep_info.query_resolve_buffer_cache_mode); | ||
|
@@ -135,6 +138,7 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info | |
} else { | ||
query_type_ = TimestampQueryType::None; | ||
} | ||
is_flash_attention_enabled_ = webgpu_ep_info.enable_flash_attention; | ||
}); | ||
} | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.