|
| 1 | +/* |
| 2 | + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "decoder_masked_multihead_attention_template.hpp" |
| 18 | +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" |
| 19 | +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" |
| 20 | +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" |
| 21 | +#include <assert.h> |
| 22 | +#include <float.h> |
| 23 | +#include <type_traits> |
| 24 | + |
| 25 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 26 | + |
| 27 | +#define MMHA_LAUNCH_KERNEL( \ |
| 28 | + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ |
| 29 | + size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ |
| 30 | + dim3 grid(params.num_heads, params.batch_size); \ |
| 31 | + mmha::masked_multihead_attention_kernel<T, \ |
| 32 | + Dh, \ |
| 33 | + Dh_MAX, \ |
| 34 | + THDS_PER_KEY, \ |
| 35 | + THDS_PER_VALUE, \ |
| 36 | + THDS_PER_BLOCK, \ |
| 37 | + DO_CROSS_ATTENTION, \ |
| 38 | + HAS_BEAMS><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params) |
| 39 | + |
| 40 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 41 | + |
| 42 | +// !!! Specialize the launcher for Cross attention |
| 43 | +template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE> |
| 44 | +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) |
| 45 | +{ |
| 46 | + constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value; |
| 47 | + constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value; |
| 48 | + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; |
| 49 | + if (params.cache_indir == nullptr) { |
| 50 | + if (tlength < 32) { |
| 51 | + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); |
| 52 | + } |
| 53 | + else if (tlength < 2048) { |
| 54 | + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); |
| 55 | + } |
| 56 | + else { |
| 57 | + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); |
| 58 | + } |
| 59 | + } |
| 60 | + else { |
| 61 | + if (tlength < 32) { |
| 62 | + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); |
| 63 | + } |
| 64 | + else if (tlength < 2048) { |
| 65 | + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); |
| 66 | + } |
| 67 | + else { |
| 68 | + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); |
| 69 | + } |
| 70 | + } |
| 71 | +} |
| 72 | + |
| 73 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 74 | + |
| 75 | +template void mmha_launch_kernel<float, 112, 128, Masked_multihead_attention_params<float>>( |
| 76 | + const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream); |
| 77 | +template void mmha_launch_kernel<uint16_t, 112, 128, Masked_multihead_attention_params<uint16_t>>( |
| 78 | + const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream); |
| 79 | +#ifdef ENABLE_BF16 |
| 80 | +template void mmha_launch_kernel<__nv_bfloat16, 112, 128, Masked_multihead_attention_params<__nv_bfloat16>>( |
| 81 | + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); |
| 82 | +#endif |
| 83 | +#ifdef ENABLE_FP8 |
| 84 | +template void mmha_launch_kernel<__nv_fp8_e4m3, 112, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( |
| 85 | + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); |
| 86 | +#endif |
| 87 | + |
| 88 | +template void mmha_launch_kernel<float, 112, 128, Cross_multihead_attention_params<float>>( |
| 89 | + const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream); |
| 90 | +template void mmha_launch_kernel<uint16_t, 112, 128, Cross_multihead_attention_params<uint16_t>>( |
| 91 | + const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream); |
| 92 | +#ifdef ENABLE_BF16 |
| 93 | +template void mmha_launch_kernel<__nv_bfloat16, 112, 128, Cross_multihead_attention_params<__nv_bfloat16>>( |
| 94 | + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); |
| 95 | +#endif |
| 96 | +#ifdef ENABLE_FP8 |
| 97 | +template void mmha_launch_kernel<__nv_fp8_e4m3, 112, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( |
| 98 | + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); |
| 99 | +#endif |
| 100 | + |
| 101 | +#undef MMHA_LAUNCH_KERNEL |
0 commit comments