From 52656bffb3094a6bc22f66d8a4d35560c741f01d Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 11 Nov 2024 10:30:03 -0800 Subject: [PATCH 01/25] FA Base - Does Not Work --- .../webgpu/bert/multihead_attention.cc | 291 +++++++++++++++++- .../webgpu/bert/multihead_attention.h | 55 +++- .../core/providers/webgpu/shader_helper.cc | 4 +- 3 files changed, 331 insertions(+), 19 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 5583f296fae42..d3007e7ddd889 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include +#include +#include + #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/webgpu/bert/multihead_attention.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" @@ -89,7 +94,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { } shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - if (has_present_key_) { + if (has_present_key_ && !fa_variant_) { shader.AddOutput("present_key", ShaderUsage::UseUniform); } @@ -132,7 +137,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << " tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];\n"; } - if (has_present_key_) { + if (has_present_key_ && !fa_variant_) { shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"; } @@ -159,9 +164,75 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } +Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { + // This shader works only for a limited case of current_seq_len = 1, that is the generation phase of an LLM. + // Expectations are + // qkv have same number of heads and hidden dimension (head size). + // qkv are in BSNH format. + // B - batch size but shader only supports batch_size 1. + // S - current sequence length but shader supports only S = 1. + // N - number of heads. + // H - head size or hidden dimension for each qkv head. + // KV cache is stored as BN(total_sequence_length)H + // Attention bias is in BN(total_sequence_length) + // + // Expectation is that present_key, and present_value contain past key and values since + // we are out of storage buffers a shader can have and both past/present cant be passed. + + shader.AddInput("k", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("past_k", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("past_v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddOutput("present_key", ShaderUsage::UseUniform); + shader.AddOutput("present_value", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << "let head_index = local_id.x;\n" + << "let k_index = workgroup_id.x;\n" + << "if (k_index < uniforms.past_sequence_length) {\n" + << " let present_offset = head_index * uniforms.total_sequence_length * uniforms.qkv_hidden_size + k_index * uniforms.qkv_hidden_size;\n" + << " let past_offset = head_index * uniforms.past_sequence_length * uniforms.qkv_hidden_size + k_index * uniforms.qkv_hidden_size;\n" + << " for (var i=0u; i(parameters.head_size)}, + {static_cast(past_sequence_length)}, + {static_cast(total_sequence_length)}}); + + return context.RunProgram(program); +} + Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, - AttentionParameters& parameters, int past_sequence_length, int total_sequence_length) { + AttentionParameters& parameters, int past_sequence_length, int total_sequence_length, bool fa_variant = false) { const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; @@ -172,7 +243,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components}; + components, fa_variant}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -182,7 +253,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); - if (has_present_key) { + if (has_present_key && !fa_variant) { program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); } @@ -191,7 +262,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length + tile_size - 1) / tile_size, parameters.batch_size * parameters.num_heads) .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size)) + .CacheHint(std::to_string(tile_size)+std::to_string(fa_variant)) .AddUniformVariables({{static_cast(parameters.sequence_length)}, {static_cast(vectorized_head_size)}, {static_cast(total_sequence_length)}, @@ -274,7 +345,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } shader.AddOutput("output", ShaderUsage::UseUniform); - if (has_present_value_) { + if (has_present_value_ && !fa_variant_) { shader.AddOutput("present_value", ShaderUsage::UseUniform); } @@ -315,7 +386,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << " tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];\n"; } - if (has_present_value_) { + if (has_present_value_ && !fa_variant_) { shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"; } @@ -347,19 +418,20 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int Tensor* present_value, AttentionParameters& parameters, int past_sequence_length, - int total_sequence_length) { + int total_sequence_length, + bool fa_variant) { const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0; const bool has_present_value = output_count > 1 && past_value != nullptr; constexpr int tile_size = 12; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, fa_variant}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); - if (has_present_value) { + if (has_present_value && !fa_variant) { program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank}); } @@ -382,7 +454,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, bool fa_variant = false) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length : 0; const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length; @@ -392,17 +464,201 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T const TensorShape probs_shape(probs_dims); Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, - parameters, past_sequence_length, total_sequence_length)); + parameters, past_sequence_length, total_sequence_length, fa_variant)); ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, parameters.batch_size * parameters.num_heads * parameters.sequence_length, total_sequence_length)); ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, - parameters, past_sequence_length, total_sequence_length)); + parameters, past_sequence_length, total_sequence_length, fa_variant)); + + return Status::OK(); +} + +Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { + // This shader works only for a limited case of current_seq_len = 1, that is the generation phase of an LLM. + // Expectations are + // qkv have same number of heads and hidden dimension (head size). + // qkv are in BSNH format. + // B - batch size but shader only supports batch_size 1. + // S - current sequence length but shader supports only S = 1. + // N - number of heads. + // H - head size or hidden dimension for each qkv head. + // KV cache is stored as BN(total_sequence_length)H + // Attention bias is in BN(total_sequence_length) + // + // Expectation is that present_key, and present_value contain past key and values since + // we are out of storage buffers a shader can have and both past/present cant be passed. + shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("k", ShaderUsage::UseUniform); + shader.AddInput("v", ShaderUsage::UseUniform); + if (has_attention_bias_) { + shader.AddInput("attention_bias", ShaderUsage::UseUniform); + } + shader.AddOutput("present_key", ShaderUsage::UseUniform); + shader.AddOutput("present_value", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform); + + uint32_t qkv_hidden_size_in_components = qkv_hidden_size_ / components_; + shader.AdditionalImplementation() << "var x: array;\n" + << "var q_shm: array;\n" + << "var subgroup_scratch: array;\n" + << "var local_max_x: q_element_t;\n" + << "var prev_max_x: q_element_t;\n" + << "var denom_x: q_element_t;\n" + << "var denom_ratio: q_element_t;\n" + << "var attention_scale: q_element_t;\n"; + + // First fill the shared memory with the q values + shader.MainFunctionBody() << "let qkv_hidden_size_in_components:u32 = " << qkv_hidden_size_in_components << ";\n" + << "let n_qkv_hidden_size_in_components:u32 = " << qkv_num_heads_ * qkv_hidden_size_in_components << ";\n" + << "let total_sequence_length:u32 = uniforms.past_sequence_length +1;\n" + << "let tile_size:u32 = TILE_SIZE;\n" + << "var current_head:u32 = workgroup_id.x;\n" + << "var q_index:u32 = local_id.x;\n" + << "var head_offset:u32 = current_head*qkv_hidden_size_in_components;\n" + << "attention_scale = q_element_t(uniforms.attention_scale);\n" + << "let key_offset = uniforms.past_sequence_length * n_qkv_hidden_size_in_components + current_head * qkv_hidden_size_in_components;\n" + << "let value_offset = uniforms.past_sequence_length * "<< qkv_num_heads_ * qkv_hidden_size_ + << " + current_head * " << qkv_hidden_size_ << " ;\n" + << "while (q_index < qkv_hidden_size_in_components) {\n" + << " q_shm[q_index] = q[head_offset + q_index];\n" + << " output[head_offset + q_index] = 0;\n" + << " present_key[key_offset + q_index] = k[q_index];\n" + << " present_value[value_offset + q_index] = v[q_index];\n"; + + if (components_ == 4) { + shader.MainFunctionBody() << " present_value[value_offset + q_index + 1] = v[q_index+1];\n" + << " present_value[value_offset + q_index + 2] = v[q_index+2];\n" + << " present_value[value_offset + q_index + 3] = v[q_index+3];\n"; + } + if (components_ == 2) { + shader.MainFunctionBody() << " present_value[value_offset + q_index + 1] = v[q_index+1];\n"; + } + + shader.MainFunctionBody() << " q_index += tile_size;\n" + << "}\n" + << "if (local_id.x == 0) {\n" + << " prev_max_x = q_element_t(-65504h);\n" + << "}\n" + << "workgroupBarrier();\n"; + + // Go through all the tiles + shader.MainFunctionBody() << "for (var seq_begin = 0u; seq_begin <= uniforms.past_sequence_length; seq_begin += tile_size) {\n"; + + // Compute QKt, assumption is that K is stored as transposed. + shader.MainFunctionBody() << " var k_index:u32 = seq_begin + local_id.x;\n" + << " if (sg_id == 0) {\n" + << " subgroup_scratch[u32(local_id.x/sg_size)] = q_element_t(-65504h);\n" + << " }\n" + << " var x_value_single:q_element_t = 0;" + << " if (k_index < total_sequence_length) {\n" + << " var x_value: q_value_t = q_value_t(0);\n" + << " let key_offset = k_index * n_qkv_hidden_size_in_components + current_head * qkv_hidden_size_in_components;\n" + << " for (var i = 0u; i < qkv_hidden_size_in_components; i++) {\n" + << " x_value += q_shm[i] * present_key[i+key_offset];\n" + << " }\n" + << " x_value_single = " << (components_ == 4 ? "x_value.x + x_value.y + x_value.z + x_value.w" : (components_ == 2 ? "x_value.x + x_value.y" : "x_value")) << ";\n" + << " x_value_single = x_value_single * attention_scale;\n"; + if (has_attention_bias_) { + shader.MainFunctionBody() << " let attention_bias_index = k_index + current_head * (uniforms.past_sequence_length+1);\n" + << " x_value_single = x_value_single + attention_bias[attention_bias_index];\n"; + } + shader.MainFunctionBody() << " var sub_group_max_value = subgroupMax(x_value_single);\n" + << " if (sg_id == 0) {\n" + << " subgroup_scratch[u32(local_id.x/sg_size)] = sub_group_max_value;\n" + << " }\n" + << " }\n"; + + // Update phase to merge with previous tile results. + shader.MainFunctionBody() << " workgroupBarrier();\n" + << " if (local_id.x == 0) {\n" + << " local_max_x = prev_max_x;\n" + << " for(var i = 0u; i < u32(TILE_SIZE/sg_size); i++) {\n" + << " local_max_x = max(local_max_x, subgroup_scratch[i]);\n" + << " }\n" + << " }\n" + << " workgroupBarrier();\n" + << " if (k_index < total_sequence_length) {\n" + << " x_value_single = exp(x_value_single - local_max_x);\n" + << " }\n" + << " if (sg_id == 0) {\n" + << " subgroup_scratch[u32(local_id.x/sg_size)] = 0;\n" + << " }\n" + << " workgroupBarrier();\n" + << " var sub_group_sum_value = subgroupAdd(x_value_single);\n" + << " if (sg_id == 0) {\n" + << " subgroup_scratch[u32(local_id.x/sg_size)] = sub_group_sum_value;\n" + << " }\n" + << " workgroupBarrier();\n" + << " if (local_id.x == 0) {\n" + << " var sum_x:q_element_t = 0;\n" + << " for(var i = 0u; i < u32(TILE_SIZE/sg_size); i++) {\n" + << " sum_x = sum_x + subgroup_scratch[i];\n" + << " }\n" + << " var new_denom_x_first_term:q_element_t = denom_x * exp(prev_max_x - local_max_x);\n" + << " denom_x = new_denom_x_first_term + sum_x;\n" + << " denom_ratio = new_denom_x_first_term / denom_x;\n" + << " prev_max_x = local_max_x;\n" + << " }\n" + << " workgroupBarrier();\n" + << " x[local_id.x] = x_value_single / denom_x;\n"; + + // Update O, we are going to switch to parallalism in v_hidden_size dimension, each thread is going to be + // responsible for a single hidden dimension of v. + shader.MainFunctionBody() << " workgroupBarrier();\n" + << " let v_start = seq_begin *" << qkv_hidden_size_ << " + current_head * total_sequence_length *" << qkv_hidden_size_ << ";\n" + << " for (var i = local_id.x; i < " << qkv_hidden_size_ << "; i+=tile_size) {\n" + << " var sum:q_element_t = 0.0;\n" + << " for (var t = 0u; t < tile_size ; t++) {\n" + << " if (seq_begin + t < total_sequence_length) {\n" + << " var v_index:u32 = v_start + t * " << qkv_hidden_size_ << " + i;\n" + << " sum += (x[t] * present_value[v_index]);\n" + << " }\n" + << " }\n" + << " output[i] = output[i] * denom_ratio + sum;\n" + << " }\n"; + + shader.MainFunctionBody() << "}\n"; return Status::OK(); } +Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); + return ApplyAttention(Q, K, V, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context, true); + // constexpr int flash_attention_tile_length = 64; + // const float attention_scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) + // : parameters.scale; + // const bool has_attention_bias = attention_bias != nullptr; + // const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); + // FlashAttentionProgram program{"FlashAttentionProgram", has_attention_bias, flash_attention_tile_length, + // parameters.v_hidden_size, parameters.head_size, components}; + + // program.AddInputs({{Q, ProgramTensorMetadataDependency::Type, components}, + // {K, ProgramTensorMetadataDependency::Type, components}, + // {V, ProgramTensorMetadataDependency::Type}}); + // if (has_attention_bias) { + // program.AddInput({attention_bias, ProgramTensorMetadataDependency::Type}); + // } + + // program.AddOutput({present_key, ProgramTensorMetadataDependency::Type, components}); + // program.AddOutput({present_value, ProgramTensorMetadataDependency::Type}); + // program.AddOutput({output, ProgramTensorMetadataDependency::Type}); + + // program.SetDispatchGroupSize(parameters.num_heads) + // .SetWorkgroupSize(flash_attention_tile_length) + // .CacheHint(std::to_string(has_attention_bias)+std::to_string(flash_attention_tile_length)) + // .AddUniformVariables({{static_cast(attention_scale)}, + // {static_cast(parameters.past_sequence_length)}}) + // .SetOverridableConstants({{static_cast(flash_attention_tile_length)}}); + + // return context.RunProgram(program); +} + MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : WebGpuKernel(info) { int64_t num_heads = 0; @@ -457,6 +713,13 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_key = context.Output(1, present_shape); Tensor* present_value = context.Output(2, present_shape); + if (parameters.sequence_length == 1 && bias == nullptr && parameters.kv_sequence_length == 1 && + present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && + present_value->SizeInBytes() > 0) { + return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, + present_value, parameters, context); + } + TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, parameters.sequence_length, parameters.head_size}); TensorShape q_new_shape(q_new_dims); diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index 36803e3027b4c..56d23b00badec 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -33,8 +33,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components) { + bool has_attention_bias, int tile_size, int components, bool fa_variant = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), fa_variant_(fa_variant) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -55,6 +55,7 @@ class AttentionProbsProgram final : public Program { bool has_attention_bias_; int tile_size_; int components_; + bool fa_variant_; }; class InPlaceSoftmaxProgram final : public Program { @@ -74,10 +75,27 @@ class InPlaceSoftmaxProgram final : public Program { int components_; }; +class CopyKVCacheProgram final : public Program { + public: + CopyKVCacheProgram(const std::string& kernel_name) + : Program{kernel_name} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"qkv_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"total_sequence_length", ProgramUniformVariableDataType::Uint32}); + + private: + int work_group_size_; + int components_; +}; + class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool fa_variant = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), fa_variant_(fa_variant) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -96,6 +114,35 @@ class VxAttentionScoreProgram final : public Program { bool feed_past_value_; bool has_present_value_; int tile_size_; + bool fa_variant_; +}; + +class FlashAttentionProgram final : public Program { + public: + FlashAttentionProgram(const std::string& kernel_name, bool has_attention_bias, + int tile_size, int qkv_hidden_size, int qkv_num_heads, int components) + : Program{kernel_name}, + has_attention_bias_(has_attention_bias), + tile_size_(tile_size), + qkv_hidden_size_(qkv_hidden_size), + qkv_num_heads_(qkv_num_heads), + components_(components) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"attention_scale", ProgramUniformVariableDataType::Float32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"component_indexes_per_token", ProgramUniformVariableDataType::Uint32}); + + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + + private: + bool has_attention_bias_; + int tile_size_; + int qkv_hidden_size_; + int qkv_num_heads_; + int components_; }; class MultiHeadAttention final : public WebGpuKernel { diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 5685494556248..2215cc2a7cd20 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -62,7 +62,9 @@ Status ShaderHelper::Init() { "fn main(@builtin(global_invocation_id) global_id : vec3,\n" " @builtin(workgroup_id) workgroup_id : vec3,\n" " @builtin(local_invocation_index) local_idx : u32,\n" - " @builtin(local_invocation_id) local_id : vec3"; + " @builtin(local_invocation_id) local_id : vec3,\n" + " @builtin(subgroup_invocation_id) sg_id : u32,\n" + " @builtin(subgroup_size) sg_size : u32"; if (!is_1d_dispatch) { body_ss_ << ",\n" " @builtin(num_workgroups) num_workgroups : vec3"; From ed8bf5ddc02dce438c386eb959dc04d2459e665e Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 11 Nov 2024 13:32:19 -0800 Subject: [PATCH 02/25] The new Copy KV Cache works. --- .../webgpu/bert/multihead_attention.cc | 158 +++++++----------- .../webgpu/bert/multihead_attention.h | 34 ++-- 2 files changed, 81 insertions(+), 111 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index d3007e7ddd889..1c3f5835093bb 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -164,72 +164,6 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { - // This shader works only for a limited case of current_seq_len = 1, that is the generation phase of an LLM. - // Expectations are - // qkv have same number of heads and hidden dimension (head size). - // qkv are in BSNH format. - // B - batch size but shader only supports batch_size 1. - // S - current sequence length but shader supports only S = 1. - // N - number of heads. - // H - head size or hidden dimension for each qkv head. - // KV cache is stored as BN(total_sequence_length)H - // Attention bias is in BN(total_sequence_length) - // - // Expectation is that present_key, and present_value contain past key and values since - // we are out of storage buffers a shader can have and both past/present cant be passed. - - shader.AddInput("k", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("past_k", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("past_v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddOutput("present_key", ShaderUsage::UseUniform); - shader.AddOutput("present_value", ShaderUsage::UseUniform); - - shader.MainFunctionBody() << "let head_index = local_id.x;\n" - << "let k_index = workgroup_id.x;\n" - << "if (k_index < uniforms.past_sequence_length) {\n" - << " let present_offset = head_index * uniforms.total_sequence_length * uniforms.qkv_hidden_size + k_index * uniforms.qkv_hidden_size;\n" - << " let past_offset = head_index * uniforms.past_sequence_length * uniforms.qkv_hidden_size + k_index * uniforms.qkv_hidden_size;\n" - << " for (var i=0u; i(parameters.head_size)}, - {static_cast(past_sequence_length)}, - {static_cast(total_sequence_length)}}); - - return context.RunProgram(program); -} - Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, AttentionParameters& parameters, int past_sequence_length, int total_sequence_length, bool fa_variant = false) { @@ -475,6 +409,71 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T return Status::OK(); } +Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Expectations are + // qkv have same number of heads and hidden dimension (head size). + // qkv are in BSNH format. + // B - batch size but shader only supports batch_size 1. + // S - current sequence length but shader supports only S = 1. + // N - number of heads. + // H - head size or hidden dimension for each qkv head. + // KV cache is stored as BN(total_sequence_length)H + // Attention bias is in BN(total_sequence_length) + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddOutput("present_key", ShaderUsage::UseUniform); + shader.AddOutput("present_value", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" + << "let kIdx = workgroup_id.x;\n" + << "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n" + << "if (kIdx < uniforms.past_sequence_length) {\n" + << " let pastKeyOffset = headIdx * uniforms.past_sequence_length * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n" + << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" + << " present_key[presentKeyOffset+w] = past_key[pastKeyOffset+w];\n" + << " present_value[presentKeyOffset+w] = past_value[pastKeyOffset+w];\n" + << " }\n" + << "}\n" + << "else if (kIdx >= uniforms.past_sequence_length) {\n" + << " let nkIdx = kIdx - uniforms.past_sequence_length;\n" + << " // Assumes kv have BSNH layout. num_workgroups.z is the num_head as per the dispatch requirement.\n" + << " let nOffset = nkIdx * uniforms.vectorized_head_size * num_workgroups.z + headIdx*uniforms.vectorized_head_size;\n" + << " // Assumes kv have BNSH layout.\n" + << " // let nOffset = headIdx * uniforms.kv_sequence_length * uniforms.vectorized_head_size + nkIdx * uniforms.vectorized_head_size;\n" + << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" + << " present_key[presentKeyOffset+w] = key[nOffset+w];\n" + << " present_value[presentKeyOffset+w] = value[nOffset+w];\n" + << " }\n" + << "}\n"; + + return Status::OK(); +} + +Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, AttentionParameters& parameters, + const Tensor* K, const Tensor* past_key, Tensor* present_key, + const Tensor* V, const Tensor* past_value, Tensor* present_value, + int past_sequence_length, int total_sequence_length) { + + const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); + CopyKVCacheProgram program{"CopyKVCache", components}; + program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, + {V, ProgramTensorMetadataDependency::TypeAndRank, components}, + {past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, + {past_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); + program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components}, + {present_value, ProgramTensorMetadataDependency::Rank, components}}); + + program.SetDispatchGroupSize(total_sequence_length, 1, parameters.num_heads) + .SetWorkgroupSize(1) + .AddUniformVariables({{static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}, + {static_cast(parameters.head_size/ components)},}); + + return context.RunProgram(program); +} + Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // This shader works only for a limited case of current_seq_len = 1, that is the generation phase of an LLM. // Expectations are @@ -630,33 +629,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); return ApplyAttention(Q, K, V, attention_bias, past_key, past_value, output, present_key, present_value, parameters, context, true); - // constexpr int flash_attention_tile_length = 64; - // const float attention_scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) - // : parameters.scale; - // const bool has_attention_bias = attention_bias != nullptr; - // const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); - // FlashAttentionProgram program{"FlashAttentionProgram", has_attention_bias, flash_attention_tile_length, - // parameters.v_hidden_size, parameters.head_size, components}; - - // program.AddInputs({{Q, ProgramTensorMetadataDependency::Type, components}, - // {K, ProgramTensorMetadataDependency::Type, components}, - // {V, ProgramTensorMetadataDependency::Type}}); - // if (has_attention_bias) { - // program.AddInput({attention_bias, ProgramTensorMetadataDependency::Type}); - // } - - // program.AddOutput({present_key, ProgramTensorMetadataDependency::Type, components}); - // program.AddOutput({present_value, ProgramTensorMetadataDependency::Type}); - // program.AddOutput({output, ProgramTensorMetadataDependency::Type}); - - // program.SetDispatchGroupSize(parameters.num_heads) - // .SetWorkgroupSize(flash_attention_tile_length) - // .CacheHint(std::to_string(has_attention_bias)+std::to_string(flash_attention_tile_length)) - // .AddUniformVariables({{static_cast(attention_scale)}, - // {static_cast(parameters.past_sequence_length)}}) - // .SetOverridableConstants({{static_cast(flash_attention_tile_length)}}); - - // return context.RunProgram(program); } MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index 56d23b00badec..5dc567b57f424 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -74,24 +74,6 @@ class InPlaceSoftmaxProgram final : public Program { int work_group_size_; int components_; }; - -class CopyKVCacheProgram final : public Program { - public: - CopyKVCacheProgram(const std::string& kernel_name) - : Program{kernel_name} { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"qkv_hidden_size", ProgramUniformVariableDataType::Uint32}, - {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"total_sequence_length", ProgramUniformVariableDataType::Uint32}); - - private: - int work_group_size_; - int components_; -}; - class VxAttentionScoreProgram final : public Program { public: VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool fa_variant = false) @@ -117,6 +99,22 @@ class VxAttentionScoreProgram final : public Program { bool fa_variant_; }; +class CopyKVCacheProgram final : public Program { + public: + CopyKVCacheProgram(const std::string& kernel_name, int components) + : Program{kernel_name}, components_(components) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"vectorized_head_size", ProgramUniformVariableDataType::Uint32}); + + private: + int components_; +}; + class FlashAttentionProgram final : public Program { public: FlashAttentionProgram(const std::string& kernel_name, bool has_attention_bias, From 75aa49d120b6d81deeac4b74c63336ee43bbeb9a Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Fri, 15 Nov 2024 16:52:56 -0800 Subject: [PATCH 03/25] Add flash attention --- .../webgpu/bert/multihead_attention.cc | 361 +++++++++++------- .../webgpu/bert/multihead_attention.h | 26 +- 2 files changed, 247 insertions(+), 140 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 1c3f5835093bb..324b14f69a0a0 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -467,15 +467,15 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, AttentionParame program.SetDispatchGroupSize(total_sequence_length, 1, parameters.num_heads) .SetWorkgroupSize(1) + .CacheHint(std::to_string(components)) .AddUniformVariables({{static_cast(past_sequence_length)}, {static_cast(parameters.kv_sequence_length)}, - {static_cast(parameters.head_size/ components)},}); + {static_cast(parameters.head_size/ components)}}); return context.RunProgram(program); } Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { - // This shader works only for a limited case of current_seq_len = 1, that is the generation phase of an LLM. // Expectations are // qkv have same number of heads and hidden dimension (head size). // qkv are in BSNH format. @@ -484,141 +484,222 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // N - number of heads. // H - head size or hidden dimension for each qkv head. // KV cache is stored as BN(total_sequence_length)H - // Attention bias is in BN(total_sequence_length) + // Attention bias is in BN(new_sequence_length)(total_sequence_length) // // Expectation is that present_key, and present_value contain past key and values since // we are out of storage buffers a shader can have and both past/present cant be passed. + // The hidden size of each q head should be a multiple of 4 because shader uses vectorized loads. + constexpr int vectorization_size = 4; shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("k", ShaderUsage::UseUniform); - shader.AddInput("v", ShaderUsage::UseUniform); + shader.AddInput("present_key", ShaderUsage::UseUniform); + shader.AddInput("present_value", ShaderUsage::UseUniform); if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } - shader.AddOutput("present_key", ShaderUsage::UseUniform); - shader.AddOutput("present_value", ShaderUsage::UseUniform); shader.AddOutput("output", ShaderUsage::UseUniform); - uint32_t qkv_hidden_size_in_components = qkv_hidden_size_ / components_; - shader.AdditionalImplementation() << "var x: array;\n" - << "var q_shm: array;\n" - << "var subgroup_scratch: array;\n" - << "var local_max_x: q_element_t;\n" - << "var prev_max_x: q_element_t;\n" - << "var denom_x: q_element_t;\n" - << "var denom_ratio: q_element_t;\n" - << "var attention_scale: q_element_t;\n"; - - // First fill the shared memory with the q values - shader.MainFunctionBody() << "let qkv_hidden_size_in_components:u32 = " << qkv_hidden_size_in_components << ";\n" - << "let n_qkv_hidden_size_in_components:u32 = " << qkv_num_heads_ * qkv_hidden_size_in_components << ";\n" - << "let total_sequence_length:u32 = uniforms.past_sequence_length +1;\n" - << "let tile_size:u32 = TILE_SIZE;\n" - << "var current_head:u32 = workgroup_id.x;\n" - << "var q_index:u32 = local_id.x;\n" - << "var head_offset:u32 = current_head*qkv_hidden_size_in_components;\n" - << "attention_scale = q_element_t(uniforms.attention_scale);\n" - << "let key_offset = uniforms.past_sequence_length * n_qkv_hidden_size_in_components + current_head * qkv_hidden_size_in_components;\n" - << "let value_offset = uniforms.past_sequence_length * "<< qkv_num_heads_ * qkv_hidden_size_ - << " + current_head * " << qkv_hidden_size_ << " ;\n" - << "while (q_index < qkv_hidden_size_in_components) {\n" - << " q_shm[q_index] = q[head_offset + q_index];\n" - << " output[head_offset + q_index] = 0;\n" - << " present_key[key_offset + q_index] = k[q_index];\n" - << " present_value[value_offset + q_index] = v[q_index];\n"; - - if (components_ == 4) { - shader.MainFunctionBody() << " present_value[value_offset + q_index + 1] = v[q_index+1];\n" - << " present_value[value_offset + q_index + 2] = v[q_index+2];\n" - << " present_value[value_offset + q_index + 3] = v[q_index+3];\n"; - } - if (components_ == 2) { - shader.MainFunctionBody() << " present_value[value_offset + q_index + 1] = v[q_index+1];\n"; - } + // SUBGROUP_SIZE has to be the same as sg_size. For intel this will be 8. + // TILE_SIZE is the number of groups sharing the k_tile. + // TILE_SIZE has to be <= SUBGROUP_SIZE. Ideal perf of computeSoftMax is when + // TILE_SIZE == SUBGROUP_SIZE. This is a sperate constant from SUBGROUP_SIZE + // because SUBGROUP_SIZE * TILE_SIZE has to be <= 256 as per webgpu + // gpu limits. For Intel this TILE_SIZE will be 8. + shader.AdditionalImplementation() << "const SUBGROUP_SIZE: u32 = " << subgroup_size_ << ";" + << "const TILE_SIZE: u32 = " << tile_size_ << ";" + << "const VECTOR_SIZE: u32 = " << vectorization_size << ";" + << "const QKV_HEAD_SIZE: u32 = " << qkv_head_size_ << ";" + << "const QKV_HEAD_VECTORIZED_SIZE: u32 = QKV_HEAD_SIZE / VECTOR_SIZE;" + << "const NUM_HEADS: u32 = " << qkv_num_heads_ << ";" + << "const MIN_VALUE : q_element_t = -6504.0h;"; + + // Best to keep SHM usage per workgroup < 8KB. 4KB is the limit on a 48EU tigerlake + // GPU afterwhich workgroups will be unscheduled to make space for memory. + shader.AdditionalImplementation() << "var q_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB." + << "var k_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB." + << "var v_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB." + << "var o_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB." + << "var qk_tile : array, TILE_SIZE>; // 8 * 2 * 8 = 128" + << "var max_tile : array; // 2 * 8 = 16" + << "var denom_tile : array; // 2 * 8 = 16" + << "var o_ratio : array; // 2 * 8 = 16"; + + shader.AdditionalImplementation() << R"HELPER_FN( +fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) +{ + if (q_idx_global >= uniforms.new_sequence_length) { + return; + } + // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA + // This is the layout if TransferBSDToBNSH has not been run. + // let offset = q_idx_global * (QKV_HEAD_VECTORIZED_SIZE) * NUM_HEADS + QKV_HEAD_VECTORIZED_SIZE * head_idx; + // Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run. + let offset = head_idx * uniforms.new_sequence_length * QKV_HEAD_VECTORIZED_SIZE + q_idx_global * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) + { + var value = q[idx+offset]; + q_tile[slot][idx] = value; + } +} - shader.MainFunctionBody() << " q_index += tile_size;\n" - << "}\n" - << "if (local_id.x == 0) {\n" - << " prev_max_x = q_element_t(-65504h);\n" - << "}\n" - << "workgroupBarrier();\n"; +fn debugKTile() -> q_value_t +{ + var sum_value = q_value_t(0); + for (var qidx:u32 = 0; qidx < TILE_SIZE; qidx++) + { + for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx++) + { + var value = k_tile[qidx][idx]; + sum_value += value; + } + } + return sum_value; +} - // Go through all the tiles - shader.MainFunctionBody() << "for (var seq_begin = 0u; seq_begin <= uniforms.past_sequence_length; seq_begin += tile_size) {\n"; +fn loadk(slot: u32, k_idx_global : u32, head_idx: u32) +{ + if (k_idx_global >= uniforms.present_sequence_length) { + return; + } + + // Stored as float16[batch_size,num_heads,present_sequence_length,96] + let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + k_idx_global * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx++) + { + var value = present_key[idx+offset]; + k_tile[slot][idx] = value; + } +} - // Compute QKt, assumption is that K is stored as transposed. - shader.MainFunctionBody() << " var k_index:u32 = seq_begin + local_id.x;\n" - << " if (sg_id == 0) {\n" - << " subgroup_scratch[u32(local_id.x/sg_size)] = q_element_t(-65504h);\n" - << " }\n" - << " var x_value_single:q_element_t = 0;" - << " if (k_index < total_sequence_length) {\n" - << " var x_value: q_value_t = q_value_t(0);\n" - << " let key_offset = k_index * n_qkv_hidden_size_in_components + current_head * qkv_hidden_size_in_components;\n" - << " for (var i = 0u; i < qkv_hidden_size_in_components; i++) {\n" - << " x_value += q_shm[i] * present_key[i+key_offset];\n" - << " }\n" - << " x_value_single = " << (components_ == 4 ? "x_value.x + x_value.y + x_value.z + x_value.w" : (components_ == 2 ? "x_value.x + x_value.y" : "x_value")) << ";\n" - << " x_value_single = x_value_single * attention_scale;\n"; - if (has_attention_bias_) { - shader.MainFunctionBody() << " let attention_bias_index = k_index + current_head * (uniforms.past_sequence_length+1);\n" - << " x_value_single = x_value_single + attention_bias[attention_bias_index];\n"; - } - shader.MainFunctionBody() << " var sub_group_max_value = subgroupMax(x_value_single);\n" - << " if (sg_id == 0) {\n" - << " subgroup_scratch[u32(local_id.x/sg_size)] = sub_group_max_value;\n" - << " }\n" - << " }\n"; - - // Update phase to merge with previous tile results. - shader.MainFunctionBody() << " workgroupBarrier();\n" - << " if (local_id.x == 0) {\n" - << " local_max_x = prev_max_x;\n" - << " for(var i = 0u; i < u32(TILE_SIZE/sg_size); i++) {\n" - << " local_max_x = max(local_max_x, subgroup_scratch[i]);\n" - << " }\n" - << " }\n" - << " workgroupBarrier();\n" - << " if (k_index < total_sequence_length) {\n" - << " x_value_single = exp(x_value_single - local_max_x);\n" - << " }\n" - << " if (sg_id == 0) {\n" - << " subgroup_scratch[u32(local_id.x/sg_size)] = 0;\n" - << " }\n" - << " workgroupBarrier();\n" - << " var sub_group_sum_value = subgroupAdd(x_value_single);\n" - << " if (sg_id == 0) {\n" - << " subgroup_scratch[u32(local_id.x/sg_size)] = sub_group_sum_value;\n" - << " }\n" - << " workgroupBarrier();\n" - << " if (local_id.x == 0) {\n" - << " var sum_x:q_element_t = 0;\n" - << " for(var i = 0u; i < u32(TILE_SIZE/sg_size); i++) {\n" - << " sum_x = sum_x + subgroup_scratch[i];\n" - << " }\n" - << " var new_denom_x_first_term:q_element_t = denom_x * exp(prev_max_x - local_max_x);\n" - << " denom_x = new_denom_x_first_term + sum_x;\n" - << " denom_ratio = new_denom_x_first_term / denom_x;\n" - << " prev_max_x = local_max_x;\n" - << " }\n" - << " workgroupBarrier();\n" - << " x[local_id.x] = x_value_single / denom_x;\n"; - - // Update O, we are going to switch to parallalism in v_hidden_size dimension, each thread is going to be - // responsible for a single hidden dimension of v. - shader.MainFunctionBody() << " workgroupBarrier();\n" - << " let v_start = seq_begin *" << qkv_hidden_size_ << " + current_head * total_sequence_length *" << qkv_hidden_size_ << ";\n" - << " for (var i = local_id.x; i < " << qkv_hidden_size_ << "; i+=tile_size) {\n" - << " var sum:q_element_t = 0.0;\n" - << " for (var t = 0u; t < tile_size ; t++) {\n" - << " if (seq_begin + t < total_sequence_length) {\n" - << " var v_index:u32 = v_start + t * " << qkv_hidden_size_ << " + i;\n" - << " sum += (x[t] * present_value[v_index]);\n" - << " }\n" - << " }\n" - << " output[i] = output[i] * denom_ratio + sum;\n" - << " }\n"; - - shader.MainFunctionBody() << "}\n"; +fn loadv(slot: u32, v_idx_global : u32, head_idx: u32) +{ + if (v_idx_global >= uniforms.present_sequence_length) { + return; + } + + // Stored as float16[batch_size,num_heads,present_sequence_length,96] + let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + v_idx_global * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx ++) + { + v_tile[slot][idx] = present_value[idx+offset]; + } +} + +fn loadAttentionBias(qtile_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32) +{ + // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] + if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length) { + qk_tile[qtile_row][k_col] = 0.0; + return; + } + let offset = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length + k_idx_global; + qk_tile[qtile_row][k_col] = attention_bias[offset]; +} + +fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) +{ + if (o_idx_global >= uniforms.new_sequence_length) { + return; + } + // Stored as float16[batch_size,sequence_length,3072] + let offset = o_idx_global * NUM_HEADS * QKV_HEAD_VECTORIZED_SIZE + head_idx * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx += sg_size) + { + let value = o_tile[slot][idx]; + output[offset+idx] = value; + } +} + +fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) +{ + var sum:vec4 = q_value_t(0, 0, 0, 0); + for (var idx = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) + { + var result = q_tile[q_idx][idx]*k_tile[k_idx][idx]; + sum += subgroupAdd(result); + } + if (sg_id == 0) + { + let single_sum : q_element_t = sum.x + sum.y + sum.z + sum.w; + let sqrt_dk = q_element_t(uniforms.alpha); + let value = single_sum * sqrt_dk; + qk_tile[q_idx][k_idx] += value; + } +} + +fn computeSoftMax(q_idx: u32, sg_id:u32) +{ + let x = qk_tile[q_idx][sg_id]; + var max_value = subgroupMax(x); + max_value = max(max_tile[q_idx], max_value); + let sub = x - max_value; + let value = exp(sub); + let sum = subgroupAdd(value); + + // Compute lhs term of update di prime and the compute di prime. + let dleft = denom_tile[q_idx] * exp(max_tile[q_idx]-max_value); + let d = dleft + sum; + qk_tile[q_idx][sg_id] = value / d; + if (sg_id == 0) + { + max_tile[q_idx] = max_value; + denom_tile[q_idx] = d; + o_ratio[q_idx] = dleft / d; + } +} + +fn computeO(q_idx: u32, sg_id:u32) +{ + for (var i:u32 = 0; i < QKV_HEAD_VECTORIZED_SIZE; i++) + { + let attn = qk_tile[q_idx][sg_id]; + let val = v_tile[sg_id][i]; + let intermediate = attn * val; + let sum = subgroupAdd(intermediate); + if (sg_id == 0) + { + let o_ratio = o_ratio[q_idx]; + let old_o = o_tile[q_idx][i]; + let new_o = ( o_ratio * old_o) + sum; + o_tile[q_idx][i] = new_o; + } + } +})HELPER_FN"; + +// Shader is designed to be dispatched as Dispatch(num_heads, present_seq_length / TILE_SIZE, 1) +// QKV_HEAD_VECTORIZED_SIZE % sg_id == 0 for loadq, loadk and computeDotProduct to work right. + + shader.MainFunctionBody() << R"MAIN_FN( +let head_idx = workgroup_id.x; +// Split the composite workgroup id into actual y and subgroup id. +let q_tile_row = u32(local_idx / sg_size); + +let q_idx_global = workgroup_id.y * TILE_SIZE + q_tile_row; +// Each invocation (q_tile_row) gets x threads (subgroup threads) and is responsible for 1 query. +loadq(q_tile_row, q_idx_global, head_idx, sg_id, sg_size); +max_tile[sg_id] = MIN_VALUE; + +for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) +{ + if (k_start+sg_id < uniforms.present_sequence_length) { + loadk(sg_id, k_start+sg_id, head_idx); + loadv(sg_id, k_start+sg_id, head_idx); + loadAttentionBias(q_tile_row, q_idx_global, sg_id, k_start+sg_id, head_idx); + } + workgroupBarrier(); + // Do k_idx + k_start <= q_idx_global if we want only look past. + for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start <= uniforms.present_sequence_length; k_idx++) + { + computeDotProduct(q_tile_row, k_idx, sg_id, sg_size); + } + if (sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length) + { + computeSoftMax(q_tile_row, sg_id); + computeO(q_tile_row, sg_id); + } +} +workgroupBarrier(); +writeo(q_tile_row, q_idx_global, head_idx, sg_id, sg_size); +)MAIN_FN"; return Status::OK(); } @@ -627,8 +708,32 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); - return ApplyAttention(Q, K, V, attention_bias, past_key, past_value, output, present_key, - present_value, parameters, context, true); + // return ApplyAttention(Q, K, V, attention_bias, past_key, past_value, output, present_key, + // present_value, parameters, context, true); + constexpr int subgroup_size = 32; + constexpr int tile_size = 8; + bool has_attention_bias = attention_bias != nullptr; + FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size, parameters.num_heads}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, + {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, + {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}, + {attention_bias, ProgramTensorMetadataDependency::TypeAndRank}}); + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); + + std::string cache_hint = std::to_string(has_attention_bias) + + std::to_string(subgroup_size) + + std::to_string(tile_size) + + std::to_string(parameters.head_size) + + std::to_string(parameters.num_heads); + const uint32_t new_seq_length = parameters.sequence_length - parameters.past_sequence_length; + program.SetDispatchGroupSize(parameters.num_heads, (new_seq_length + tile_size - 1) / tile_size, 1) + .SetWorkgroupSize(subgroup_size*tile_size) + .CacheHint(cache_hint) + .AddUniformVariables({{static_cast(new_seq_length)}, + {static_cast(parameters.sequence_length)}, + {static_cast(1.0f / sqrt(parameters.head_size))}}); + + return context.RunProgram(program); } MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) @@ -687,7 +792,7 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (parameters.sequence_length == 1 && bias == nullptr && parameters.kv_sequence_length == 1 && present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && - present_value->SizeInBytes() > 0) { + present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index 5dc567b57f424..8176997c1483a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -117,30 +117,32 @@ class CopyKVCacheProgram final : public Program { class FlashAttentionProgram final : public Program { public: - FlashAttentionProgram(const std::string& kernel_name, bool has_attention_bias, - int tile_size, int qkv_hidden_size, int qkv_num_heads, int components) + FlashAttentionProgram(const std::string& kernel_name, + bool has_attention_bias, + int subgroup_size, + int tile_size, + int qkv_head_size, + int qkv_num_heads) : Program{kernel_name}, has_attention_bias_(has_attention_bias), + subgroup_size_(subgroup_size), tile_size_(tile_size), - qkv_hidden_size_(qkv_hidden_size), - qkv_num_heads_(qkv_num_heads), - components_(components) { + qkv_head_size_(qkv_head_size), + qkv_num_heads_(qkv_num_heads) { } Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"attention_scale", ProgramUniformVariableDataType::Float32}, - {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"component_indexes_per_token", ProgramUniformVariableDataType::Uint32}); - - WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}); private: bool has_attention_bias_; + int subgroup_size_; int tile_size_; - int qkv_hidden_size_; + int qkv_head_size_; int qkv_num_heads_; - int components_; }; class MultiHeadAttention final : public WebGpuKernel { From 58157c54fb596d1f8922f51ab9502360076d1247 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Tue, 19 Nov 2024 10:32:58 -0800 Subject: [PATCH 04/25] Integrate FA --- .../webgpu/bert/multihead_attention.cc | 111 +++++++++++------- 1 file changed, 67 insertions(+), 44 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 324b14f69a0a0..ff506dec7ba4f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -504,36 +504,37 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // TILE_SIZE == SUBGROUP_SIZE. This is a sperate constant from SUBGROUP_SIZE // because SUBGROUP_SIZE * TILE_SIZE has to be <= 256 as per webgpu // gpu limits. For Intel this TILE_SIZE will be 8. - shader.AdditionalImplementation() << "const SUBGROUP_SIZE: u32 = " << subgroup_size_ << ";" - << "const TILE_SIZE: u32 = " << tile_size_ << ";" - << "const VECTOR_SIZE: u32 = " << vectorization_size << ";" - << "const QKV_HEAD_SIZE: u32 = " << qkv_head_size_ << ";" - << "const QKV_HEAD_VECTORIZED_SIZE: u32 = QKV_HEAD_SIZE / VECTOR_SIZE;" - << "const NUM_HEADS: u32 = " << qkv_num_heads_ << ";" - << "const MIN_VALUE : q_element_t = -6504.0h;"; + shader.AdditionalImplementation() << "const SUBGROUP_SIZE: u32 = " << subgroup_size_ << ";\n" + << "const TILE_SIZE: u32 = " << tile_size_ << ";\n" + << "const VECTOR_SIZE: u32 = " << vectorization_size << ";\n" + << "const QKV_HEAD_SIZE: u32 = " << qkv_head_size_ << ";\n" + << "const QKV_HEAD_VECTORIZED_SIZE: u32 = QKV_HEAD_SIZE / VECTOR_SIZE;\n" + << "const NUM_HEADS: u32 = " << qkv_num_heads_ << ";\n" + << "const MIN_VALUE : q_element_t = -6504.0h;\n"; // Best to keep SHM usage per workgroup < 8KB. 4KB is the limit on a 48EU tigerlake // GPU afterwhich workgroups will be unscheduled to make space for memory. - shader.AdditionalImplementation() << "var q_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB." - << "var k_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB." - << "var v_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB." - << "var o_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB." - << "var qk_tile : array, TILE_SIZE>; // 8 * 2 * 8 = 128" - << "var max_tile : array; // 2 * 8 = 16" - << "var denom_tile : array; // 2 * 8 = 16" - << "var o_ratio : array; // 2 * 8 = 16"; + shader.AdditionalImplementation() << "var q_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var k_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var v_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var o_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var qk_tile : array, TILE_SIZE>; // 8 * 2 * 8 = 128\n" + << "var max_tile : array; // 2 * 8 = 16\n" + << "var denom_tile : array; // 2 * 8 = 16\n" + << "var o_ratio : array; // 2 * 8 = 16\n"; shader.AdditionalImplementation() << R"HELPER_FN( + fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) { if (q_idx_global >= uniforms.new_sequence_length) { return; } // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA - // This is the layout if TransferBSDToBNSH has not been run. - // let offset = q_idx_global * (QKV_HEAD_VECTORIZED_SIZE) * NUM_HEADS + QKV_HEAD_VECTORIZED_SIZE * head_idx; + // This is the layout if TransferBSDToBNSH has not been run. + let offset = q_idx_global * (QKV_HEAD_VECTORIZED_SIZE) * NUM_HEADS + QKV_HEAD_VECTORIZED_SIZE * head_idx; // Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run. - let offset = head_idx * uniforms.new_sequence_length * QKV_HEAD_VECTORIZED_SIZE + q_idx_global * QKV_HEAD_VECTORIZED_SIZE; + // let offset = head_idx * uniforms.new_sequence_length * QKV_HEAD_VECTORIZED_SIZE + q_idx_global * QKV_HEAD_VECTORIZED_SIZE; for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) { var value = q[idx+offset]; @@ -612,27 +613,42 @@ fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) { var sum:vec4 = q_value_t(0, 0, 0, 0); - for (var idx = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) + for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) { - var result = q_tile[q_idx][idx]*k_tile[k_idx][idx]; + var result = q_value_t(0); + let sg_idx = idx+sg_id; + // QKV_HEAD_VECTORIZED_SIZE is divisible by the subgroup size this if check is not + // required. Hopefully the compiler sees the first half of this if statement and + // removes this if instruction. + if (QKV_HEAD_VECTORIZED_SIZE % sg_size == 0 || sg_idx < QKV_HEAD_VECTORIZED_SIZE) + { + result = q_tile[q_idx][sg_idx]*k_tile[k_idx][sg_idx]; + } sum += subgroupAdd(result); } + if (sg_id == 0) { let single_sum : q_element_t = sum.x + sum.y + sum.z + sum.w; let sqrt_dk = q_element_t(uniforms.alpha); let value = single_sum * sqrt_dk; qk_tile[q_idx][k_idx] += value; + } } -} -fn computeSoftMax(q_idx: u32, sg_id:u32) +fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) { - let x = qk_tile[q_idx][sg_id]; + var x = MIN_VALUE; + if (enabled){ + x = qk_tile[q_idx][sg_id]; + } var max_value = subgroupMax(x); max_value = max(max_tile[q_idx], max_value); let sub = x - max_value; - let value = exp(sub); + var value:q_element_t = 0; + if (enabled) { + value = exp(sub); + } let sum = subgroupAdd(value); // Compute lhs term of update di prime and the compute di prime. @@ -647,13 +663,17 @@ fn computeSoftMax(q_idx: u32, sg_id:u32) } } -fn computeO(q_idx: u32, sg_id:u32) +fn computeO(q_idx: u32, sg_id:u32, enabled:bool) { for (var i:u32 = 0; i < QKV_HEAD_VECTORIZED_SIZE; i++) { let attn = qk_tile[q_idx][sg_id]; let val = v_tile[sg_id][i]; - let intermediate = attn * val; + var intermediate = q_value_t(0); + if (enabled) + { + intermediate = attn * val; + } let sum = subgroupAdd(intermediate); if (sg_id == 0) { @@ -663,7 +683,9 @@ fn computeO(q_idx: u32, sg_id:u32) o_tile[q_idx][i] = new_o; } } -})HELPER_FN"; +} + +)HELPER_FN"; // Shader is designed to be dispatched as Dispatch(num_heads, present_seq_length / TILE_SIZE, 1) // QKV_HEAD_VECTORIZED_SIZE % sg_id == 0 for loadq, loadk and computeDotProduct to work right. @@ -680,25 +702,24 @@ max_tile[sg_id] = MIN_VALUE; for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) { - if (k_start+sg_id < uniforms.present_sequence_length) { + if (sg_id < TILE_SIZE && k_start+sg_id < uniforms.present_sequence_length) { loadk(sg_id, k_start+sg_id, head_idx); loadv(sg_id, k_start+sg_id, head_idx); loadAttentionBias(q_tile_row, q_idx_global, sg_id, k_start+sg_id, head_idx); } workgroupBarrier(); // Do k_idx + k_start <= q_idx_global if we want only look past. - for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start <= uniforms.present_sequence_length; k_idx++) + for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start < uniforms.present_sequence_length; k_idx++) { computeDotProduct(q_tile_row, k_idx, sg_id, sg_size); } - if (sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length) - { - computeSoftMax(q_tile_row, sg_id); - computeO(q_tile_row, sg_id); - } + let enabled:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; + computeSoftMax(q_tile_row, sg_id, enabled); + computeO(q_tile_row, sg_id, enabled); } workgroupBarrier(); writeo(q_tile_row, q_idx_global, head_idx, sg_id, sg_size); + )MAIN_FN"; return Status::OK(); @@ -708,8 +729,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); - // return ApplyAttention(Q, K, V, attention_bias, past_key, past_value, output, present_key, - // present_value, parameters, context, true); + //return ApplyAttention(Q, K, V, attention_bias, past_key, past_value, output, present_key, + // present_value, parameters, context, true); + constexpr int subgroup_size = 32; constexpr int tile_size = 8; bool has_attention_bias = attention_bias != nullptr; @@ -718,20 +740,20 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}, {attention_bias, ProgramTensorMetadataDependency::TypeAndRank}}); - program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); - + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}}); + const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; std::string cache_hint = std::to_string(has_attention_bias) + std::to_string(subgroup_size) + std::to_string(tile_size) + std::to_string(parameters.head_size) + std::to_string(parameters.num_heads); - const uint32_t new_seq_length = parameters.sequence_length - parameters.past_sequence_length; - program.SetDispatchGroupSize(parameters.num_heads, (new_seq_length + tile_size - 1) / tile_size, 1) + program.SetDispatchGroupSize(parameters.num_heads, (parameters.sequence_length + tile_size - 1) / tile_size, 1) .SetWorkgroupSize(subgroup_size*tile_size) .CacheHint(cache_hint) - .AddUniformVariables({{static_cast(new_seq_length)}, - {static_cast(parameters.sequence_length)}, - {static_cast(1.0f / sqrt(parameters.head_size))}}); + .AddUniformVariables({{static_cast(parameters.sequence_length)}, + {static_cast(parameters.total_sequence_length)}, + {alpha}}); return context.RunProgram(program); } @@ -790,7 +812,8 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_key = context.Output(1, present_shape); Tensor* present_value = context.Output(2, present_shape); - if (parameters.sequence_length == 1 && bias == nullptr && parameters.kv_sequence_length == 1 && + if (bias == nullptr && + past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0 && past_value->SizeInBytes() > 0 && present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, From 80296aa877ea1284e665a08fd776622cc898eef1 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Tue, 19 Nov 2024 17:02:00 -0800 Subject: [PATCH 05/25] Try fix the divide by zero issue --- .../contrib_ops/webgpu/bert/multihead_attention.cc | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index ff506dec7ba4f..e5c68bc14b83f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -618,7 +618,7 @@ fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) var result = q_value_t(0); let sg_idx = idx+sg_id; // QKV_HEAD_VECTORIZED_SIZE is divisible by the subgroup size this if check is not - // required. Hopefully the compiler sees the first half of this if statement and + // required. Hopefully the compiler sees the first half of this if statement and // removes this if instruction. if (QKV_HEAD_VECTORIZED_SIZE % sg_size == 0 || sg_idx < QKV_HEAD_VECTORIZED_SIZE) { @@ -633,8 +633,8 @@ fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) let sqrt_dk = q_element_t(uniforms.alpha); let value = single_sum * sqrt_dk; qk_tile[q_idx][k_idx] += value; - } } +} fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) { @@ -653,7 +653,11 @@ fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) // Compute lhs term of update di prime and the compute di prime. let dleft = denom_tile[q_idx] * exp(max_tile[q_idx]-max_value); - let d = dleft + sum; + var d = dleft + sum; + if (d == 0) + { + d = 0.0000001h; + } qk_tile[q_idx][sg_id] = value / d; if (sg_id == 0) { @@ -812,7 +816,8 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_key = context.Output(1, present_shape); Tensor* present_value = context.Output(2, present_shape); - if (bias == nullptr && + if (parameters.batch_size == 1 && + bias == nullptr && past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0 && past_value->SizeInBytes() > 0 && present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) { From c281f84d2bd1e2c2b99292886a4c5c0c045a3f8a Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Thu, 21 Nov 2024 11:10:24 -0800 Subject: [PATCH 06/25] FA works onn intel (TILE_SIZE == SUBGROUP_SIZE) for seq length of 1. --- .../webgpu/bert/multihead_attention.cc | 102 ++++++++++++------ .../webgpu/bert/multihead_attention.h | 5 +- 2 files changed, 72 insertions(+), 35 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index e5c68bc14b83f..dfe188a23f5a5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -421,23 +421,29 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { // Attention bias is in BN(total_sequence_length) shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AddInput("value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + if (has_past_) { + shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + } shader.AddOutput("present_key", ShaderUsage::UseUniform); shader.AddOutput("present_value", ShaderUsage::UseUniform); shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" << "let kIdx = workgroup_id.x;\n" - << "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n" - << "if (kIdx < uniforms.past_sequence_length) {\n" - << " let pastKeyOffset = headIdx * uniforms.past_sequence_length * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n" - << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" - << " present_key[presentKeyOffset+w] = past_key[pastKeyOffset+w];\n" - << " present_value[presentKeyOffset+w] = past_value[pastKeyOffset+w];\n" - << " }\n" - << "}\n" - << "else if (kIdx >= uniforms.past_sequence_length) {\n" - << " let nkIdx = kIdx - uniforms.past_sequence_length;\n" + << "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n"; + if (has_past_) { + shader.MainFunctionBody() << "if (kIdx < uniforms.past_sequence_length) {\n" + << " let pastKeyOffset = headIdx * uniforms.past_sequence_length * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n" + << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" + << " present_key[presentKeyOffset+w] = past_key[pastKeyOffset+w];\n" + << " present_value[presentKeyOffset+w] = past_value[pastKeyOffset+w];\n" + << " }\n" + << "}\n" + << "else if (kIdx >= uniforms.past_sequence_length) {\n"; + } else { + shader.MainFunctionBody() << "if (kIdx >= uniforms.past_sequence_length) {\n"; + } + shader.MainFunctionBody() << " let nkIdx = kIdx - uniforms.past_sequence_length;\n" << " // Assumes kv have BSNH layout. num_workgroups.z is the num_head as per the dispatch requirement.\n" << " let nOffset = nkIdx * uniforms.vectorized_head_size * num_workgroups.z + headIdx*uniforms.vectorized_head_size;\n" << " // Assumes kv have BNSH layout.\n" @@ -457,17 +463,24 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, AttentionParame int past_sequence_length, int total_sequence_length) { const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); - CopyKVCacheProgram program{"CopyKVCache", components}; - program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, - {V, ProgramTensorMetadataDependency::TypeAndRank, components}, - {past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, - {past_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); + bool has_past = (past_sequence_length != 0); + CopyKVCacheProgram program{"CopyKVCache", components, has_past}; + if (has_past) { + program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, + {V, ProgramTensorMetadataDependency::TypeAndRank, components}, + {past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, + {past_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); + } else { + program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, + {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); + } + program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components}, {present_value, ProgramTensorMetadataDependency::Rank, components}}); program.SetDispatchGroupSize(total_sequence_length, 1, parameters.num_heads) .SetWorkgroupSize(1) - .CacheHint(std::to_string(components)) + .CacheHint(std::to_string(components) + std::to_string(has_past)) .AddUniformVariables({{static_cast(past_sequence_length)}, {static_cast(parameters.kv_sequence_length)}, {static_cast(parameters.head_size/ components)}}); @@ -669,15 +682,15 @@ fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) fn computeO(q_idx: u32, sg_id:u32, enabled:bool) { + var attn = q_element_t(0); + if (enabled) + { + attn = qk_tile[q_idx][sg_id]; + } for (var i:u32 = 0; i < QKV_HEAD_VECTORIZED_SIZE; i++) { - let attn = qk_tile[q_idx][sg_id]; let val = v_tile[sg_id][i]; - var intermediate = q_value_t(0); - if (enabled) - { - intermediate = attn * val; - } + var intermediate = attn * val; let sum = subgroupAdd(intermediate); if (sg_id == 0) { @@ -733,10 +746,33 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); - //return ApplyAttention(Q, K, V, attention_bias, past_key, past_value, output, present_key, - // present_value, parameters, context, true); - constexpr int subgroup_size = 32; + // // Uncomment to test CopyKVCache independent of FlashAttentionProgram. + // TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, + // parameters.sequence_length, parameters.head_size}); + // TensorShape q_new_shape(q_new_dims); + // Tensor Qn = context.CreateGPUTensor(Q->DataType(), q_new_shape); + // ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + // context, parameters.num_heads, parameters.sequence_length, parameters.head_size, Q, nullptr, 0, &Qn)); + + // TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads, + // parameters.kv_sequence_length, parameters.head_size}); + // TensorShape k_new_shape(k_new_dims); + // Tensor Kn = context.CreateGPUTensor(K->DataType(), k_new_shape); + // ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, + // parameters.head_size, K, nullptr, parameters.hidden_size, &Kn)); + + // TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads, + // parameters.kv_sequence_length, parameters.v_head_size}); + // TensorShape v_new_shape(v_new_dims); + // Tensor Vn = context.CreateGPUTensor(V->DataType(), v_new_shape); + // ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, + // parameters.v_head_size, V, nullptr, 2 * parameters.hidden_size, &Vn)); + + // return ApplyAttention(&Qn, &Kn, &Vn, attention_bias, past_key, past_value, output, present_key, + // present_value, parameters, context, true); + + constexpr int subgroup_size = 8; constexpr int tile_size = 8; bool has_attention_bias = attention_bias != nullptr; FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size, parameters.num_heads}; @@ -817,12 +853,12 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_value = context.Output(2, present_shape); if (parameters.batch_size == 1 && - bias == nullptr && - past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0 && past_value->SizeInBytes() > 0 && - present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && - present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) { - return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context); + bias == nullptr && + past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0 && past_value->SizeInBytes() > 0 && + present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && + present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) { + return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, + present_value, parameters, context); } TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index 8176997c1483a..b6d7aa3811672 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -101,8 +101,8 @@ class VxAttentionScoreProgram final : public Program { class CopyKVCacheProgram final : public Program { public: - CopyKVCacheProgram(const std::string& kernel_name, int components) - : Program{kernel_name}, components_(components) { + CopyKVCacheProgram(const std::string& kernel_name, int components, bool has_past) + : Program{kernel_name}, components_(components), has_past_(has_past) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -113,6 +113,7 @@ class CopyKVCacheProgram final : public Program { private: int components_; + bool has_past_; }; class FlashAttentionProgram final : public Program { From 3d258522217746fb8ed423b700ff4177d47b37ba Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Fri, 22 Nov 2024 19:44:18 -0800 Subject: [PATCH 07/25] Update subgroup_size and tile_size to be actual intel values --- .../webgpu/bert/multihead_attention.cc | 49 +++++-------------- 1 file changed, 12 insertions(+), 37 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index dfe188a23f5a5..bb2999fb44a87 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -710,11 +710,11 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool) shader.MainFunctionBody() << R"MAIN_FN( let head_idx = workgroup_id.x; // Split the composite workgroup id into actual y and subgroup id. -let q_tile_row = u32(local_idx / sg_size); +let wave_id = u32(local_idx / sg_size); -let q_idx_global = workgroup_id.y * TILE_SIZE + q_tile_row; -// Each invocation (q_tile_row) gets x threads (subgroup threads) and is responsible for 1 query. -loadq(q_tile_row, q_idx_global, head_idx, sg_id, sg_size); +let q_idx_global = workgroup_id.y * TILE_SIZE + wave_id; +// Each invocation (wave_id) gets sg_size lanes (subgroup threads) and is responsible for 1 query. +loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size); max_tile[sg_id] = MIN_VALUE; for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) @@ -722,20 +722,20 @@ for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_ if (sg_id < TILE_SIZE && k_start+sg_id < uniforms.present_sequence_length) { loadk(sg_id, k_start+sg_id, head_idx); loadv(sg_id, k_start+sg_id, head_idx); - loadAttentionBias(q_tile_row, q_idx_global, sg_id, k_start+sg_id, head_idx); + loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); } workgroupBarrier(); // Do k_idx + k_start <= q_idx_global if we want only look past. for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start < uniforms.present_sequence_length; k_idx++) { - computeDotProduct(q_tile_row, k_idx, sg_id, sg_size); + computeDotProduct(wave_id, k_idx, sg_id, sg_size); } let enabled:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; - computeSoftMax(q_tile_row, sg_id, enabled); - computeO(q_tile_row, sg_id, enabled); + computeSoftMax(wave_id, sg_id, enabled); + computeO(wave_id, sg_id, enabled); } workgroupBarrier(); -writeo(q_tile_row, q_idx_global, head_idx, sg_id, sg_size); +writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size); )MAIN_FN"; @@ -747,33 +747,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); - // // Uncomment to test CopyKVCache independent of FlashAttentionProgram. - // TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, - // parameters.sequence_length, parameters.head_size}); - // TensorShape q_new_shape(q_new_dims); - // Tensor Qn = context.CreateGPUTensor(Q->DataType(), q_new_shape); - // ORT_RETURN_IF_ERROR(TransferBSDToBNSH( - // context, parameters.num_heads, parameters.sequence_length, parameters.head_size, Q, nullptr, 0, &Qn)); - - // TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads, - // parameters.kv_sequence_length, parameters.head_size}); - // TensorShape k_new_shape(k_new_dims); - // Tensor Kn = context.CreateGPUTensor(K->DataType(), k_new_shape); - // ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, - // parameters.head_size, K, nullptr, parameters.hidden_size, &Kn)); - - // TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads, - // parameters.kv_sequence_length, parameters.v_head_size}); - // TensorShape v_new_shape(v_new_dims); - // Tensor Vn = context.CreateGPUTensor(V->DataType(), v_new_shape); - // ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, - // parameters.v_head_size, V, nullptr, 2 * parameters.hidden_size, &Vn)); - - // return ApplyAttention(&Qn, &Kn, &Vn, attention_bias, past_key, past_value, output, present_key, - // present_value, parameters, context, true); - - constexpr int subgroup_size = 8; - constexpr int tile_size = 8; + constexpr int subgroup_size = 16; + constexpr int tile_size = 16; bool has_attention_bias = attention_bias != nullptr; FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size, parameters.num_heads}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, @@ -789,7 +764,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co std::to_string(parameters.head_size) + std::to_string(parameters.num_heads); program.SetDispatchGroupSize(parameters.num_heads, (parameters.sequence_length + tile_size - 1) / tile_size, 1) - .SetWorkgroupSize(subgroup_size*tile_size) + .SetWorkgroupSize(subgroup_size*subgroup_size) .CacheHint(cache_hint) .AddUniformVariables({{static_cast(parameters.sequence_length)}, {static_cast(parameters.total_sequence_length)}, From b19070a34888b4ab6cf9ae2968bedc4228e149bb Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Fri, 22 Nov 2024 19:48:47 -0800 Subject: [PATCH 08/25] Commit temporarily --- .../webgpu/bert/multihead_attention.cc | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index bb2999fb44a87..d6da6e6c14f8a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -704,18 +704,25 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool) )HELPER_FN"; -// Shader is designed to be dispatched as Dispatch(num_heads, present_seq_length / TILE_SIZE, 1) -// QKV_HEAD_VECTORIZED_SIZE % sg_id == 0 for loadq, loadk and computeDotProduct to work right. +// Shader is designed to be dispatched as Dispatch(num_heads, new_seq_length / TILE_SIZE, 1) shader.MainFunctionBody() << R"MAIN_FN( let head_idx = workgroup_id.x; -// Split the composite workgroup id into actual y and subgroup id. +// It is always the case that 0 <= wave_id < TILE_SIZE +// Each wave has sg_size lanes (subgroup threads). let wave_id = u32(local_idx / sg_size); -let q_idx_global = workgroup_id.y * TILE_SIZE + wave_id; -// Each invocation (wave_id) gets sg_size lanes (subgroup threads) and is responsible for 1 query. -loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size); -max_tile[sg_id] = MIN_VALUE; +let q_idx_start = workgroup_id.y * TILE_SIZE; +let q_idx_global = q_idx_start + wave_id; +let q_idx_global_using_wave_valid = q_idx_global < uniforms.new_sequence_length; +if (q_idx_global_using_wave_valid) +{ + loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size); +} +if (sg_id == 0) +{ + max_tile[wave_id] = MIN_VALUE; +} for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) { From 228b840bdcd43326b3253b04f6a54b6ce1a7d080 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Fri, 22 Nov 2024 20:29:38 -0800 Subject: [PATCH 09/25] Works so far. --- .../webgpu/bert/multihead_attention.cc | 60 +++++++------------ 1 file changed, 21 insertions(+), 39 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index d6da6e6c14f8a..4764d594811f7 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -537,12 +537,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var o_ratio : array; // 2 * 8 = 16\n"; shader.AdditionalImplementation() << R"HELPER_FN( - fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) { - if (q_idx_global >= uniforms.new_sequence_length) { - return; - } // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA // This is the layout if TransferBSDToBNSH has not been run. let offset = q_idx_global * (QKV_HEAD_VECTORIZED_SIZE) * NUM_HEADS + QKV_HEAD_VECTORIZED_SIZE * head_idx; @@ -555,58 +551,36 @@ fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u3 } } -fn debugKTile() -> q_value_t +fn loadk(slot: u32, k_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) { - var sum_value = q_value_t(0); - for (var qidx:u32 = 0; qidx < TILE_SIZE; qidx++) - { - for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx++) - { - var value = k_tile[qidx][idx]; - sum_value += value; - } - } - return sum_value; -} - -fn loadk(slot: u32, k_idx_global : u32, head_idx: u32) -{ - if (k_idx_global >= uniforms.present_sequence_length) { - return; - } - // Stored as float16[batch_size,num_heads,present_sequence_length,96] let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + k_idx_global * QKV_HEAD_VECTORIZED_SIZE; - for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx++) + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) { var value = present_key[idx+offset]; k_tile[slot][idx] = value; } } -fn loadv(slot: u32, v_idx_global : u32, head_idx: u32) +fn loadv(slot: u32, v_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) { - if (v_idx_global >= uniforms.present_sequence_length) { - return; - } - // Stored as float16[batch_size,num_heads,present_sequence_length,96] let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + v_idx_global * QKV_HEAD_VECTORIZED_SIZE; - for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx ++) + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) { v_tile[slot][idx] = present_value[idx+offset]; } } -fn loadAttentionBias(qtile_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32) +fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32) { // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length) { - qk_tile[qtile_row][k_col] = 0.0; + if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length || k_col >= TILE_SIZE) { + qk_tile[q_row][k_col] = 0.0; return; } let offset = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length + k_idx_global; - qk_tile[qtile_row][k_col] = attention_bias[offset]; + qk_tile[q_row][k_col] = attention_bias[offset]; } fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) @@ -704,8 +678,8 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool) )HELPER_FN"; -// Shader is designed to be dispatched as Dispatch(num_heads, new_seq_length / TILE_SIZE, 1) - +// Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / TILE_SIZE, 1) +// Each workgroup is responsible for a range of q values (TILE_SIZE) and visits all Ks for those q's. shader.MainFunctionBody() << R"MAIN_FN( let head_idx = workgroup_id.x; // It is always the case that 0 <= wave_id < TILE_SIZE @@ -717,6 +691,7 @@ let q_idx_global = q_idx_start + wave_id; let q_idx_global_using_wave_valid = q_idx_global < uniforms.new_sequence_length; if (q_idx_global_using_wave_valid) { + // Each invocation (wave_id) gets lane threads (subgroup threads) and is responsible for 1 query. loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size); } if (sg_id == 0) @@ -726,9 +701,16 @@ if (sg_id == 0) for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) { - if (sg_id < TILE_SIZE && k_start+sg_id < uniforms.present_sequence_length) { - loadk(sg_id, k_start+sg_id, head_idx); - loadv(sg_id, k_start+sg_id, head_idx); + let k_idx_global = k_start+wave_id; + let k_idx_global_using_wave_valid = k_idx_global < uniforms.present_sequence_length; + if (k_idx_global_using_wave_valid) { + // Leveraging the subgroup lanes for parallelism, load into slot wave_id + // K/V values from k_start+wave_id. + loadk(wave_id, k_idx_global, head_idx, sg_id, sg_size); + loadv(wave_id, k_idx_global, head_idx, sg_id, sg_size); + // Next, we want for every q row (wave_id) to populate bias for new sequence length + // (k_start+sg_id). loadAttentionBias handles range checking q_idx_global, + // and sg_id, (k_start+sg_id). loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); } workgroupBarrier(); From 122e5f985ae503503978d9ba4a971b5302b313f6 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Fri, 22 Nov 2024 20:57:47 -0800 Subject: [PATCH 10/25] This works. --- .../webgpu/bert/multihead_attention.cc | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 4764d594811f7..b3c00ab69100e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -585,9 +585,6 @@ fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) { - if (o_idx_global >= uniforms.new_sequence_length) { - return; - } // Stored as float16[batch_size,sequence_length,3072] let offset = o_idx_global * NUM_HEADS * QKV_HEAD_VECTORIZED_SIZE + head_idx * QKV_HEAD_VECTORIZED_SIZE; for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx += sg_size) @@ -600,6 +597,8 @@ fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) { var sum:vec4 = q_value_t(0, 0, 0, 0); + // idx is not initialized to sg_id to ensure uniformity because the loop uses + // subgroupAdd and unused lanes need to be initialized with 0 for correctness. for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) { var result = q_value_t(0); @@ -613,7 +612,6 @@ fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) } sum += subgroupAdd(result); } - if (sg_id == 0) { let single_sum : q_element_t = sum.x + sum.y + sum.z + sum.w; @@ -637,12 +635,12 @@ fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) value = exp(sub); } let sum = subgroupAdd(value); - // Compute lhs term of update di prime and the compute di prime. let dleft = denom_tile[q_idx] * exp(max_tile[q_idx]-max_value); var d = dleft + sum; if (d == 0) { + // Avoid division by zero by setting d to a really small value. d = 0.0000001h; } qk_tile[q_idx][sg_id] = value / d; @@ -714,18 +712,31 @@ for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_ loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); } workgroupBarrier(); - // Do k_idx + k_start <= q_idx_global if we want only look past. - for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start < uniforms.present_sequence_length; k_idx++) + + if (k_idx_global_using_wave_valid) { - computeDotProduct(wave_id, k_idx, sg_id, sg_size); + for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.new_sequence_length; q_idx++) + { + // Leveraging the subgroups for parallelism, compute dot product of QK. + // Because for the case of new_seq 1, there is a single query and context length of K + // we iterate over q and use the waves for K so that this step can use all the waves in + // in the workgroup. + // We validate q_idx,wave_id to be less than TILE_SIZE, computeDotProduct only needs to + // validate sg_id as being less than QKV_HEAD_VECTORIZED_SIZE. + computeDotProduct(q_idx, wave_id, sg_id, sg_size); + } } - let enabled:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; - computeSoftMax(wave_id, sg_id, enabled); - computeO(wave_id, sg_id, enabled); + workgroupBarrier(); + + let wave_lane_valid:bool = q_idx_global_using_wave_valid && sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; + computeSoftMax(wave_id, sg_id, wave_lane_valid); + computeO(wave_id, sg_id, wave_lane_valid); } workgroupBarrier(); -writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size); - +if (q_idx_global_using_wave_valid) +{ + writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size); +} )MAIN_FN"; return Status::OK(); From 8b5fcc719d12588df343e2bf3ec9f959d5f672fc Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Fri, 22 Nov 2024 21:02:46 -0800 Subject: [PATCH 11/25] Matches past and should work. --- .../contrib_ops/webgpu/bert/multihead_attention.cc | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index b3c00ab69100e..042b085c49452 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -550,7 +550,6 @@ fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u3 q_tile[slot][idx] = value; } } - fn loadk(slot: u32, k_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,96] @@ -561,7 +560,6 @@ fn loadk(slot: u32, k_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) k_tile[slot][idx] = value; } } - fn loadv(slot: u32, v_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,96] @@ -571,7 +569,6 @@ fn loadv(slot: u32, v_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) v_tile[slot][idx] = present_value[idx+offset]; } } - fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32) { // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] @@ -582,7 +579,6 @@ fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : let offset = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length + k_idx_global; qk_tile[q_row][k_col] = attention_bias[offset]; } - fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) { // Stored as float16[batch_size,sequence_length,3072] @@ -593,7 +589,6 @@ fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u output[offset+idx] = value; } } - fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) { var sum:vec4 = q_value_t(0, 0, 0, 0); @@ -620,7 +615,6 @@ fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) qk_tile[q_idx][k_idx] += value; } } - fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) { var x = MIN_VALUE; @@ -651,7 +645,6 @@ fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) o_ratio[q_idx] = dleft / d; } } - fn computeO(q_idx: u32, sg_id:u32, enabled:bool) { var attn = q_element_t(0); @@ -673,7 +666,6 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool) } } } - )HELPER_FN"; // Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / TILE_SIZE, 1) @@ -696,7 +688,6 @@ if (sg_id == 0) { max_tile[wave_id] = MIN_VALUE; } - for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) { let k_idx_global = k_start+wave_id; @@ -712,7 +703,6 @@ for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_ loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); } workgroupBarrier(); - if (k_idx_global_using_wave_valid) { for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.new_sequence_length; q_idx++) From ed310dedad36ee50b92485b4208e2cc073ac6739 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Fri, 22 Nov 2024 21:23:12 -0800 Subject: [PATCH 12/25] Multi Q also works --- onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 042b085c49452..dddeb0fc80ba0 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -819,7 +819,6 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (parameters.batch_size == 1 && bias == nullptr && - past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0 && past_value->SizeInBytes() > 0 && present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, From ee1051f914aa2a15edfeaf336ab46c6bb4d37395 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Fri, 22 Nov 2024 22:22:09 -0800 Subject: [PATCH 13/25] Switch back to safer algorithm that is optimized for prefill. --- .../webgpu/bert/multihead_attention.cc | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index dddeb0fc80ba0..1f8d725992091 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -598,10 +598,7 @@ fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) { var result = q_value_t(0); let sg_idx = idx+sg_id; - // QKV_HEAD_VECTORIZED_SIZE is divisible by the subgroup size this if check is not - // required. Hopefully the compiler sees the first half of this if statement and - // removes this if instruction. - if (QKV_HEAD_VECTORIZED_SIZE % sg_size == 0 || sg_idx < QKV_HEAD_VECTORIZED_SIZE) + if (sg_idx < QKV_HEAD_VECTORIZED_SIZE) { result = q_tile[q_idx][sg_idx]*k_tile[k_idx][sg_idx]; } @@ -690,6 +687,7 @@ if (sg_id == 0) } for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) { + workgroupBarrier(); let k_idx_global = k_start+wave_id; let k_idx_global_using_wave_valid = k_idx_global < uniforms.present_sequence_length; if (k_idx_global_using_wave_valid) { @@ -703,20 +701,14 @@ for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_ loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); } workgroupBarrier(); - if (k_idx_global_using_wave_valid) + + if (q_idx_global_using_wave_valid) { - for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.new_sequence_length; q_idx++) + for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start < uniforms.present_sequence_length; k_idx++) { - // Leveraging the subgroups for parallelism, compute dot product of QK. - // Because for the case of new_seq 1, there is a single query and context length of K - // we iterate over q and use the waves for K so that this step can use all the waves in - // in the workgroup. - // We validate q_idx,wave_id to be less than TILE_SIZE, computeDotProduct only needs to - // validate sg_id as being less than QKV_HEAD_VECTORIZED_SIZE. - computeDotProduct(q_idx, wave_id, sg_id, sg_size); + computeDotProduct(wave_id, k_idx, sg_id, sg_size); } } - workgroupBarrier(); let wave_lane_valid:bool = q_idx_global_using_wave_valid && sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; computeSoftMax(wave_id, sg_id, wave_lane_valid); From d1d81753bdfd080fd7bac7969c01b77f7bb77b55 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sat, 23 Nov 2024 02:31:10 -0800 Subject: [PATCH 14/25] Refactor into separate file --- .../webgpu/bert/flash_attention.cc | 372 +++++++++++++++++ .../contrib_ops/webgpu/bert/flash_attention.h | 69 ++++ .../webgpu/bert/multihead_attention.cc | 376 +----------------- .../webgpu/bert/multihead_attention.h | 52 +-- 4 files changed, 458 insertions(+), 411 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/flash_attention.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc new file mode 100644 index 0000000000000..885f9e4d2f198 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -0,0 +1,372 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/flash_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::multihead_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Expectations are + // qkv have same number of heads and hidden dimension (head size). + // qkv are in BSNH format. + // B - batch size but shader only supports batch_size 1. + // S - current sequence length but shader supports only S = 1. + // N - number of heads. + // H - head size or hidden dimension for each qkv head. + // KV cache is stored as BN(total_sequence_length)H + // Attention bias is in BN(total_sequence_length) + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + if (has_past_) { + shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + } + shader.AddOutput("present_key", ShaderUsage::UseUniform); + shader.AddOutput("present_value", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" + << "let kIdx = workgroup_id.x;\n" + << "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n"; + if (has_past_) { + shader.MainFunctionBody() << "if (kIdx < uniforms.past_sequence_length) {\n" + << " let pastKeyOffset = headIdx * uniforms.past_sequence_length * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n" + << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" + << " present_key[presentKeyOffset+w] = past_key[pastKeyOffset+w];\n" + << " present_value[presentKeyOffset+w] = past_value[pastKeyOffset+w];\n" + << " }\n" + << "}\n" + << "else if (kIdx >= uniforms.past_sequence_length) {\n"; + } else { + shader.MainFunctionBody() << "if (kIdx >= uniforms.past_sequence_length) {\n"; + } + shader.MainFunctionBody() << " let nkIdx = kIdx - uniforms.past_sequence_length;\n" + << " // Assumes kv have BSNH layout. num_workgroups.z is the num_head as per the dispatch requirement.\n" + << " let nOffset = nkIdx * uniforms.vectorized_head_size * num_workgroups.z + headIdx*uniforms.vectorized_head_size;\n" + << " // Assumes kv have BNSH layout.\n" + << " // let nOffset = headIdx * uniforms.kv_sequence_length * uniforms.vectorized_head_size + nkIdx * uniforms.vectorized_head_size;\n" + << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" + << " present_key[presentKeyOffset+w] = key[nOffset+w];\n" + << " present_value[presentKeyOffset+w] = value[nOffset+w];\n" + << " }\n" + << "}\n"; + + return Status::OK(); +} + +Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, AttentionParameters& parameters, + const Tensor* K, const Tensor* past_key, Tensor* present_key, + const Tensor* V, const Tensor* past_value, Tensor* present_value, + int past_sequence_length, int total_sequence_length) { + + const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); + bool has_past = (past_sequence_length != 0); + CopyKVCacheProgram program{"CopyKVCache", components, has_past}; + if (has_past) { + program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, + {V, ProgramTensorMetadataDependency::TypeAndRank, components}, + {past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, + {past_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); + } else { + program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, + {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); + } + + program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components}, + {present_value, ProgramTensorMetadataDependency::Rank, components}}); + + program.SetDispatchGroupSize(total_sequence_length, 1, parameters.num_heads) + .SetWorkgroupSize(1) + .CacheHint(std::to_string(components) + std::to_string(has_past)) + .AddUniformVariables({{static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}, + {static_cast(parameters.head_size/ components)}}); + + return context.RunProgram(program); +} + +Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Expectations are + // qkv have same number of heads and hidden dimension (head size). + // qkv are in BSNH format. + // B - batch size but shader only supports batch_size 1. + // S - current sequence length but shader supports only S = 1. + // N - number of heads. + // H - head size or hidden dimension for each qkv head. + // KV cache is stored as BN(total_sequence_length)H + // Attention bias is in BN(new_sequence_length)(total_sequence_length) + // + // Expectation is that present_key, and present_value contain past key and values since + // we are out of storage buffers a shader can have and both past/present cant be passed. + // The hidden size of each q head should be a multiple of 4 because shader uses vectorized loads. + constexpr int vectorization_size = 4; + shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("present_key", ShaderUsage::UseUniform); + shader.AddInput("present_value", ShaderUsage::UseUniform); + if (has_attention_bias_) { + shader.AddInput("attention_bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + + // SUBGROUP_SIZE has to be the same as sg_size. For intel this will be 8. + // TILE_SIZE is the number of groups sharing the k_tile. + // TILE_SIZE has to be <= SUBGROUP_SIZE. Ideal perf of computeSoftMax is when + // TILE_SIZE == SUBGROUP_SIZE. This is a sperate constant from SUBGROUP_SIZE + // because SUBGROUP_SIZE * TILE_SIZE has to be <= 256 as per webgpu + // gpu limits. For Intel this TILE_SIZE will be 8. + shader.AdditionalImplementation() << "const SUBGROUP_SIZE: u32 = " << subgroup_size_ << ";\n" + << "const TILE_SIZE: u32 = " << tile_size_ << ";\n" + << "const VECTOR_SIZE: u32 = " << vectorization_size << ";\n" + << "const QKV_HEAD_SIZE: u32 = " << qkv_head_size_ << ";\n" + << "const QKV_HEAD_VECTORIZED_SIZE: u32 = QKV_HEAD_SIZE / VECTOR_SIZE;\n" + << "const NUM_HEADS: u32 = " << qkv_num_heads_ << ";\n" + << "const MIN_VALUE : q_element_t = -6504.0h;\n"; + + // Best to keep SHM usage per workgroup < 8KB. 4KB is the limit on a 48EU tigerlake + // GPU afterwhich workgroups will be unscheduled to make space for memory. + shader.AdditionalImplementation() << "var q_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var k_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var v_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var o_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var qk_tile : array, TILE_SIZE>; // 8 * 2 * 8 = 128\n" + << "var max_tile : array; // 2 * 8 = 16\n" + << "var denom_tile : array; // 2 * 8 = 16\n" + << "var o_ratio : array; // 2 * 8 = 16\n"; + + shader.AdditionalImplementation() << R"HELPER_FN( +fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) +{ + // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA + // This is the layout if TransferBSDToBNSH has not been run. + let offset = q_idx_global * (QKV_HEAD_VECTORIZED_SIZE) * NUM_HEADS + QKV_HEAD_VECTORIZED_SIZE * head_idx; + // Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run. + // let offset = head_idx * uniforms.new_sequence_length * QKV_HEAD_VECTORIZED_SIZE + q_idx_global * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) + { + var value = q[idx+offset]; + q_tile[slot][idx] = value; + } +} +fn loadk(slot: u32, k_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) +{ + // Stored as float16[batch_size,num_heads,present_sequence_length,96] + let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + k_idx_global * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) + { + var value = present_key[idx+offset]; + k_tile[slot][idx] = value; + } +} +fn loadv(slot: u32, v_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) +{ + // Stored as float16[batch_size,num_heads,present_sequence_length,96] + let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + v_idx_global * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) + { + v_tile[slot][idx] = present_value[idx+offset]; + } +} +fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32) +{ + // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] + if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length || k_col >= TILE_SIZE) { + qk_tile[q_row][k_col] = 0.0; + return; + } + let offset = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length + k_idx_global; + qk_tile[q_row][k_col] = attention_bias[offset]; +} +fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) +{ + // Stored as float16[batch_size,sequence_length,3072] + let offset = o_idx_global * NUM_HEADS * QKV_HEAD_VECTORIZED_SIZE + head_idx * QKV_HEAD_VECTORIZED_SIZE; + for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx += sg_size) + { + let value = o_tile[slot][idx]; + output[offset+idx] = value; + } +} +fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) +{ + var sum:vec4 = q_value_t(0, 0, 0, 0); + // idx is not initialized to sg_id to ensure uniformity because the loop uses + // subgroupAdd and unused lanes need to be initialized with 0 for correctness. + for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) + { + var result = q_value_t(0); + let sg_idx = idx+sg_id; + if (sg_idx < QKV_HEAD_VECTORIZED_SIZE) + { + result = q_tile[q_idx][sg_idx]*k_tile[k_idx][sg_idx]; + } + sum += subgroupAdd(result); + } + if (sg_id == 0) + { + let single_sum : q_element_t = sum.x + sum.y + sum.z + sum.w; + let sqrt_dk = q_element_t(uniforms.alpha); + let value = single_sum * sqrt_dk; + qk_tile[q_idx][k_idx] += value; + } +} +fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) +{ + var x = MIN_VALUE; + if (enabled){ + x = qk_tile[q_idx][sg_id]; + } + var max_value = subgroupMax(x); + max_value = max(max_tile[q_idx], max_value); + let sub = x - max_value; + var value:q_element_t = 0; + if (enabled) { + value = exp(sub); + } + let sum = subgroupAdd(value); + // Compute lhs term of update di prime and the compute di prime. + let dleft = denom_tile[q_idx] * exp(max_tile[q_idx]-max_value); + var d = dleft + sum; + if (d == 0) + { + // Avoid division by zero by setting d to a really small value. + d = 0.0000001h; + } + qk_tile[q_idx][sg_id] = value / d; + if (sg_id == 0) + { + max_tile[q_idx] = max_value; + denom_tile[q_idx] = d; + o_ratio[q_idx] = dleft / d; + } +} +fn computeO(q_idx: u32, sg_id:u32, enabled:bool) +{ + var attn = q_element_t(0); + if (enabled) + { + attn = qk_tile[q_idx][sg_id]; + } + for (var i:u32 = 0; i < QKV_HEAD_VECTORIZED_SIZE; i++) + { + let val = v_tile[sg_id][i]; + var intermediate = attn * val; + let sum = subgroupAdd(intermediate); + if (sg_id == 0) + { + let o_ratio = o_ratio[q_idx]; + let old_o = o_tile[q_idx][i]; + let new_o = ( o_ratio * old_o) + sum; + o_tile[q_idx][i] = new_o; + } + } +} +)HELPER_FN"; + +// Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / TILE_SIZE, 1) +// Each workgroup is responsible for a range of q values (TILE_SIZE) and visits all Ks for those q's. + shader.MainFunctionBody() << R"MAIN_FN( +let head_idx = workgroup_id.x; +// It is always the case that 0 <= wave_id < TILE_SIZE +// Each wave has sg_size lanes (subgroup threads). +let wave_id = u32(local_idx / sg_size); + +let q_idx_start = workgroup_id.y * TILE_SIZE; +let q_idx_global = q_idx_start + wave_id; +let q_idx_global_using_wave_valid = q_idx_global < uniforms.new_sequence_length; +if (q_idx_global_using_wave_valid) +{ + // Each invocation (wave_id) gets lane threads (subgroup threads) and is responsible for 1 query. + loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size); +} +if (sg_id == 0) +{ + max_tile[wave_id] = MIN_VALUE; +} +for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) +{ + workgroupBarrier(); + let k_idx_global = k_start+wave_id; + let k_idx_global_using_wave_valid = k_idx_global < uniforms.present_sequence_length; + if (k_idx_global_using_wave_valid) { + // Leveraging the subgroup lanes for parallelism, load into slot wave_id + // K/V values from k_start+wave_id. + loadk(wave_id, k_idx_global, head_idx, sg_id, sg_size); + loadv(wave_id, k_idx_global, head_idx, sg_id, sg_size); + // Next, we want for every q row (wave_id) to populate bias for new sequence length + // (k_start+sg_id). loadAttentionBias handles range checking q_idx_global, + // and sg_id, (k_start+sg_id). + loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); + } + workgroupBarrier(); + + if (q_idx_global_using_wave_valid) + { + for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start < uniforms.present_sequence_length; k_idx++) + { + computeDotProduct(wave_id, k_idx, sg_id, sg_size); + } + } + + let wave_lane_valid:bool = q_idx_global_using_wave_valid && sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; + computeSoftMax(wave_id, sg_id, wave_lane_valid); + computeO(wave_id, sg_id, wave_lane_valid); +} +workgroupBarrier(); +if (q_idx_global_using_wave_valid) +{ + writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size); +} +)MAIN_FN"; + + return Status::OK(); +} + +Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); + + constexpr int subgroup_size = 16; + constexpr int tile_size = 16; + bool has_attention_bias = attention_bias != nullptr; + FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size, parameters.num_heads}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, + {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, + {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}, + {attention_bias, ProgramTensorMetadataDependency::TypeAndRank}}); + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}}); + const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; + std::string cache_hint = std::to_string(has_attention_bias) + + std::to_string(subgroup_size) + + std::to_string(tile_size) + + std::to_string(parameters.head_size) + + std::to_string(parameters.num_heads); + program.SetDispatchGroupSize(parameters.num_heads, (parameters.sequence_length + tile_size - 1) / tile_size, 1) + .SetWorkgroupSize(subgroup_size*subgroup_size) + .CacheHint(cache_hint) + .AddUniformVariables({{static_cast(parameters.sequence_length)}, + {static_cast(parameters.total_sequence_length)}, + {alpha}}); + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h new file mode 100644 index 0000000000000..95cbd94cf503b --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class CopyKVCacheProgram final : public Program { + public: + CopyKVCacheProgram(const std::string& kernel_name, int components, bool has_past) + : Program{kernel_name}, components_(components), has_past_(has_past) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"vectorized_head_size", ProgramUniformVariableDataType::Uint32}); + + private: + int components_; + bool has_past_; +}; + +class FlashAttentionProgram final : public Program { + public: + FlashAttentionProgram(const std::string& kernel_name, + bool has_attention_bias, + int subgroup_size, + int tile_size, + int qkv_head_size, + int qkv_num_heads) + : Program{kernel_name}, + has_attention_bias_(has_attention_bias), + subgroup_size_(subgroup_size), + tile_size_(tile_size), + qkv_head_size_(qkv_head_size), + qkv_num_heads_(qkv_num_heads) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}); + + private: + bool has_attention_bias_; + int subgroup_size_; + int tile_size_; + int qkv_head_size_; + int qkv_num_heads_; +}; + +Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 1f8d725992091..1264ebb05dfdb 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -7,6 +7,7 @@ #include #include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/flash_attention.h" #include "contrib_ops/webgpu/bert/multihead_attention.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" @@ -94,7 +95,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { } shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - if (has_present_key_ && !fa_variant_) { + if (has_present_key_) { shader.AddOutput("present_key", ShaderUsage::UseUniform); } @@ -137,7 +138,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << " tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];\n"; } - if (has_present_key_ && !fa_variant_) { + if (has_present_key_) { shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"; } @@ -166,7 +167,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, - AttentionParameters& parameters, int past_sequence_length, int total_sequence_length, bool fa_variant = false) { + AttentionParameters& parameters, int past_sequence_length, int total_sequence_length) { const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; @@ -177,7 +178,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, fa_variant}; + components}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -187,7 +188,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); - if (has_present_key && !fa_variant) { + if (has_present_key) { program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); } @@ -196,7 +197,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length + tile_size - 1) / tile_size, parameters.batch_size * parameters.num_heads) .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size)+std::to_string(fa_variant)) + .CacheHint(std::to_string(tile_size)) .AddUniformVariables({{static_cast(parameters.sequence_length)}, {static_cast(vectorized_head_size)}, {static_cast(total_sequence_length)}, @@ -279,7 +280,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } shader.AddOutput("output", ShaderUsage::UseUniform); - if (has_present_value_ && !fa_variant_) { + if (has_present_value_) { shader.AddOutput("present_value", ShaderUsage::UseUniform); } @@ -320,7 +321,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << " tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];\n"; } - if (has_present_value_ && !fa_variant_) { + if (has_present_value_) { shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"; } @@ -352,20 +353,19 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int Tensor* present_value, AttentionParameters& parameters, int past_sequence_length, - int total_sequence_length, - bool fa_variant) { + int total_sequence_length) { const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0; const bool has_present_value = output_count > 1 && past_value != nullptr; constexpr int tile_size = 12; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, fa_variant}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); - if (has_present_value && !fa_variant) { + if (has_present_value) { program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank}); } @@ -388,7 +388,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, bool fa_variant = false) { + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length : 0; const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length; @@ -398,363 +398,17 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T const TensorShape probs_shape(probs_dims); Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, - parameters, past_sequence_length, total_sequence_length, fa_variant)); + parameters, past_sequence_length, total_sequence_length)); ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, parameters.batch_size * parameters.num_heads * parameters.sequence_length, total_sequence_length)); ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, - parameters, past_sequence_length, total_sequence_length, fa_variant)); + parameters, past_sequence_length, total_sequence_length)); return Status::OK(); } -Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { - // Expectations are - // qkv have same number of heads and hidden dimension (head size). - // qkv are in BSNH format. - // B - batch size but shader only supports batch_size 1. - // S - current sequence length but shader supports only S = 1. - // N - number of heads. - // H - head size or hidden dimension for each qkv head. - // KV cache is stored as BN(total_sequence_length)H - // Attention bias is in BN(total_sequence_length) - shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - if (has_past_) { - shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - } - shader.AddOutput("present_key", ShaderUsage::UseUniform); - shader.AddOutput("present_value", ShaderUsage::UseUniform); - - shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" - << "let kIdx = workgroup_id.x;\n" - << "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n"; - if (has_past_) { - shader.MainFunctionBody() << "if (kIdx < uniforms.past_sequence_length) {\n" - << " let pastKeyOffset = headIdx * uniforms.past_sequence_length * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n" - << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" - << " present_key[presentKeyOffset+w] = past_key[pastKeyOffset+w];\n" - << " present_value[presentKeyOffset+w] = past_value[pastKeyOffset+w];\n" - << " }\n" - << "}\n" - << "else if (kIdx >= uniforms.past_sequence_length) {\n"; - } else { - shader.MainFunctionBody() << "if (kIdx >= uniforms.past_sequence_length) {\n"; - } - shader.MainFunctionBody() << " let nkIdx = kIdx - uniforms.past_sequence_length;\n" - << " // Assumes kv have BSNH layout. num_workgroups.z is the num_head as per the dispatch requirement.\n" - << " let nOffset = nkIdx * uniforms.vectorized_head_size * num_workgroups.z + headIdx*uniforms.vectorized_head_size;\n" - << " // Assumes kv have BNSH layout.\n" - << " // let nOffset = headIdx * uniforms.kv_sequence_length * uniforms.vectorized_head_size + nkIdx * uniforms.vectorized_head_size;\n" - << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" - << " present_key[presentKeyOffset+w] = key[nOffset+w];\n" - << " present_value[presentKeyOffset+w] = value[nOffset+w];\n" - << " }\n" - << "}\n"; - - return Status::OK(); -} - -Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, AttentionParameters& parameters, - const Tensor* K, const Tensor* past_key, Tensor* present_key, - const Tensor* V, const Tensor* past_value, Tensor* present_value, - int past_sequence_length, int total_sequence_length) { - - const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); - bool has_past = (past_sequence_length != 0); - CopyKVCacheProgram program{"CopyKVCache", components, has_past}; - if (has_past) { - program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, - {V, ProgramTensorMetadataDependency::TypeAndRank, components}, - {past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, - {past_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); - } else { - program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, - {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); - } - - program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components}, - {present_value, ProgramTensorMetadataDependency::Rank, components}}); - - program.SetDispatchGroupSize(total_sequence_length, 1, parameters.num_heads) - .SetWorkgroupSize(1) - .CacheHint(std::to_string(components) + std::to_string(has_past)) - .AddUniformVariables({{static_cast(past_sequence_length)}, - {static_cast(parameters.kv_sequence_length)}, - {static_cast(parameters.head_size/ components)}}); - - return context.RunProgram(program); -} - -Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { - // Expectations are - // qkv have same number of heads and hidden dimension (head size). - // qkv are in BSNH format. - // B - batch size but shader only supports batch_size 1. - // S - current sequence length but shader supports only S = 1. - // N - number of heads. - // H - head size or hidden dimension for each qkv head. - // KV cache is stored as BN(total_sequence_length)H - // Attention bias is in BN(new_sequence_length)(total_sequence_length) - // - // Expectation is that present_key, and present_value contain past key and values since - // we are out of storage buffers a shader can have and both past/present cant be passed. - // The hidden size of each q head should be a multiple of 4 because shader uses vectorized loads. - constexpr int vectorization_size = 4; - shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("present_key", ShaderUsage::UseUniform); - shader.AddInput("present_value", ShaderUsage::UseUniform); - if (has_attention_bias_) { - shader.AddInput("attention_bias", ShaderUsage::UseUniform); - } - shader.AddOutput("output", ShaderUsage::UseUniform); - - // SUBGROUP_SIZE has to be the same as sg_size. For intel this will be 8. - // TILE_SIZE is the number of groups sharing the k_tile. - // TILE_SIZE has to be <= SUBGROUP_SIZE. Ideal perf of computeSoftMax is when - // TILE_SIZE == SUBGROUP_SIZE. This is a sperate constant from SUBGROUP_SIZE - // because SUBGROUP_SIZE * TILE_SIZE has to be <= 256 as per webgpu - // gpu limits. For Intel this TILE_SIZE will be 8. - shader.AdditionalImplementation() << "const SUBGROUP_SIZE: u32 = " << subgroup_size_ << ";\n" - << "const TILE_SIZE: u32 = " << tile_size_ << ";\n" - << "const VECTOR_SIZE: u32 = " << vectorization_size << ";\n" - << "const QKV_HEAD_SIZE: u32 = " << qkv_head_size_ << ";\n" - << "const QKV_HEAD_VECTORIZED_SIZE: u32 = QKV_HEAD_SIZE / VECTOR_SIZE;\n" - << "const NUM_HEADS: u32 = " << qkv_num_heads_ << ";\n" - << "const MIN_VALUE : q_element_t = -6504.0h;\n"; - - // Best to keep SHM usage per workgroup < 8KB. 4KB is the limit on a 48EU tigerlake - // GPU afterwhich workgroups will be unscheduled to make space for memory. - shader.AdditionalImplementation() << "var q_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var k_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var v_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var o_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var qk_tile : array, TILE_SIZE>; // 8 * 2 * 8 = 128\n" - << "var max_tile : array; // 2 * 8 = 16\n" - << "var denom_tile : array; // 2 * 8 = 16\n" - << "var o_ratio : array; // 2 * 8 = 16\n"; - - shader.AdditionalImplementation() << R"HELPER_FN( -fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) -{ - // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA - // This is the layout if TransferBSDToBNSH has not been run. - let offset = q_idx_global * (QKV_HEAD_VECTORIZED_SIZE) * NUM_HEADS + QKV_HEAD_VECTORIZED_SIZE * head_idx; - // Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run. - // let offset = head_idx * uniforms.new_sequence_length * QKV_HEAD_VECTORIZED_SIZE + q_idx_global * QKV_HEAD_VECTORIZED_SIZE; - for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) - { - var value = q[idx+offset]; - q_tile[slot][idx] = value; - } -} -fn loadk(slot: u32, k_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) -{ - // Stored as float16[batch_size,num_heads,present_sequence_length,96] - let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + k_idx_global * QKV_HEAD_VECTORIZED_SIZE; - for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) - { - var value = present_key[idx+offset]; - k_tile[slot][idx] = value; - } -} -fn loadv(slot: u32, v_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32) -{ - // Stored as float16[batch_size,num_heads,present_sequence_length,96] - let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + v_idx_global * QKV_HEAD_VECTORIZED_SIZE; - for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size) - { - v_tile[slot][idx] = present_value[idx+offset]; - } -} -fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32) -{ - // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length || k_col >= TILE_SIZE) { - qk_tile[q_row][k_col] = 0.0; - return; - } - let offset = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length + k_idx_global; - qk_tile[q_row][k_col] = attention_bias[offset]; -} -fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) -{ - // Stored as float16[batch_size,sequence_length,3072] - let offset = o_idx_global * NUM_HEADS * QKV_HEAD_VECTORIZED_SIZE + head_idx * QKV_HEAD_VECTORIZED_SIZE; - for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx += sg_size) - { - let value = o_tile[slot][idx]; - output[offset+idx] = value; - } -} -fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) -{ - var sum:vec4 = q_value_t(0, 0, 0, 0); - // idx is not initialized to sg_id to ensure uniformity because the loop uses - // subgroupAdd and unused lanes need to be initialized with 0 for correctness. - for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) - { - var result = q_value_t(0); - let sg_idx = idx+sg_id; - if (sg_idx < QKV_HEAD_VECTORIZED_SIZE) - { - result = q_tile[q_idx][sg_idx]*k_tile[k_idx][sg_idx]; - } - sum += subgroupAdd(result); - } - if (sg_id == 0) - { - let single_sum : q_element_t = sum.x + sum.y + sum.z + sum.w; - let sqrt_dk = q_element_t(uniforms.alpha); - let value = single_sum * sqrt_dk; - qk_tile[q_idx][k_idx] += value; - } -} -fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) -{ - var x = MIN_VALUE; - if (enabled){ - x = qk_tile[q_idx][sg_id]; - } - var max_value = subgroupMax(x); - max_value = max(max_tile[q_idx], max_value); - let sub = x - max_value; - var value:q_element_t = 0; - if (enabled) { - value = exp(sub); - } - let sum = subgroupAdd(value); - // Compute lhs term of update di prime and the compute di prime. - let dleft = denom_tile[q_idx] * exp(max_tile[q_idx]-max_value); - var d = dleft + sum; - if (d == 0) - { - // Avoid division by zero by setting d to a really small value. - d = 0.0000001h; - } - qk_tile[q_idx][sg_id] = value / d; - if (sg_id == 0) - { - max_tile[q_idx] = max_value; - denom_tile[q_idx] = d; - o_ratio[q_idx] = dleft / d; - } -} -fn computeO(q_idx: u32, sg_id:u32, enabled:bool) -{ - var attn = q_element_t(0); - if (enabled) - { - attn = qk_tile[q_idx][sg_id]; - } - for (var i:u32 = 0; i < QKV_HEAD_VECTORIZED_SIZE; i++) - { - let val = v_tile[sg_id][i]; - var intermediate = attn * val; - let sum = subgroupAdd(intermediate); - if (sg_id == 0) - { - let o_ratio = o_ratio[q_idx]; - let old_o = o_tile[q_idx][i]; - let new_o = ( o_ratio * old_o) + sum; - o_tile[q_idx][i] = new_o; - } - } -} -)HELPER_FN"; - -// Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / TILE_SIZE, 1) -// Each workgroup is responsible for a range of q values (TILE_SIZE) and visits all Ks for those q's. - shader.MainFunctionBody() << R"MAIN_FN( -let head_idx = workgroup_id.x; -// It is always the case that 0 <= wave_id < TILE_SIZE -// Each wave has sg_size lanes (subgroup threads). -let wave_id = u32(local_idx / sg_size); - -let q_idx_start = workgroup_id.y * TILE_SIZE; -let q_idx_global = q_idx_start + wave_id; -let q_idx_global_using_wave_valid = q_idx_global < uniforms.new_sequence_length; -if (q_idx_global_using_wave_valid) -{ - // Each invocation (wave_id) gets lane threads (subgroup threads) and is responsible for 1 query. - loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size); -} -if (sg_id == 0) -{ - max_tile[wave_id] = MIN_VALUE; -} -for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) -{ - workgroupBarrier(); - let k_idx_global = k_start+wave_id; - let k_idx_global_using_wave_valid = k_idx_global < uniforms.present_sequence_length; - if (k_idx_global_using_wave_valid) { - // Leveraging the subgroup lanes for parallelism, load into slot wave_id - // K/V values from k_start+wave_id. - loadk(wave_id, k_idx_global, head_idx, sg_id, sg_size); - loadv(wave_id, k_idx_global, head_idx, sg_id, sg_size); - // Next, we want for every q row (wave_id) to populate bias for new sequence length - // (k_start+sg_id). loadAttentionBias handles range checking q_idx_global, - // and sg_id, (k_start+sg_id). - loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); - } - workgroupBarrier(); - - if (q_idx_global_using_wave_valid) - { - for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start < uniforms.present_sequence_length; k_idx++) - { - computeDotProduct(wave_id, k_idx, sg_id, sg_size); - } - } - - let wave_lane_valid:bool = q_idx_global_using_wave_valid && sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; - computeSoftMax(wave_id, sg_id, wave_lane_valid); - computeO(wave_id, sg_id, wave_lane_valid); -} -workgroupBarrier(); -if (q_idx_global_using_wave_valid) -{ - writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size); -} -)MAIN_FN"; - - return Status::OK(); -} - -Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, - Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, - AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); - - constexpr int subgroup_size = 16; - constexpr int tile_size = 16; - bool has_attention_bias = attention_bias != nullptr; - FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size, parameters.num_heads}; - program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, - {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, - {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}, - {attention_bias, ProgramTensorMetadataDependency::TypeAndRank}}); - program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}}); - const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) - : parameters.scale; - std::string cache_hint = std::to_string(has_attention_bias) + - std::to_string(subgroup_size) + - std::to_string(tile_size) + - std::to_string(parameters.head_size) + - std::to_string(parameters.num_heads); - program.SetDispatchGroupSize(parameters.num_heads, (parameters.sequence_length + tile_size - 1) / tile_size, 1) - .SetWorkgroupSize(subgroup_size*subgroup_size) - .CacheHint(cache_hint) - .AddUniformVariables({{static_cast(parameters.sequence_length)}, - {static_cast(parameters.total_sequence_length)}, - {alpha}}); - - return context.RunProgram(program); -} - MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : WebGpuKernel(info) { int64_t num_heads = 0; diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index b6d7aa3811672..71cf2b5e9208f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -33,8 +33,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, bool fa_variant = false) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), fa_variant_(fa_variant) { + bool has_attention_bias, int tile_size, int components) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -55,7 +55,6 @@ class AttentionProbsProgram final : public Program { bool has_attention_bias_; int tile_size_; int components_; - bool fa_variant_; }; class InPlaceSoftmaxProgram final : public Program { @@ -99,53 +98,6 @@ class VxAttentionScoreProgram final : public Program { bool fa_variant_; }; -class CopyKVCacheProgram final : public Program { - public: - CopyKVCacheProgram(const std::string& kernel_name, int components, bool has_past) - : Program{kernel_name}, components_(components), has_past_(has_past) { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"past_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"vectorized_head_size", ProgramUniformVariableDataType::Uint32}); - - private: - int components_; - bool has_past_; -}; - -class FlashAttentionProgram final : public Program { - public: - FlashAttentionProgram(const std::string& kernel_name, - bool has_attention_bias, - int subgroup_size, - int tile_size, - int qkv_head_size, - int qkv_num_heads) - : Program{kernel_name}, - has_attention_bias_(has_attention_bias), - subgroup_size_(subgroup_size), - tile_size_(tile_size), - qkv_head_size_(qkv_head_size), - qkv_num_heads_(qkv_num_heads) { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"alpha", ProgramUniformVariableDataType::Float32}); - - private: - bool has_attention_bias_; - int subgroup_size_; - int tile_size_; - int qkv_head_size_; - int qkv_num_heads_; -}; - class MultiHeadAttention final : public WebGpuKernel { public: MultiHeadAttention(const OpKernelInfo& info); From 6febf6c1013366f74535efa74fc3d1ec36ab8162 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sat, 23 Nov 2024 11:42:56 -0800 Subject: [PATCH 15/25] Get subgroup size from device limits --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 4 ++-- onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | 8 ++++---- onnxruntime/core/providers/webgpu/compute_context.h | 3 +++ onnxruntime/core/providers/webgpu/webgpu_context.cc | 3 +++ onnxruntime/core/providers/webgpu/webgpu_context.h | 2 ++ 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 885f9e4d2f198..f219beb1eb1ae 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -341,8 +341,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); - constexpr int subgroup_size = 16; - constexpr int tile_size = 16; + const uint32_t subgroup_size = context.MinSubgroupSize(); + const uint32_t tile_size = subgroup_size; bool has_attention_bias = attention_bias != nullptr; FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size, parameters.num_heads}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 95cbd94cf503b..d9188496821a4 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -35,8 +35,8 @@ class FlashAttentionProgram final : public Program { public: FlashAttentionProgram(const std::string& kernel_name, bool has_attention_bias, - int subgroup_size, - int tile_size, + uint32_t subgroup_size, + uint32_t tile_size, int qkv_head_size, int qkv_num_heads) : Program{kernel_name}, @@ -55,8 +55,8 @@ class FlashAttentionProgram final : public Program { private: bool has_attention_bias_; - int subgroup_size_; - int tile_size_; + uint32_t subgroup_size_; + uint32_t tile_size_; int qkv_head_size_; int qkv_num_heads_; }; diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index b7ea8a58e232b..22cdf389bdf7c 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -41,6 +41,9 @@ class ComputeContext { inline const wgpu::Limits& DeviceLimits() const { return webgpu_context_.DeviceLimits(); } + inline const uint32_t MinSubgroupSize() const { + return webgpu_context_.MinSubgroupSize(); + } // // Get the kernel context. diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 36aab2e628a16..ef0d4a987e483 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -118,8 +118,11 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info ORT_ENFORCE(Adapter().GetInfo(&adapter_info_)); // cache device limits wgpu::SupportedLimits device_supported_limits; + wgpu::DawnExperimentalSubgroupLimits subgroup_limits; + device_supported_limits.nextInChain = &subgroup_limits; ORT_ENFORCE(Device().GetLimits(&device_supported_limits)); device_limits_ = device_supported_limits.limits; + min_subgroup_size_ = subgroup_limits.minSubgroupSize; // create buffer manager buffer_mgr_ = BufferManagerFactory::Create(*this, webgpu_ep_info.storage_buffer_cache_mode, webgpu_ep_info.uniform_buffer_cache_mode, webgpu_ep_info.query_resolve_buffer_cache_mode); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index be05b06523b9c..0b8b526b2c381 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -55,6 +55,7 @@ class WebGpuContext final { const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; } const wgpu::Limits& DeviceLimits() const { return device_limits_; } + uint32_t MinSubgroupSize() const { return min_subgroup_size_; } const wgpu::CommandEncoder& GetCommandEncoder() { if (!current_command_encoder_) { @@ -161,6 +162,7 @@ class WebGpuContext final { wgpu::AdapterInfo adapter_info_; wgpu::Limits device_limits_; + uint32_t min_subgroup_size_ = 0; wgpu::CommandEncoder current_command_encoder_; wgpu::ComputePassEncoder current_compute_pass_encoder_; From ec59ba5f487d3c5c596b54f02bdac61e2df56bff Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sat, 23 Nov 2024 15:51:23 -0800 Subject: [PATCH 16/25] Switch to the dot product optimized for 1 token length, since we are not creating a seperate shader for new seq length == 1 case. --- .../contrib_ops/webgpu/bert/flash_attention.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index f219beb1eb1ae..d48349a546e59 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -314,13 +314,20 @@ for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_ } workgroupBarrier(); - if (q_idx_global_using_wave_valid) + if (k_idx_global_using_wave_valid) { - for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start < uniforms.present_sequence_length; k_idx++) + for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.new_sequence_length; q_idx++) { - computeDotProduct(wave_id, k_idx, sg_id, sg_size); + // Leveraging the subgroups for parallelism, compute dot product of QK. + // Because for the case of new_seq 1, there is a single query and context length of K + // we iterate over q and use the waves for K so that this step can use all the waves in + // in the workgroup. + // We validate q_idx,wave_id to be less than TILE_SIZE, computeDotProduct only needs to + // validate sg_id as being less than QKV_HEAD_VECTORIZED_SIZE. + computeDotProduct(q_idx, wave_id, sg_id, sg_size); } } + workgroupBarrier(); let wave_lane_valid:bool = q_idx_global_using_wave_valid && sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; computeSoftMax(wave_id, sg_id, wave_lane_valid); From 0dd5b706daa61acadbeb0a1f21266b002f08efa4 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sat, 23 Nov 2024 18:22:19 -0800 Subject: [PATCH 17/25] Add an alias to control math precision of flash attention. --- .../webgpu/bert/flash_attention.cc | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index d48349a546e59..faaa821ae3fbe 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -129,13 +129,15 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // TILE_SIZE == SUBGROUP_SIZE. This is a sperate constant from SUBGROUP_SIZE // because SUBGROUP_SIZE * TILE_SIZE has to be <= 256 as per webgpu // gpu limits. For Intel this TILE_SIZE will be 8. + // Change precision_t to be f32 below to run dotproduct/ softmax in fp32 precision. shader.AdditionalImplementation() << "const SUBGROUP_SIZE: u32 = " << subgroup_size_ << ";\n" << "const TILE_SIZE: u32 = " << tile_size_ << ";\n" << "const VECTOR_SIZE: u32 = " << vectorization_size << ";\n" << "const QKV_HEAD_SIZE: u32 = " << qkv_head_size_ << ";\n" << "const QKV_HEAD_VECTORIZED_SIZE: u32 = QKV_HEAD_SIZE / VECTOR_SIZE;\n" << "const NUM_HEADS: u32 = " << qkv_num_heads_ << ";\n" - << "const MIN_VALUE : q_element_t = -6504.0h;\n"; + << "alias precision_t = q_element_t;\n" + << "const MIN_VALUE : precision_t = precision_t(-6504.0h);\n"; // Best to keep SHM usage per workgroup < 8KB. 4KB is the limit on a 48EU tigerlake // GPU afterwhich workgroups will be unscheduled to make space for memory. @@ -143,10 +145,10 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var k_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" << "var v_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" << "var o_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var qk_tile : array, TILE_SIZE>; // 8 * 2 * 8 = 128\n" - << "var max_tile : array; // 2 * 8 = 16\n" - << "var denom_tile : array; // 2 * 8 = 16\n" - << "var o_ratio : array; // 2 * 8 = 16\n"; + << "var qk_tile : array, TILE_SIZE>; // 8 * 2 * 8 = 128\n" + << "var max_tile : array; // 2 * 8 = 16\n" + << "var denom_tile : array; // 2 * 8 = 16\n" + << "var o_ratio : array; // 2 * 8 = 16\n"; shader.AdditionalImplementation() << R"HELPER_FN( fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) @@ -189,7 +191,7 @@ fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : return; } let offset = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length + k_idx_global; - qk_tile[q_row][k_col] = attention_bias[offset]; + qk_tile[q_row][k_col] = precision_t(attention_bias[offset]); } fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) { @@ -203,37 +205,37 @@ fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u } fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) { - var sum:vec4 = q_value_t(0, 0, 0, 0); + var sum:vec4 = vec4(0, 0, 0, 0); // idx is not initialized to sg_id to ensure uniformity because the loop uses // subgroupAdd and unused lanes need to be initialized with 0 for correctness. for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx+= sg_size) { - var result = q_value_t(0); + var result = vec4(0); let sg_idx = idx+sg_id; if (sg_idx < QKV_HEAD_VECTORIZED_SIZE) { - result = q_tile[q_idx][sg_idx]*k_tile[k_idx][sg_idx]; + result = vec4(q_tile[q_idx][sg_idx])*vec4(k_tile[k_idx][sg_idx]); } sum += subgroupAdd(result); } if (sg_id == 0) { - let single_sum : q_element_t = sum.x + sum.y + sum.z + sum.w; - let sqrt_dk = q_element_t(uniforms.alpha); + let single_sum : precision_t = sum.x + sum.y + sum.z + sum.w; + let sqrt_dk = precision_t(uniforms.alpha); let value = single_sum * sqrt_dk; qk_tile[q_idx][k_idx] += value; } } fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) { - var x = MIN_VALUE; + var x : precision_t = MIN_VALUE; if (enabled){ x = qk_tile[q_idx][sg_id]; } var max_value = subgroupMax(x); max_value = max(max_tile[q_idx], max_value); let sub = x - max_value; - var value:q_element_t = 0; + var value:precision_t = 0; if (enabled) { value = exp(sub); } @@ -244,7 +246,7 @@ fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) if (d == 0) { // Avoid division by zero by setting d to a really small value. - d = 0.0000001h; + d = precision_t(0.0000001h); } qk_tile[q_idx][sg_id] = value / d; if (sg_id == 0) @@ -256,22 +258,22 @@ fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) } fn computeO(q_idx: u32, sg_id:u32, enabled:bool) { - var attn = q_element_t(0); + var attn = precision_t(0); if (enabled) { attn = qk_tile[q_idx][sg_id]; } for (var i:u32 = 0; i < QKV_HEAD_VECTORIZED_SIZE; i++) { - let val = v_tile[sg_id][i]; + let val = vec4(v_tile[sg_id][i]); var intermediate = attn * val; let sum = subgroupAdd(intermediate); if (sg_id == 0) { let o_ratio = o_ratio[q_idx]; - let old_o = o_tile[q_idx][i]; + let old_o = vec4(o_tile[q_idx][i]); let new_o = ( o_ratio * old_o) + sum; - o_tile[q_idx][i] = new_o; + o_tile[q_idx][i] = q_value_t(new_o); } } } From 8bcb47af438e5f2148b01061955b306df0d2e01d Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sat, 23 Nov 2024 19:17:11 -0800 Subject: [PATCH 18/25] Improve comments. --- onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index 71cf2b5e9208f..36803e3027b4c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -73,10 +73,11 @@ class InPlaceSoftmaxProgram final : public Program { int work_group_size_; int components_; }; + class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool fa_variant = false) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), fa_variant_(fa_variant) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -95,7 +96,6 @@ class VxAttentionScoreProgram final : public Program { bool feed_past_value_; bool has_present_value_; int tile_size_; - bool fa_variant_; }; class MultiHeadAttention final : public WebGpuKernel { From 832c3239c965c989f05e53d4b02a710ff75365e8 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sat, 23 Nov 2024 19:17:26 -0800 Subject: [PATCH 19/25] Improve comments --- .../webgpu/bert/flash_attention.cc | 74 ++++++++++--------- .../webgpu/bert/multihead_attention.cc | 7 +- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index faaa821ae3fbe..1ed7f8ad010fb 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -1,11 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include -#include - #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/webgpu/bert/flash_attention.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" @@ -22,23 +17,23 @@ namespace contrib { namespace webgpu { Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { - // Expectations are - // qkv have same number of heads and hidden dimension (head size). - // qkv are in BSNH format. - // B - batch size but shader only supports batch_size 1. - // S - current sequence length but shader supports only S = 1. - // N - number of heads. - // H - head size or hidden dimension for each qkv head. - // KV cache is stored as BN(total_sequence_length)H - // Attention bias is in BN(total_sequence_length) - shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - if (has_past_) { - shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - } - shader.AddOutput("present_key", ShaderUsage::UseUniform); - shader.AddOutput("present_value", ShaderUsage::UseUniform); + // Expectations are + // qkv have same number of heads and hidden dimension (head size). + // qkv are in BSNH format. + // B - batch size but shader only supports batch_size 1. + // S - current sequence length but shader supports only S = 1. + // N - number of heads. + // H - head size or hidden dimension for each qkv head. + // KV cache is stored as BN(total_sequence_length)H + // Attention bias is in BN(total_sequence_length) + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + if (has_past_) { + shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("past_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + } + shader.AddOutput("present_key", ShaderUsage::UseUniform); + shader.AddOutput("present_value", ShaderUsage::UseUniform); shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" << "let kIdx = workgroup_id.x;\n" @@ -73,7 +68,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, AttentionParame const Tensor* K, const Tensor* past_key, Tensor* present_key, const Tensor* V, const Tensor* past_value, Tensor* present_value, int past_sequence_length, int total_sequence_length) { - + // CopyKVCache takes past key/value and current key/value and copies them to present key and value. + // This makes it so that FlashAttention only needs to look at present key and value, and saves + // number of input buffers in the shader, which we run out of (<=8) without this optimization. const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); bool has_past = (past_sequence_length != 0); CopyKVCacheProgram program{"CopyKVCache", components, has_past}; @@ -141,14 +138,15 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // Best to keep SHM usage per workgroup < 8KB. 4KB is the limit on a 48EU tigerlake // GPU afterwhich workgroups will be unscheduled to make space for memory. - shader.AdditionalImplementation() << "var q_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var k_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var v_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var o_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var qk_tile : array, TILE_SIZE>; // 8 * 2 * 8 = 128\n" - << "var max_tile : array; // 2 * 8 = 16\n" - << "var denom_tile : array; // 2 * 8 = 16\n" - << "var o_ratio : array; // 2 * 8 = 16\n"; + shader.AdditionalImplementation() << "" + << "var q_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var k_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var v_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var o_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" + << "var qk_tile : array, TILE_SIZE>; // 8 * 2 * 8 = 128\n" + << "var max_tile : array; // 2 * 8 = 16\n" + << "var denom_tile : array; // 2 * 8 = 16\n" + << "var o_ratio : array; // 2 * 8 = 16\n"; shader.AdditionalImplementation() << R"HELPER_FN( fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) @@ -246,6 +244,8 @@ fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) if (d == 0) { // Avoid division by zero by setting d to a really small value. + // Removing this protection has had no negative effect on any + // of the prompts tried so far. This is a safety net. d = precision_t(0.0000001h); } qk_tile[q_idx][sg_id] = value / d; @@ -281,6 +281,9 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool) // Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / TILE_SIZE, 1) // Each workgroup is responsible for a range of q values (TILE_SIZE) and visits all Ks for those q's. +// Each workgroup has TILE_SIZE waves, with each wave having subgroup size number of lanes (threads). +// Synchronization between lanes in a wave is free, with various subgroup* functions, and this shader +// uses that. Synchronization beween waves requires calling workgroupBarrier. shader.MainFunctionBody() << R"MAIN_FN( let head_idx = workgroup_id.x; // It is always the case that 0 <= wave_id < TILE_SIZE @@ -301,6 +304,7 @@ if (sg_id == 0) } for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE) { + // Insert barrier before updating shared memory the workgroup shares. workgroupBarrier(); let k_idx_global = k_start+wave_id; let k_idx_global_using_wave_valid = k_idx_global < uniforms.present_sequence_length; @@ -314,21 +318,23 @@ for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_ // and sg_id, (k_start+sg_id). loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); } + // Insert barrier before workgroup starts reading the shared memory. workgroupBarrier(); if (k_idx_global_using_wave_valid) { + // Iterate over Q rather than K because for the case of new_seq 1, there is a single query + // and context length of K by iterating over Q using the waves for K, this step can use all + // the waves in the workgroup, instead of leaving them idle. for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.new_sequence_length; q_idx++) { // Leveraging the subgroups for parallelism, compute dot product of QK. - // Because for the case of new_seq 1, there is a single query and context length of K - // we iterate over q and use the waves for K so that this step can use all the waves in - // in the workgroup. // We validate q_idx,wave_id to be less than TILE_SIZE, computeDotProduct only needs to // validate sg_id as being less than QKV_HEAD_VECTORIZED_SIZE. computeDotProduct(q_idx, wave_id, sg_id, sg_size); } } + // Insert barrier before SoftMax reads the dot product values across K. workgroupBarrier(); let wave_lane_valid:bool = q_idx_global_using_wave_valid && sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length; diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 1264ebb05dfdb..8d979d5adb83c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -1,11 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include -#include - #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/webgpu/bert/flash_attention.h" #include "contrib_ops/webgpu/bert/multihead_attention.h" @@ -178,7 +173,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components}; + components}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { From a814770c6b619e4135acc319c4e3799a67e07719 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sat, 23 Nov 2024 19:19:28 -0800 Subject: [PATCH 20/25] Fix comment about intel subgroup size. --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 1ed7f8ad010fb..5fb5e74510d9c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -120,12 +120,12 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { } shader.AddOutput("output", ShaderUsage::UseUniform); - // SUBGROUP_SIZE has to be the same as sg_size. For intel this will be 8. + // SUBGROUP_SIZE has to be the same as sg_size. For intel this will be 16. // TILE_SIZE is the number of groups sharing the k_tile. // TILE_SIZE has to be <= SUBGROUP_SIZE. Ideal perf of computeSoftMax is when // TILE_SIZE == SUBGROUP_SIZE. This is a sperate constant from SUBGROUP_SIZE // because SUBGROUP_SIZE * TILE_SIZE has to be <= 256 as per webgpu - // gpu limits. For Intel this TILE_SIZE will be 8. + // gpu limits. For Intel this TILE_SIZE will be 16. // Change precision_t to be f32 below to run dotproduct/ softmax in fp32 precision. shader.AdditionalImplementation() << "const SUBGROUP_SIZE: u32 = " << subgroup_size_ << ";\n" << "const TILE_SIZE: u32 = " << tile_size_ << ";\n" From ab110096fcd3b6d20a0c5f5f3929fbc32523c190 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Sun, 24 Nov 2024 17:39:44 -0800 Subject: [PATCH 21/25] Fix AttentionBias loading --- .../webgpu/bert/flash_attention.cc | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 5fb5e74510d9c..1c23c3c49cb9f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -136,17 +136,18 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { << "alias precision_t = q_element_t;\n" << "const MIN_VALUE : precision_t = precision_t(-6504.0h);\n"; - // Best to keep SHM usage per workgroup < 8KB. 4KB is the limit on a 48EU tigerlake + // Best to keep SHM usage per workgroup < 128KB, from intel docs for Intel Iris Xe GPU. + // "The SLM is a 128KB High Bandwidth Memory (HBM) accessible from the EUs in the subslice" // GPU afterwhich workgroups will be unscheduled to make space for memory. shader.AdditionalImplementation() << "" - << "var q_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var k_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var v_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var o_tile : array, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n" - << "var qk_tile : array, TILE_SIZE>; // 8 * 2 * 8 = 128\n" - << "var max_tile : array; // 2 * 8 = 16\n" - << "var denom_tile : array; // 2 * 8 = 16\n" - << "var o_ratio : array; // 2 * 8 = 16\n"; + << "var q_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" + << "var k_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" + << "var v_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" + << "var o_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" + << "var qk_tile : array, TILE_SIZE>; // 16 * 2 * 16 = 512\n" + << "var max_tile : array; // 2 * 16 = 32\n" + << "var denom_tile : array; // 2 * 16 = 32\n" + << "var o_ratio : array; // 2 * 16 = 32\n"; shader.AdditionalImplementation() << R"HELPER_FN( fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) @@ -313,11 +314,11 @@ for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_ // K/V values from k_start+wave_id. loadk(wave_id, k_idx_global, head_idx, sg_id, sg_size); loadv(wave_id, k_idx_global, head_idx, sg_id, sg_size); - // Next, we want for every q row (wave_id) to populate bias for new sequence length - // (k_start+sg_id). loadAttentionBias handles range checking q_idx_global, - // and sg_id, (k_start+sg_id). - loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); } + // Next, we want for every q row (wave_id) to populate bias for new sequence length + // (k_start+sg_id). loadAttentionBias handles range checking q_idx_global, + // and sg_id, (k_start+sg_id). + loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx); // Insert barrier before workgroup starts reading the shared memory. workgroupBarrier(); From 563e662d01b7616bc5fca41c025a069609ab26ea Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 25 Nov 2024 16:23:27 -0800 Subject: [PATCH 22/25] Add comment explaining Flash Attention --- .../contrib_ops/webgpu/bert/flash_attention.cc | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 1c23c3c49cb9f..5d4b29f46f4c1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -225,6 +225,22 @@ fn computeDotProduct(q_idx: u32, k_idx: u32, sg_id: u32, sg_size : u32) qk_tile[q_idx][k_idx] += value; } } +// +// Crux of Flash Attention is here, that allows for partial softmax computation, +// direct update of output and merging with previous results. +// https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf +// Where b is the block size of the tile. Xi is storing QKtranspose for the ith tile. +// mi_local is the max of Xi. Note: _ in this notation means what follows is a +// subscript. max_j=1:b (Xi[j]) is the max of Xi[j] for j=1 to b. +// +// for i = 1, #tiles do +// Xi = Q[k,:] Kt[:, (i-1) b : i b] +// mi_local= max_j=1:b (Xi[j]) +// Mi = max(M_(i-1), mi_local) +// d'_i = d'_(i-1) * e^(M_(i-1)-M_i) + Σ_j=1:b e^(Xi[j]-Mi) +// o'_i = o'_(i-1) * d'_(i-1) * e^(M_(i-1)-M_i) / d'_i + Σ_j=1:b (e^(Xi[j]-Mi) / d'_i) V[j + (i - 1)b,:] +// end +// fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) { var x : precision_t = MIN_VALUE; @@ -245,7 +261,7 @@ fn computeSoftMax(q_idx: u32, sg_id:u32, enabled:bool) if (d == 0) { // Avoid division by zero by setting d to a really small value. - // Removing this protection has had no negative effect on any + // Note: Removing this protection has had no negative effect on any // of the prompts tried so far. This is a safety net. d = precision_t(0.0000001h); } From ce2031e4b6435f743f6e5f41c97aacb1bcc32cd0 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 25 Nov 2024 18:10:32 -0800 Subject: [PATCH 23/25] Fix lint errors --- .../webgpu/bert/flash_attention.cc | 80 +++++++++---------- .../contrib_ops/webgpu/bert/flash_attention.h | 14 ++-- .../webgpu/bert/multihead_attention.cc | 8 +- 3 files changed, 51 insertions(+), 51 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 5d4b29f46f4c1..fb8338bf87407 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -36,8 +36,8 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddOutput("present_value", ShaderUsage::UseUniform); shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" - << "let kIdx = workgroup_id.x;\n" - << "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n"; + << "let kIdx = workgroup_id.x;\n" + << "let presentKeyOffset = headIdx * num_workgroups.x * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n"; if (has_past_) { shader.MainFunctionBody() << "if (kIdx < uniforms.past_sequence_length) {\n" << " let pastKeyOffset = headIdx * uniforms.past_sequence_length * uniforms.vectorized_head_size + (kIdx)*uniforms.vectorized_head_size;\n" @@ -51,23 +51,23 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "if (kIdx >= uniforms.past_sequence_length) {\n"; } shader.MainFunctionBody() << " let nkIdx = kIdx - uniforms.past_sequence_length;\n" - << " // Assumes kv have BSNH layout. num_workgroups.z is the num_head as per the dispatch requirement.\n" - << " let nOffset = nkIdx * uniforms.vectorized_head_size * num_workgroups.z + headIdx*uniforms.vectorized_head_size;\n" - << " // Assumes kv have BNSH layout.\n" - << " // let nOffset = headIdx * uniforms.kv_sequence_length * uniforms.vectorized_head_size + nkIdx * uniforms.vectorized_head_size;\n" - << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" - << " present_key[presentKeyOffset+w] = key[nOffset+w];\n" - << " present_value[presentKeyOffset+w] = value[nOffset+w];\n" - << " }\n" - << "}\n"; + << " // Assumes kv have BSNH layout. num_workgroups.z is the num_head as per the dispatch requirement.\n" + << " let nOffset = nkIdx * uniforms.vectorized_head_size * num_workgroups.z + headIdx*uniforms.vectorized_head_size;\n" + << " // Assumes kv have BNSH layout.\n" + << " // let nOffset = headIdx * uniforms.kv_sequence_length * uniforms.vectorized_head_size + nkIdx * uniforms.vectorized_head_size;\n" + << " for (var w: u32 = 0u; w < uniforms.vectorized_head_size; w ++) {\n" + << " present_key[presentKeyOffset+w] = key[nOffset+w];\n" + << " present_value[presentKeyOffset+w] = value[nOffset+w];\n" + << " }\n" + << "}\n"; - return Status::OK(); + return Status::OK(); } Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, AttentionParameters& parameters, - const Tensor* K, const Tensor* past_key, Tensor* present_key, - const Tensor* V, const Tensor* past_value, Tensor* present_value, - int past_sequence_length, int total_sequence_length) { + const Tensor* K, const Tensor* past_key, Tensor* present_key, + const Tensor* V, const Tensor* past_value, Tensor* present_value, + int past_sequence_length, int total_sequence_length) { // CopyKVCache takes past key/value and current key/value and copies them to present key and value. // This makes it so that FlashAttention only needs to look at present key and value, and saves // number of input buffers in the shader, which we run out of (<=8) without this optimization. @@ -92,7 +92,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, AttentionParame .CacheHint(std::to_string(components) + std::to_string(has_past)) .AddUniformVariables({{static_cast(past_sequence_length)}, {static_cast(parameters.kv_sequence_length)}, - {static_cast(parameters.head_size/ components)}}); + {static_cast(parameters.head_size / components)}}); return context.RunProgram(program); } @@ -140,14 +140,14 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // "The SLM is a 128KB High Bandwidth Memory (HBM) accessible from the EUs in the subslice" // GPU afterwhich workgroups will be unscheduled to make space for memory. shader.AdditionalImplementation() << "" - << "var q_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" - << "var k_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" - << "var v_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" - << "var o_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" - << "var qk_tile : array, TILE_SIZE>; // 16 * 2 * 16 = 512\n" - << "var max_tile : array; // 2 * 16 = 32\n" - << "var denom_tile : array; // 2 * 16 = 32\n" - << "var o_ratio : array; // 2 * 16 = 32\n"; + << "var q_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" + << "var k_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" + << "var v_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" + << "var o_tile : array, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n" + << "var qk_tile : array, TILE_SIZE>; // 16 * 2 * 16 = 512\n" + << "var max_tile : array; // 2 * 16 = 32\n" + << "var denom_tile : array; // 2 * 16 = 32\n" + << "var o_ratio : array; // 2 * 16 = 32\n"; shader.AdditionalImplementation() << R"HELPER_FN( fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32) @@ -296,11 +296,11 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool) } )HELPER_FN"; -// Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / TILE_SIZE, 1) -// Each workgroup is responsible for a range of q values (TILE_SIZE) and visits all Ks for those q's. -// Each workgroup has TILE_SIZE waves, with each wave having subgroup size number of lanes (threads). -// Synchronization between lanes in a wave is free, with various subgroup* functions, and this shader -// uses that. Synchronization beween waves requires calling workgroupBarrier. + // Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / TILE_SIZE, 1) + // Each workgroup is responsible for a range of q values (TILE_SIZE) and visits all Ks for those q's. + // Each workgroup has TILE_SIZE waves, with each wave having subgroup size number of lanes (threads). + // Synchronization between lanes in a wave is free, with various subgroup* functions, and this shader + // uses that. Synchronization beween waves requires calling workgroupBarrier. shader.MainFunctionBody() << R"MAIN_FN( let head_idx = workgroup_id.x; // It is always the case that 0 <= wave_id < TILE_SIZE @@ -369,8 +369,8 @@ if (q_idx_global_using_wave_valid) } Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, - Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, - AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); const uint32_t subgroup_size = context.MinSubgroupSize(); @@ -385,16 +385,16 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; std::string cache_hint = std::to_string(has_attention_bias) + - std::to_string(subgroup_size) + - std::to_string(tile_size) + - std::to_string(parameters.head_size) + - std::to_string(parameters.num_heads); + std::to_string(subgroup_size) + + std::to_string(tile_size) + + std::to_string(parameters.head_size) + + std::to_string(parameters.num_heads); program.SetDispatchGroupSize(parameters.num_heads, (parameters.sequence_length + tile_size - 1) / tile_size, 1) - .SetWorkgroupSize(subgroup_size*subgroup_size) - .CacheHint(cache_hint) - .AddUniformVariables({{static_cast(parameters.sequence_length)}, - {static_cast(parameters.total_sequence_length)}, - {alpha}}); + .SetWorkgroupSize(subgroup_size * subgroup_size) + .CacheHint(cache_hint) + .AddUniformVariables({{static_cast(parameters.sequence_length)}, + {static_cast(parameters.total_sequence_length)}, + {alpha}}); return context.RunProgram(program); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index d9188496821a4..7c546a58a6ad3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -40,11 +40,11 @@ class FlashAttentionProgram final : public Program { int qkv_head_size, int qkv_num_heads) : Program{kernel_name}, - has_attention_bias_(has_attention_bias), - subgroup_size_(subgroup_size), - tile_size_(tile_size), - qkv_head_size_(qkv_head_size), - qkv_num_heads_(qkv_num_heads) { + has_attention_bias_(has_attention_bias), + subgroup_size_(subgroup_size), + tile_size_(tile_size), + qkv_head_size_(qkv_head_size), + qkv_num_heads_(qkv_num_heads) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -62,8 +62,8 @@ class FlashAttentionProgram final : public Program { }; Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, - Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, - AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); + Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 8d979d5adb83c..5cf295cdd2023 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -459,11 +459,11 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_value = context.Output(2, present_shape); if (parameters.batch_size == 1 && - bias == nullptr && - present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && - present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) { + bias == nullptr && + present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && + present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context); + present_value, parameters, context); } TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, From bf1b14658ebca0f8d0a58ed2cda774f8cbca9556 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Mon, 25 Nov 2024 19:34:53 -0800 Subject: [PATCH 24/25] Add option to turn on flash attention via config --- .../webgpu/bert/flash_attention.cc | 26 ++++++++++++++++--- .../contrib_ops/webgpu/bert/flash_attention.h | 5 +++- .../webgpu/bert/multihead_attention.cc | 5 +--- .../core/providers/webgpu/compute_context.h | 3 +++ .../core/providers/webgpu/webgpu_context.cc | 1 + .../core/providers/webgpu/webgpu_context.h | 2 ++ .../webgpu/webgpu_execution_provider.h | 4 ++- .../webgpu/webgpu_provider_factory.cc | 14 ++++++++++ .../webgpu/webgpu_provider_options.h | 4 +++ 9 files changed, 55 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index fb8338bf87407..4375b51d73b1e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -64,7 +64,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, AttentionParameters& parameters, +Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const AttentionParameters& parameters, const Tensor* K, const Tensor* past_key, Tensor* present_key, const Tensor* V, const Tensor* past_value, Tensor* present_value, int past_sequence_length, int total_sequence_length) { @@ -300,7 +300,7 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool) // Each workgroup is responsible for a range of q values (TILE_SIZE) and visits all Ks for those q's. // Each workgroup has TILE_SIZE waves, with each wave having subgroup size number of lanes (threads). // Synchronization between lanes in a wave is free, with various subgroup* functions, and this shader - // uses that. Synchronization beween waves requires calling workgroupBarrier. + // uses that. Synchronization between waves requires calling workgroupBarrier. shader.MainFunctionBody() << R"MAIN_FN( let head_idx = workgroup_id.x; // It is always the case that 0 <= wave_id < TILE_SIZE @@ -370,7 +370,7 @@ if (q_idx_global_using_wave_valid) Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, - AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + const AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length)); const uint32_t subgroup_size = context.MinSubgroupSize(); @@ -399,6 +399,26 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return context.RunProgram(program); } +bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value, + const AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + // The min subgroup size affects the block size while going through the sequence length. + // 16 is the smallest size tested, smaller sized would impact performance. + // Checking for this also ensures that we dont run flash attention where subgroup is not supported. + constexpr int kMinSupportedSubgroupSize = 16; + // Workgroup size is set to be (subgroup_size * subgroup_size), check that it is allowed. + // Flash attention is written only to support batch_size of 1, algorithm can be extended to support + // batch_size > 1. What bias is used for is not clear, so it is not implemented in the shader. + // The Flash attention implementation is vectorized, to keep things simple, only vec4 is implemented - + // this implies that head_size has to be a multiple of 4. + return context.IsFlashAttentionEnabled() && + context.MinSubgroupSize() >= kMinSupportedSubgroupSize && + context.DeviceLimits().maxComputeWorkgroupSizeX >= (context.MinSubgroupSize() * context.MinSubgroupSize()) && + parameters.batch_size == 1 && + bias == nullptr && + present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && + present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0; +} + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 7c546a58a6ad3..04d1fdbc75572 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -63,7 +63,10 @@ class FlashAttentionProgram final : public Program { Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, - AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); + const AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); + +bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value, + const AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 5cf295cdd2023..ddcf54bfdab32 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -458,10 +458,7 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_key = context.Output(1, present_shape); Tensor* present_value = context.Output(2, present_shape); - if (parameters.batch_size == 1 && - bias == nullptr && - present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && - present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) { + if (CanApplyFlashAttention(bias, present_key, present_value, parameters, context)) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context); } diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 22cdf389bdf7c..420b5653c953b 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -44,6 +44,9 @@ class ComputeContext { inline const uint32_t MinSubgroupSize() const { return webgpu_context_.MinSubgroupSize(); } + inline const bool IsFlashAttentionEnabled() const { + return webgpu_context_.IsFlashAttentionEnabled(); + } // // Get the kernel context. diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index ef0d4a987e483..070cb4e7007ea 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -138,6 +138,7 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info } else { query_type_ = TimestampQueryType::None; } + is_flash_attention_enabled_ = webgpu_ep_info.enable_flash_attention; }); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 0b8b526b2c381..7964246b8ace9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -56,6 +56,7 @@ class WebGpuContext final { const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; } const wgpu::Limits& DeviceLimits() const { return device_limits_; } uint32_t MinSubgroupSize() const { return min_subgroup_size_; } + const bool IsFlashAttentionEnabled() const { return is_flash_attention_enabled_; } const wgpu::CommandEncoder& GetCommandEncoder() { if (!current_command_encoder_) { @@ -185,6 +186,7 @@ class WebGpuContext final { uint64_t gpu_timestamp_offset_ = 0; bool is_profiling_ = false; + bool is_flash_attention_enabled_ = false; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 336395a1dd0dd..2366c2d69845b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -23,9 +23,10 @@ class WebGpuProfiler; } // namespace webgpu struct WebGpuExecutionProviderInfo { - WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture) + WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture, bool enable_flash_attention) : data_layout{data_layout}, enable_graph_capture{enable_graph_capture}, + enable_flash_attention{enable_flash_attention}, storage_buffer_cache_mode{}, uniform_buffer_cache_mode{}, query_resolve_buffer_cache_mode{}, @@ -36,6 +37,7 @@ struct WebGpuExecutionProviderInfo { DataLayout data_layout; bool enable_graph_capture; + bool enable_flash_attention; webgpu::BufferCacheMode storage_buffer_cache_mode; webgpu::BufferCacheMode uniform_buffer_cache_mode; webgpu::BufferCacheMode query_resolve_buffer_cache_mode; diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 803c12274c08f..50039233676b5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -40,6 +40,8 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( DataLayout::NHWC, // graph capture feature is disabled by default false, + // flash attention feature is disabled by default + false, }; std::string preferred_layout_str; @@ -67,6 +69,18 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( } LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture; + std::string enable_flash_attention_str; + if (config_options.TryGetConfigEntry(kEnableFlashAttention, enable_flash_attention_str)) { + if (enable_flash_attention_str == kEnableFlashAttention_ON) { + webgpu_ep_info.enable_flash_attention = true; + } else if (enable_flash_attention_str == kEnableFlashAttention_OFF) { + webgpu_ep_info.enable_flash_attention = false; + } else { + ORT_THROW("Invalid enable flash attention: ", enable_flash_attention_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP flash attention enabled: " << webgpu_ep_info.enable_flash_attention; + auto parse_buffer_cache_mode = [&config_options](const std::string& config_entry_str, webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode { std::string buffer_cache_mode_str; diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h index 63befedffea84..fb81b16e4fdd4 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -11,6 +11,7 @@ namespace options { constexpr const char* kPreferredLayout = "WebGPU:preferredLayout"; constexpr const char* kEnableGraphCapture = "WebGPU:enableGraphCapture"; +constexpr const char* kEnableFlashAttention = "WebGPU:enableFlashAttention"; constexpr const char* kDawnProcTable = "WebGPU:dawnProcTable"; @@ -36,6 +37,9 @@ constexpr const char* kPreferredLayout_NHWC = "NHWC"; constexpr const char* kEnableGraphCapture_ON = "1"; constexpr const char* kEnableGraphCapture_OFF = "0"; +constexpr const char* kEnableFlashAttention_ON = "1"; +constexpr const char* kEnableFlashAttention_OFF = "0"; + constexpr const char* kBufferCacheMode_Disabled = "disabled"; constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease"; constexpr const char* kBufferCacheMode_Simple = "simple"; From d1e442e6f77db692064f6bd7cf25446dbeceb9d3 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Tue, 26 Nov 2024 16:57:35 -0800 Subject: [PATCH 25/25] Fix min value --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 4375b51d73b1e..40c49a0d8ca39 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -134,7 +134,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { << "const QKV_HEAD_VECTORIZED_SIZE: u32 = QKV_HEAD_SIZE / VECTOR_SIZE;\n" << "const NUM_HEADS: u32 = " << qkv_num_heads_ << ";\n" << "alias precision_t = q_element_t;\n" - << "const MIN_VALUE : precision_t = precision_t(-6504.0h);\n"; + << "const MIN_VALUE : precision_t = precision_t(-65504.0h);\n"; // Best to keep SHM usage per workgroup < 128KB, from intel docs for Intel Iris Xe GPU. // "The SLM is a 128KB High Bandwidth Memory (HBM) accessible from the EUs in the subslice"