@@ -118,6 +118,52 @@ __global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output)
118
118
}
119
119
}
120
120
121
+ template <typename T>
122
+ __global__ void AddBiasTransposeQKV (const T* input, const T* biases, T* output, int v_head_size) {
123
+ // Input: BxSxMxNxH (Format 1)
124
+ // Output: MxBxNxSxH
125
+ // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
126
+
127
+ int n = threadIdx .y ; // head_num_id
128
+ int s = blockIdx .x ; // sequence_id
129
+ int b = blockIdx .y ; // batch_id
130
+ int m = blockIdx .z ; // matrix id (Q=0, K=1, V=2)
131
+ const int h = threadIdx .x ; // head_element_id
132
+
133
+ const int qk_head_size = blockDim .x ;
134
+ const int num_heads = blockDim .y ;
135
+
136
+ const int sequence_length = gridDim .x ;
137
+ const int batch_size = gridDim .y ;
138
+
139
+ const int qkv_head_sizes[3 ] = {qk_head_size, qk_head_size, v_head_size};
140
+
141
+ const int total_head_size = num_heads * (qkv_head_sizes[0 ] + qkv_head_sizes[1 ] + qkv_head_sizes[2 ]);
142
+
143
+ int in_offset;
144
+ int out_offset;
145
+ int bias_offset;
146
+ in_offset = b * (total_head_size * sequence_length) + // B
147
+ s * (total_head_size) + // S
148
+ m * (qk_head_size * num_heads) + // M
149
+ n * qkv_head_sizes[m] + // N
150
+ h; // H
151
+
152
+ out_offset = m * (num_heads * qk_head_size * sequence_length * batch_size) + // M
153
+ b * (num_heads * qkv_head_sizes[m] * sequence_length) + // B
154
+ n * (sequence_length * qkv_head_sizes[m]) + // N
155
+ s * (qkv_head_sizes[m]) + // S
156
+ h; // H
157
+
158
+ bias_offset = m * (num_heads * qk_head_size)+ // QKV
159
+ n * (qkv_head_sizes[m]) + // N
160
+ h; // H
161
+
162
+ if (h < qkv_head_sizes[m]) {
163
+ output[out_offset] = input[in_offset] + biases[bias_offset];
164
+ }
165
+ }
166
+
121
167
template <typename T>
122
168
__global__ void AddBiasTransposeQKVLarge (const int head_size, const T* input, const T* biases, T* output) {
123
169
int n = threadIdx .y ;
@@ -203,80 +249,86 @@ __global__ void AddBiasTransposeLarge(const int head_size, const T* input, const
203
249
template <typename T>
204
250
void InvokeAddBiasTranspose (
205
251
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
206
- const int batch_size, const int sequence_length, const int num_heads, const int head_size ,
207
- const T* input, const T* biases, T* output) {
252
+ const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size ,
253
+ const T* input, const T* biases, T* output, const int v_head_size ) {
208
254
const dim3 grid (sequence_length, batch_size, num_matrices);
209
- if (head_size * num_heads <= max_threads_per_block) {
210
- const dim3 block (head_size , num_heads, 1 );
255
+ if (qk_head_size * num_heads <= max_threads_per_block) {
256
+ const dim3 block (qk_head_size , num_heads, 1 );
211
257
if (format == 2 ) {
212
258
AddBiasTransposeTrt<T><<<grid, block, 0 , stream>>> (input, biases, output);
213
259
} else if (format == 1 ) {
214
- AddBiasTransposeQKV<T><<<grid, block, 0 , stream>>> (input, biases, output);
260
+ if ((v_head_size == -1 ) || (qk_head_size == v_head_size)) {
261
+ AddBiasTransposeQKV<T><<<grid, block, 0 , stream>>> (input, biases, output);
262
+ } else {
263
+ AddBiasTransposeQKV<T><<<grid, block, 0 , stream>>> (input, biases, output, v_head_size);
264
+ }
215
265
} else {
216
266
AddBiasTranspose<T><<<grid, block, 0 , stream>>> (input, biases, output);
217
267
}
218
268
} else {
219
269
const dim3 block (CeilDiv (max_threads_per_block, num_heads), num_heads, 1 );
220
270
if (format == 2 ) {
221
- AddBiasTransposeTrtLarge<T><<<grid, block, 0 , stream>>> (head_size , input, biases, output);
271
+ AddBiasTransposeTrtLarge<T><<<grid, block, 0 , stream>>> (qk_head_size , input, biases, output);
222
272
} else if (format == 1 ) {
223
- AddBiasTransposeQKVLarge<T><<<grid, block, 0 , stream>>> (head_size , input, biases, output);
273
+ AddBiasTransposeQKVLarge<T><<<grid, block, 0 , stream>>> (qk_head_size , input, biases, output);
224
274
} else {
225
- AddBiasTransposeLarge<T><<<grid, block, 0 , stream>>> (head_size , input, biases, output);
275
+ AddBiasTransposeLarge<T><<<grid, block, 0 , stream>>> (qk_head_size , input, biases, output);
226
276
}
227
277
}
228
278
}
229
279
230
280
template <>
231
281
void LaunchAddBiasTranspose (
232
282
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
233
- const int batch_size, const int sequence_length, const int num_heads, const int head_size ,
283
+ const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size ,
234
284
const half* input, const half* biases, half* output,
235
- bool enable_half4) {
236
- if (enable_half4 && 0 == (head_size % 4 )) {
237
- const int H = head_size / 4 ;
285
+ bool enable_half4, const int v_head_size) {
286
+ if (enable_half4 && 0 == (qk_head_size % 4 ) && 0 == (v_head_size % 4 )) {
287
+ const int H_q = qk_head_size / 4 ;
288
+ const int H_v = v_head_size / 4 ;
238
289
const Half4* input2 = reinterpret_cast <const Half4*>(input);
239
290
const Half4* biases2 = reinterpret_cast <const Half4*>(biases);
240
291
Half4* output2 = reinterpret_cast <Half4*>(output);
241
292
InvokeAddBiasTranspose<Half4>(stream, num_matrices, format, max_threads_per_block,
242
- batch_size, sequence_length, num_heads, H, input2, biases2, output2);
243
- } else if (0 == (head_size & 1 )) {
244
- const int H = head_size / 2 ;
293
+ batch_size, sequence_length, num_heads, H_q, input2, biases2, output2, H_v);
294
+ } else if (0 == (qk_head_size & 1 ) && 0 == (v_head_size % 1 )) {
295
+ const int H_q = qk_head_size / 2 ;
296
+ const int H_v = v_head_size / 2 ;
245
297
const half2* input2 = reinterpret_cast <const half2*>(input);
246
298
const half2* biases2 = reinterpret_cast <const half2*>(biases);
247
299
half2* output2 = reinterpret_cast <half2*>(output);
248
300
InvokeAddBiasTranspose<half2>(stream, num_matrices, format, max_threads_per_block,
249
- batch_size, sequence_length, num_heads, H , input2, biases2, output2);
301
+ batch_size, sequence_length, num_heads, H_q , input2, biases2, output2, H_v );
250
302
} else {
251
303
InvokeAddBiasTranspose<half>(stream, num_matrices, format, max_threads_per_block,
252
- batch_size, sequence_length, num_heads, head_size , input, biases, output);
304
+ batch_size, sequence_length, num_heads, qk_head_size , input, biases, output, v_head_size );
253
305
}
254
306
}
255
307
256
308
template <>
257
309
void LaunchAddBiasTranspose (
258
310
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
259
- const int batch_size, const int sequence_length, const int num_heads, const int head_size ,
311
+ const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size ,
260
312
const float * input, const float * biases, float * output,
261
- bool /* enable_half4*/ ) {
262
- if (0 == (head_size % 4 )) {
263
- const int H = head_size / 4 ;
313
+ bool /* enable_half4*/ , const int v_head_size ) {
314
+ if (0 == (qk_head_size % 4 )) {
315
+ const int H = qk_head_size / 4 ;
264
316
const float4 * input2 = reinterpret_cast <const float4 *>(input);
265
317
const float4 * biases2 = reinterpret_cast <const float4 *>(biases);
266
318
float4 * output2 = reinterpret_cast <float4 *>(output);
267
319
InvokeAddBiasTranspose<float4 >(stream, num_matrices, format, max_threads_per_block,
268
- batch_size, sequence_length, num_heads, H, input2, biases2, output2);
269
- } else if (0 == (head_size & 1 )) {
270
- const int H = head_size / 2 ;
320
+ batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 4 );
321
+ } else if (0 == (qk_head_size & 1 )) {
322
+ const int H = qk_head_size / 2 ;
271
323
const float2 * input2 = reinterpret_cast <const float2 *>(input);
272
324
const float2 * biases2 = reinterpret_cast <const float2 *>(biases);
273
325
float2 * output2 = reinterpret_cast <float2 *>(output);
274
326
275
327
InvokeAddBiasTranspose<float2 >(stream, num_matrices, format, max_threads_per_block,
276
- batch_size, sequence_length, num_heads, H, input2, biases2, output2);
328
+ batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 2 );
277
329
} else {
278
330
InvokeAddBiasTranspose<float >(stream, num_matrices, format, max_threads_per_block,
279
- batch_size, sequence_length, num_heads, head_size , input, biases, output);
331
+ batch_size, sequence_length, num_heads, qk_head_size , input, biases, output, v_head_size );
280
332
}
281
333
}
282
334
0 commit comments