Skip to content

Commit febd5fa

Browse files
petermcaughanPeter Mcaughan
andauthored
Change head_size parameter dependent on qkv_hidden_size (#12933)
**Description**: Add qkv_hidden_size support in CUDA Attention Layer implementation. Changes include: - Modify UT to test GPU and CPU implementation - Add overload for CUDA kernel `AddBiasTransposeQKV` to support scenario where V_HIDDEN_SIZE != QK_HIDDEN_SIZE - Update variable names from `head_size` to `qkv_head_sizes[0]` or `qkv_head_sizes[2]` - Modify function definitions to allow communication of `qkv_hidden_sizes` or `qkv_head_sizes` Note that this feature is not supported in Rocm EP or quantized attention right now. **Motivation and Context** - Why is this change required? What problem does it solve? The current CUDA implementation of attention layer doesn't support the parameter qkv_hidden_size added in the CPU implementation in PR [8039](#8039) - If it fixes an open issue, please link to the issue here. Co-authored-by: Peter Mcaughan <[email protected]>
1 parent b9e23bd commit febd5fa

File tree

9 files changed

+194
-98
lines changed

9 files changed

+194
-98
lines changed

onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu

Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,52 @@ __global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output)
118118
}
119119
}
120120

121+
template <typename T>
122+
__global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output, int v_head_size) {
123+
// Input: BxSxMxNxH (Format 1)
124+
// Output: MxBxNxSxH
125+
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
126+
127+
int n = threadIdx.y; // head_num_id
128+
int s = blockIdx.x; // sequence_id
129+
int b = blockIdx.y; // batch_id
130+
int m = blockIdx.z; // matrix id (Q=0, K=1, V=2)
131+
const int h = threadIdx.x; // head_element_id
132+
133+
const int qk_head_size = blockDim.x;
134+
const int num_heads = blockDim.y;
135+
136+
const int sequence_length = gridDim.x;
137+
const int batch_size = gridDim.y;
138+
139+
const int qkv_head_sizes[3] = {qk_head_size, qk_head_size, v_head_size};
140+
141+
const int total_head_size = num_heads * (qkv_head_sizes[0] + qkv_head_sizes[1] + qkv_head_sizes[2]);
142+
143+
int in_offset;
144+
int out_offset;
145+
int bias_offset;
146+
in_offset = b * (total_head_size * sequence_length) + // B
147+
s * (total_head_size) + // S
148+
m * (qk_head_size * num_heads) + // M
149+
n * qkv_head_sizes[m] + // N
150+
h; // H
151+
152+
out_offset = m * (num_heads * qk_head_size * sequence_length * batch_size) + // M
153+
b * (num_heads * qkv_head_sizes[m] * sequence_length) + // B
154+
n * (sequence_length * qkv_head_sizes[m]) + // N
155+
s * (qkv_head_sizes[m]) + // S
156+
h; // H
157+
158+
bias_offset = m * (num_heads * qk_head_size)+ // QKV
159+
n * (qkv_head_sizes[m]) + // N
160+
h; // H
161+
162+
if (h < qkv_head_sizes[m]) {
163+
output[out_offset] = input[in_offset] + biases[bias_offset];
164+
}
165+
}
166+
121167
template <typename T>
122168
__global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, const T* biases, T* output) {
123169
int n = threadIdx.y;
@@ -203,80 +249,86 @@ __global__ void AddBiasTransposeLarge(const int head_size, const T* input, const
203249
template <typename T>
204250
void InvokeAddBiasTranspose(
205251
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
206-
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
207-
const T* input, const T* biases, T* output) {
252+
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
253+
const T* input, const T* biases, T* output, const int v_head_size) {
208254
const dim3 grid(sequence_length, batch_size, num_matrices);
209-
if (head_size * num_heads <= max_threads_per_block) {
210-
const dim3 block(head_size, num_heads, 1);
255+
if (qk_head_size * num_heads <= max_threads_per_block) {
256+
const dim3 block(qk_head_size, num_heads, 1);
211257
if (format == 2) {
212258
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(input, biases, output);
213259
} else if (format == 1) {
214-
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output);
260+
if ((v_head_size == -1) || (qk_head_size == v_head_size)) {
261+
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output);
262+
} else {
263+
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output, v_head_size);
264+
}
215265
} else {
216266
AddBiasTranspose<T><<<grid, block, 0, stream>>>(input, biases, output);
217267
}
218268
} else {
219269
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
220270
if (format == 2) {
221-
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, input, biases, output);
271+
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
222272
} else if (format == 1) {
223-
AddBiasTransposeQKVLarge<T><<<grid, block, 0, stream>>>(head_size, input, biases, output);
273+
AddBiasTransposeQKVLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
224274
} else {
225-
AddBiasTransposeLarge<T><<<grid, block, 0, stream>>>(head_size, input, biases, output);
275+
AddBiasTransposeLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
226276
}
227277
}
228278
}
229279

