Skip to content

Commit

Permalink
Fix AttentionBias loading
Browse files Browse the repository at this point in the history
  • Loading branch information
sushraja-msft committed Nov 25, 2024
1 parent a814770 commit ab11009
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,18 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "alias precision_t = q_element_t;\n"
<< "const MIN_VALUE : precision_t = precision_t(-6504.0h);\n";

// Best to keep SHM usage per workgroup < 8KB. 4KB is the limit on a 48EU tigerlake
// Best to keep SHM usage per workgroup < 128KB, from intel docs for Intel Iris Xe GPU.
// "The SLM is a 128KB High Bandwidth Memory (HBM) accessible from the EUs in the subslice"
// GPU afterwhich workgroups will be unscheduled to make space for memory.
shader.AdditionalImplementation() << ""
<< "var<workgroup> q_tile : array<array<q_value_t, QKV_HEAD_VECTORIZED_SIZE>, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n"
<< "var<workgroup> k_tile : array<array<q_value_t, QKV_HEAD_VECTORIZED_SIZE>, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n"
<< "var<workgroup> v_tile : array<array<q_value_t, QKV_HEAD_VECTORIZED_SIZE>, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n"
<< "var<workgroup> o_tile : array<array<q_value_t, QKV_HEAD_VECTORIZED_SIZE>, TILE_SIZE>; // 96 * 8 * 2 = 1.5KB.\n"
<< "var<workgroup> qk_tile : array<array<precision_t, TILE_SIZE>, TILE_SIZE>; // 8 * 2 * 8 = 128\n"
<< "var<workgroup> max_tile : array<precision_t, TILE_SIZE>; // 2 * 8 = 16\n"
<< "var<workgroup> denom_tile : array<precision_t, TILE_SIZE>; // 2 * 8 = 16\n"
<< "var<workgroup> o_ratio : array<precision_t, TILE_SIZE>; // 2 * 8 = 16\n";
<< "var<workgroup> q_tile : array<array<q_value_t, QKV_HEAD_VECTORIZED_SIZE>, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n"
<< "var<workgroup> k_tile : array<array<q_value_t, QKV_HEAD_VECTORIZED_SIZE>, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n"
<< "var<workgroup> v_tile : array<array<q_value_t, QKV_HEAD_VECTORIZED_SIZE>, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n"
<< "var<workgroup> o_tile : array<array<q_value_t, QKV_HEAD_VECTORIZED_SIZE>, TILE_SIZE>; // 96 * 2 * 16 = 3KB.\n"
<< "var<workgroup> qk_tile : array<array<precision_t, TILE_SIZE>, TILE_SIZE>; // 16 * 2 * 16 = 512\n"
<< "var<workgroup> max_tile : array<precision_t, TILE_SIZE>; // 2 * 16 = 32\n"
<< "var<workgroup> denom_tile : array<precision_t, TILE_SIZE>; // 2 * 16 = 32\n"
<< "var<workgroup> o_ratio : array<precision_t, TILE_SIZE>; // 2 * 16 = 32\n";

shader.AdditionalImplementation() << R"HELPER_FN(
fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32)
Expand Down Expand Up @@ -313,11 +314,11 @@ for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_
// K/V values from k_start+wave_id.
loadk(wave_id, k_idx_global, head_idx, sg_id, sg_size);
loadv(wave_id, k_idx_global, head_idx, sg_id, sg_size);
// Next, we want for every q row (wave_id) to populate bias for new sequence length
// (k_start+sg_id). loadAttentionBias handles range checking q_idx_global,
// and sg_id, (k_start+sg_id).
loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx);
}
// Next, we want for every q row (wave_id) to populate bias for new sequence length
// (k_start+sg_id). loadAttentionBias handles range checking q_idx_global,
// and sg_id, (k_start+sg_id).
loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx);
// Insert barrier before workgroup starts reading the shared memory.
workgroupBarrier();
Expand Down

0 comments on commit ab11009

Please sign in to comment.