Skip to content

Commit 81b1284

Browse files
dskhudiaazahed98
authored andcommitted
Support size_per_head=112 (NVIDIA#660)
* fix multi-gpu build * add support for size_per_head=112 for gpt decoder
1 parent 42465ca commit 81b1284

File tree

6 files changed

+113
-9
lines changed

6 files changed

+113
-9
lines changed

src/fastertransformer/kernels/decoder_masked_multihead_attention.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t&
4141
case 96:
4242
mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
4343
break;
44+
case 112:
45+
mmha_launch_kernel<T, 112, 128, KERNEL_PARAMS_TYPE>(params, stream);
46+
break;
4447
case 128:
4548
mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
4649
break;
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "decoder_masked_multihead_attention_template.hpp"
18+
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
19+
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
20+
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
21+
#include <assert.h>
22+
#include <float.h>
23+
#include <type_traits>
24+
25+
////////////////////////////////////////////////////////////////////////////////////////////////////
26+
27+
#define MMHA_LAUNCH_KERNEL( \
28+
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \
29+
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
30+
dim3 grid(params.num_heads, params.batch_size); \
31+
mmha::masked_multihead_attention_kernel<T, \
32+
Dh, \
33+
Dh_MAX, \
34+
THDS_PER_KEY, \
35+
THDS_PER_VALUE, \
36+
THDS_PER_BLOCK, \
37+
DO_CROSS_ATTENTION, \
38+
HAS_BEAMS><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
39+
40+
////////////////////////////////////////////////////////////////////////////////////////////////////
41+
42+
// !!! Specialize the launcher for Cross attention
43+
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
44+
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
45+
{
46+
constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value;
47+
constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
48+
int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
49+
if (params.cache_indir == nullptr) {
50+
if (tlength < 32) {
51+
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream);
52+
}
53+
else if (tlength < 2048) {
54+
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream);
55+
}
56+
else {
57+
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream);
58+
}
59+
}
60+
else {
61+
if (tlength < 32) {
62+
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream);
63+
}
64+
else if (tlength < 2048) {
65+
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream);
66+
}
67+
else {
68+
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream);
69+
}
70+
}
71+
}
72+
73+
////////////////////////////////////////////////////////////////////////////////////////////////////
74+
75+
template void mmha_launch_kernel<float, 112, 128, Masked_multihead_attention_params<float>>(
76+
const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
77+
template void mmha_launch_kernel<uint16_t, 112, 128, Masked_multihead_attention_params<uint16_t>>(
78+
const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
79+
#ifdef ENABLE_BF16
80+
template void mmha_launch_kernel<__nv_bfloat16, 112, 128, Masked_multihead_attention_params<__nv_bfloat16>>(
81+
const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream);
82+
#endif
83+
#ifdef ENABLE_FP8
84+
template void mmha_launch_kernel<__nv_fp8_e4m3, 112, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>(
85+
const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream);
86+
#endif
87+
88+
template void mmha_launch_kernel<float, 112, 128, Cross_multihead_attention_params<float>>(
89+
const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
90+
template void mmha_launch_kernel<uint16_t, 112, 128, Cross_multihead_attention_params<uint16_t>>(
91+
const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
92+
#ifdef ENABLE_BF16
93+
template void mmha_launch_kernel<__nv_bfloat16, 112, 128, Cross_multihead_attention_params<__nv_bfloat16>>(
94+
const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream);
95+
#endif
96+
#ifdef ENABLE_FP8
97+
template void mmha_launch_kernel<__nv_fp8_e4m3, 112, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>(
98+
const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream);
99+
#endif
100+
101+
#undef MMHA_LAUNCH_KERNEL

src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -796,8 +796,8 @@ DecoderCrossAttentionLayer<T>::DecoderCrossAttentionLayer(size_t max_b
796796
q_scaling_(q_scaling)
797797
{
798798
FT_CHECK(size_per_head_ == 32 || size_per_head_ == 48 || size_per_head_ == 64 || size_per_head_ == 80
799-
|| size_per_head_ == 96 || size_per_head_ == 128 || size_per_head_ == 144 || size_per_head_ == 160
800-
|| size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256);
799+
|| size_per_head_ == 96 || size_per_head_ == 112 || size_per_head_ == 128 || size_per_head_ == 144
800+
|| size_per_head_ == 160 || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256);
801801
}
802802

803803
template<typename T>
@@ -1030,4 +1030,4 @@ template class DecoderCrossAttentionLayer<half>;
10301030
template class DecoderCrossAttentionLayer<__nv_bfloat16>;
10311031
#endif
10321032

1033-
} // namespace fastertransformer
1033+
} // namespace fastertransformer

src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@ DecoderSelfAttentionLayer<T>::DecoderSelfAttentionLayer(size_t max_bat
278278
int8_mode_(int8_mode)
279279
{
280280
FT_CHECK(size_per_head_ == 32 || size_per_head_ == 48 || size_per_head_ == 64 || size_per_head_ == 80
281-
|| size_per_head_ == 96 || size_per_head_ == 128 || size_per_head_ == 144 || size_per_head_ == 160
282-
|| size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256);
281+
|| size_per_head_ == 96 || size_per_head_ == 112 || size_per_head_ == 128 || size_per_head_ == 144
282+
|| size_per_head_ == 160 || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256);
283283
if (int8_mode_ == 1) {
284284
FT_CHECK_WITH_INFO(!(std::is_same<T, float>::value), "Weight only quant not supported for fp32.");
285285
weight_only_int8_fc_runner_ = std::make_shared<CutlassFpAIntBGemmRunner<T, uint8_t>>();

src/fastertransformer/utils/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ add_library(mpi_utils STATIC mpi_utils.cc)
5757
set_property(TARGET mpi_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
5858
set_property(TARGET mpi_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
5959
if (BUILD_MULTI_GPU)
60-
target_link_libraries(mpi_utils PUBLIC -lmpi logger)
60+
target_link_libraries(mpi_utils PUBLIC -lmpi -lmpi_cxx logger)
6161
endif()
6262

6363
add_library(nccl_utils STATIC nccl_utils.cc)

tests/decoding/tf_fused_self_multihead_attention_unit_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ def test_attn_head_fp16(self):
5656
self.run_attn(4, 128, head, 64, tf.float16)
5757

5858
def test_attn_size_fp32(self):
59-
for size in [32, 64, 80, 96, 128, 144, 160, 192, 224, 256]:
59+
for size in [32, 64, 80, 96, 112, 128, 144, 160, 192, 224, 256]:
6060
tf.reset_default_graph()
6161
self.run_attn(4, 128, 12, size, tf.float32)
6262

6363
def test_attn_size_fp16(self):
64-
for size in [32, 64, 80, 96, 128, 144, 160, 192, 224, 256]:
64+
for size in [32, 64, 80, 96, 112, 128, 144, 160, 192, 224, 256]:
6565
tf.reset_default_graph()
6666
self.run_attn(4, 128, 12, size, tf.float16)
6767

@@ -171,4 +171,4 @@ def run_attn(self, batch_size, seq_len, head_num, size_per_head, data_type):
171171
assert(v_cache_max_diff < threshold)
172172

173173
if __name__ == "__main__":
174-
unittest.main()
174+
unittest.main()

0 commit comments

Comments
 (0)