230280
template <>
231281
void LaunchAddBiasTranspose(
232282
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
233-
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
283+
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
234284
const half* input, const half* biases, half* output,
235-
bool enable_half4) {
236-
if (enable_half4 && 0 == (head_size % 4)) {
237-
const int H = head_size / 4;
285+
bool enable_half4, const int v_head_size) {
286+
if (enable_half4 && 0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) {
287+
const int H_q = qk_head_size / 4;
288+
const int H_v = v_head_size / 4;
238289
const Half4* input2 = reinterpret_cast<const Half4*>(input);
239290
const Half4* biases2 = reinterpret_cast<const Half4*>(biases);
240291
Half4* output2 = reinterpret_cast<Half4*>(output);
241292
InvokeAddBiasTranspose<Half4>(stream, num_matrices, format, max_threads_per_block,
242-
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
243-
} else if (0 == (head_size & 1)) {
244-
const int H = head_size / 2;
293+
batch_size, sequence_length, num_heads, H_q, input2, biases2, output2, H_v);
294+
} else if (0 == (qk_head_size & 1) && 0 == (v_head_size % 1)) {
295+
const int H_q = qk_head_size / 2;
296+
const int H_v = v_head_size / 2;
245297
const half2* input2 = reinterpret_cast<const half2*>(input);
246298
const half2* biases2 = reinterpret_cast<const half2*>(biases);
247299
half2* output2 = reinterpret_cast<half2*>(output);
248300
InvokeAddBiasTranspose<half2>(stream, num_matrices, format, max_threads_per_block,
249-
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
301+
batch_size, sequence_length, num_heads, H_q, input2, biases2, output2, H_v);
250302
} else {
251303
InvokeAddBiasTranspose<half>(stream, num_matrices, format, max_threads_per_block,
252-
batch_size, sequence_length, num_heads, head_size, input, biases, output);
304+
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size);
253305
}
254306
}
255307

256308
template <>
257309
void LaunchAddBiasTranspose(
258310
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
259-
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
311+
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
260312
const float* input, const float* biases, float* output,
261-
bool /*enable_half4*/) {
262-
if (0 == (head_size % 4)) {
263-
const int H = head_size / 4;
313+
bool /*enable_half4*/, const int v_head_size) {
314+
if (0 == (qk_head_size % 4)) {
315+
const int H = qk_head_size / 4;
264316
const float4* input2 = reinterpret_cast<const float4*>(input);
265317
const float4* biases2 = reinterpret_cast<const float4*>(biases);
266318
float4* output2 = reinterpret_cast<float4*>(output);
267319
InvokeAddBiasTranspose<float4>(stream, num_matrices, format, max_threads_per_block,
268-
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
269-
} else if (0 == (head_size & 1)) {
270-
const int H = head_size / 2;
320+
batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 4);
321+
} else if (0 == (qk_head_size & 1)) {
322+
const int H = qk_head_size / 2;
271323
const float2* input2 = reinterpret_cast<const float2*>(input);
272324
const float2* biases2 = reinterpret_cast<const float2*>(biases);
273325
float2* output2 = reinterpret_cast<float2*>(output);
274326

275327
InvokeAddBiasTranspose<float2>(stream, num_matrices, format, max_threads_per_block,
276-
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
328+
batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 2);
277329
} else {
278330
InvokeAddBiasTranspose<float>(stream, num_matrices, format, max_threads_per_block,
279-
batch_size, sequence_length, num_heads, head_size, input, biases, output);
331+
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size);
280332
}
281333
}
282334

onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ namespace cuda {
2424
template <typename T>
2525
void LaunchAddBiasTranspose(
2626
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
27-
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
28-
const T* input, const T* biases, T* output, bool enable_half4);
27+
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
28+
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size);
2929

3030
} // namespace cuda
3131
} // namespace contrib

onnxruntime/contrib_ops/cuda/bert/attention.cc

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,31 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
8282

8383
// bias shape (3 * hidden_size)
8484
const auto& bias_shape = bias->Shape();
85-
int hidden_size = static_cast<int>(bias_shape[0]) / 3;
85+
int q_hidden_size;
86+
int k_hidden_size;
87+
int v_hidden_size;
88+
89+
90+
if (qkv_hidden_sizes_.size() == 0) {
91+
q_hidden_size = static_cast<int>(bias_shape[0]) / 3;
92+
k_hidden_size = static_cast<int>(bias_shape[0]) / 3;
93+
v_hidden_size = static_cast<int>(bias_shape[0]) / 3;
94+
} else {
95+
q_hidden_size = static_cast<int>(qkv_hidden_sizes_[0]);
96+
k_hidden_size = static_cast<int>(qkv_hidden_sizes_[1]);
97+
v_hidden_size = static_cast<int>(qkv_hidden_sizes_[2]);
98+
}
8699

87-
int head_size = hidden_size / num_heads_;
100+
const int qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_};
88101

89102
TensorShapeVector output_shape(3);
90103
output_shape[0] = shape[0];
91104
output_shape[1] = shape[1];
92-
output_shape[2] = static_cast<int64_t>(hidden_size);
105+
output_shape[2] = static_cast<int64_t>(v_hidden_size);
93106
Tensor* output = context->Output(0, output_shape);
94107

95108
int past_sequence_length = 0;
96-
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
109+
Tensor* present = GetPresent(context, past, batch_size, qkv_head_size[1], sequence_length, past_sequence_length);
97110

98111
// Check whether we can use fused kernel
99112
int sm = device_prop.major * 10 + device_prop.minor;
@@ -103,12 +116,14 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
103116
nullptr == present &&
104117
nullptr == extra_add_qk &&
105118
!is_unidirectional_ &&
106-
HasFusedFp16Kernel(sm, head_size, sequence_length));
119+
qkv_head_size[0] == qkv_head_size[1] &&
120+
qkv_head_size[1] == qkv_head_size[2] &&
121+
HasFusedFp16Kernel(sm, qkv_head_size[0], sequence_length));
107122

108123
MHARunner* fused_runner = nullptr;
109124
if (use_fused_runner) {
110125
if (nullptr == fused_fp16_runner_.get()) {
111-
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, head_size, sm));
126+
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, qkv_head_size[0], sm));
112127
}
113128
// In case some kernel not loaded due to shared memory limit, we need to double check here.
114129
if (fused_fp16_runner_->isValid(sequence_length)) {
@@ -121,9 +136,9 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
121136

122137
// Use GEMM for fully connection.
123138
int m = batch_size * sequence_length;
124-
int n = 3 * hidden_size;
139+
int n = (q_hidden_size + k_hidden_size + v_hidden_size);
125140
int k = input_hidden_size;
126-
size_t gemm_buffer_size = static_cast<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size;
141+
size_t gemm_buffer_size = static_cast<size_t>(batch_size) * sequence_length * n * element_size;
127142
auto gemm_buffer = GetScratchBuffer<T>(gemm_buffer_size);
128143

129144
typedef typename ToCudaType<T>::MappedType CudaT;
@@ -140,10 +155,11 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
140155
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
141156
batch_size,
142157
num_heads_,
143-
head_size,
158+
qkv_head_size[0],
144159
sequence_length,
145160
past_sequence_length,
146-
fused_runner);
161+
fused_runner,
162+
qkv_head_size[2]);
147163

148164
auto work_space = GetScratchBuffer<void>(workSpaceSize);
149165
ORT_RETURN_IF_ERROR(LaunchAttentionKernel(
@@ -154,7 +170,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
154170
batch_size,
155171
sequence_length,
156172
num_heads_,
157-
head_size,
173+
qkv_head_size[0],
158174
past_sequence_length,
159175
is_unidirectional_,
160176
reinterpret_cast<const void*>(gemm_buffer.get()),
@@ -166,7 +182,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
166182
work_space.get(),
167183
output->MutableData<T>(),
168184
nullptr == present ? nullptr : present->MutableData<T>(),
169-
fused_runner));
185+
fused_runner,
186+
qkv_head_size[2]));
170187

171188
return Status::OK();
172189
}

0 commit comments

Comments
 (0)