Skip to content

Commit

Permalink
Change head_size parameter dependent on qkv_hidden_size (#12933)
Browse files Browse the repository at this point in the history
**Description**: Add qkv_hidden_size support in CUDA Attention Layer
implementation.

Changes include:

- Modify UT to test GPU and CPU implementation
- Add overload for CUDA kernel `AddBiasTransposeQKV` to support scenario
where V_HIDDEN_SIZE != QK_HIDDEN_SIZE
- Update variable names from `head_size` to `qkv_head_sizes[0]` or
`qkv_head_sizes[2]`
- Modify function definitions to allow communication of
`qkv_hidden_sizes` or `qkv_head_sizes`

Note that this feature is not supported in Rocm EP or quantized
attention right now.

**Motivation and Context**
- Why is this change required? What problem does it solve? The current
CUDA implementation of attention layer doesn't support the parameter
qkv_hidden_size added in the CPU implementation in PR
[8039](#8039)
- If it fixes an open issue, please link to the issue here.

Co-authored-by: Peter Mcaughan <[email protected]>
  • Loading branch information
petermcaughan and Peter Mcaughan authored Oct 11, 2022
1 parent b9e23bd commit febd5fa
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 98 deletions.
104 changes: 78 additions & 26 deletions onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,52 @@ __global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output)
}
}

template <typename T>
__global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output, int v_head_size) {
// Input: BxSxMxNxH (Format 1)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size

int n = threadIdx.y; // head_num_id
int s = blockIdx.x; // sequence_id
int b = blockIdx.y; // batch_id
int m = blockIdx.z; // matrix id (Q=0, K=1, V=2)
const int h = threadIdx.x; // head_element_id

const int qk_head_size = blockDim.x;
const int num_heads = blockDim.y;

const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;

const int qkv_head_sizes[3] = {qk_head_size, qk_head_size, v_head_size};

const int total_head_size = num_heads * (qkv_head_sizes[0] + qkv_head_sizes[1] + qkv_head_sizes[2]);

int in_offset;
int out_offset;
int bias_offset;
in_offset = b * (total_head_size * sequence_length) + // B
s * (total_head_size) + // S
m * (qk_head_size * num_heads) + // M
n * qkv_head_sizes[m] + // N
h; // H

out_offset = m * (num_heads * qk_head_size * sequence_length * batch_size) + // M
b * (num_heads * qkv_head_sizes[m] * sequence_length) + // B
n * (sequence_length * qkv_head_sizes[m]) + // N
s * (qkv_head_sizes[m]) + // S
h; // H

bias_offset = m * (num_heads * qk_head_size)+ // QKV
n * (qkv_head_sizes[m]) + // N
h; // H

if (h < qkv_head_sizes[m]) {
output[out_offset] = input[in_offset] + biases[bias_offset];
}
}

template <typename T>
__global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, const T* biases, T* output) {
int n = threadIdx.y;
Expand Down Expand Up @@ -203,80 +249,86 @@ __global__ void AddBiasTransposeLarge(const int head_size, const T* input, const
template <typename T>
void InvokeAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const T* input, const T* biases, T* output) {
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, const int v_head_size) {
const dim3 grid(sequence_length, batch_size, num_matrices);
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
if (qk_head_size * num_heads <= max_threads_per_block) {
const dim3 block(qk_head_size, num_heads, 1);
if (format == 2) {
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(input, biases, output);
} else if (format == 1) {
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output);
if ((v_head_size == -1) || (qk_head_size == v_head_size)) {
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output);
} else {
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output, v_head_size);
}
} else {
AddBiasTranspose<T><<<grid, block, 0, stream>>>(input, biases, output);
}
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
if (format == 2) {
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, input, biases, output);
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
} else if (format == 1) {
AddBiasTransposeQKVLarge<T><<<grid, block, 0, stream>>>(head_size, input, biases, output);
AddBiasTransposeQKVLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
} else {
AddBiasTransposeLarge<T><<<grid, block, 0, stream>>>(head_size, input, biases, output);
AddBiasTransposeLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
}
}
}

template <>
void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const half* input, const half* biases, half* output,
bool enable_half4) {
if (enable_half4 && 0 == (head_size % 4)) {
const int H = head_size / 4;
bool enable_half4, const int v_head_size) {
if (enable_half4 && 0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) {
const int H_q = qk_head_size / 4;
const int H_v = v_head_size / 4;
const Half4* input2 = reinterpret_cast<const Half4*>(input);
const Half4* biases2 = reinterpret_cast<const Half4*>(biases);
Half4* output2 = reinterpret_cast<Half4*>(output);
InvokeAddBiasTranspose<Half4>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
batch_size, sequence_length, num_heads, H_q, input2, biases2, output2, H_v);
} else if (0 == (qk_head_size & 1) && 0 == (v_head_size % 1)) {
const int H_q = qk_head_size / 2;
const int H_v = v_head_size / 2;
const half2* input2 = reinterpret_cast<const half2*>(input);
const half2* biases2 = reinterpret_cast<const half2*>(biases);
half2* output2 = reinterpret_cast<half2*>(output);
InvokeAddBiasTranspose<half2>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
batch_size, sequence_length, num_heads, H_q, input2, biases2, output2, H_v);
} else {
InvokeAddBiasTranspose<half>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, head_size, input, biases, output);
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size);
}
}

template <>
void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const float* input, const float* biases, float* output,
bool /*enable_half4*/) {
if (0 == (head_size % 4)) {
const int H = head_size / 4;
bool /*enable_half4*/, const int v_head_size) {
if (0 == (qk_head_size % 4)) {
const int H = qk_head_size / 4;
const float4* input2 = reinterpret_cast<const float4*>(input);
const float4* biases2 = reinterpret_cast<const float4*>(biases);
float4* output2 = reinterpret_cast<float4*>(output);
InvokeAddBiasTranspose<float4>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 4);
} else if (0 == (qk_head_size & 1)) {
const int H = qk_head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
const float2* biases2 = reinterpret_cast<const float2*>(biases);
float2* output2 = reinterpret_cast<float2*>(output);

InvokeAddBiasTranspose<float2>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 2);
} else {
InvokeAddBiasTranspose<float>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, head_size, input, biases, output);
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size);
}
}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ namespace cuda {
template <typename T>
void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const T* input, const T* biases, T* output, bool enable_half4);
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size);

} // namespace cuda
} // namespace contrib
Expand Down
41 changes: 29 additions & 12 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,31 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

// bias shape (3 * hidden_size)
const auto& bias_shape = bias->Shape();
int hidden_size = static_cast<int>(bias_shape[0]) / 3;
int q_hidden_size;
int k_hidden_size;
int v_hidden_size;


if (qkv_hidden_sizes_.size() == 0) {
q_hidden_size = static_cast<int>(bias_shape[0]) / 3;
k_hidden_size = static_cast<int>(bias_shape[0]) / 3;
v_hidden_size = static_cast<int>(bias_shape[0]) / 3;
} else {
q_hidden_size = static_cast<int>(qkv_hidden_sizes_[0]);
k_hidden_size = static_cast<int>(qkv_hidden_sizes_[1]);
v_hidden_size = static_cast<int>(qkv_hidden_sizes_[2]);
}

int head_size = hidden_size / num_heads_;
const int qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_};

TensorShapeVector output_shape(3);
output_shape[0] = shape[0];
output_shape[1] = shape[1];
output_shape[2] = static_cast<int64_t>(hidden_size);
output_shape[2] = static_cast<int64_t>(v_hidden_size);
Tensor* output = context->Output(0, output_shape);

int past_sequence_length = 0;
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
Tensor* present = GetPresent(context, past, batch_size, qkv_head_size[1], sequence_length, past_sequence_length);

// Check whether we can use fused kernel
int sm = device_prop.major * 10 + device_prop.minor;
Expand All @@ -103,12 +116,14 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == present &&
nullptr == extra_add_qk &&
!is_unidirectional_ &&
HasFusedFp16Kernel(sm, head_size, sequence_length));
qkv_head_size[0] == qkv_head_size[1] &&
qkv_head_size[1] == qkv_head_size[2] &&
HasFusedFp16Kernel(sm, qkv_head_size[0], sequence_length));

MHARunner* fused_runner = nullptr;
if (use_fused_runner) {
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, head_size, sm));
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, qkv_head_size[0], sm));
}
// In case some kernel not loaded due to shared memory limit, we need to double check here.
if (fused_fp16_runner_->isValid(sequence_length)) {
Expand All @@ -121,9 +136,9 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

// Use GEMM for fully connection.
int m = batch_size * sequence_length;
int n = 3 * hidden_size;
int n = (q_hidden_size + k_hidden_size + v_hidden_size);
int k = input_hidden_size;
size_t gemm_buffer_size = static_cast<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size;
size_t gemm_buffer_size = static_cast<size_t>(batch_size) * sequence_length * n * element_size;
auto gemm_buffer = GetScratchBuffer<T>(gemm_buffer_size);

typedef typename ToCudaType<T>::MappedType CudaT;
Expand All @@ -140,10 +155,11 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
batch_size,
num_heads_,
head_size,
qkv_head_size[0],
sequence_length,
past_sequence_length,
fused_runner);
fused_runner,
qkv_head_size[2]);

auto work_space = GetScratchBuffer<void>(workSpaceSize);
ORT_RETURN_IF_ERROR(LaunchAttentionKernel(
Expand All @@ -154,7 +170,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
batch_size,
sequence_length,
num_heads_,
head_size,
qkv_head_size[0],
past_sequence_length,
is_unidirectional_,
reinterpret_cast<const void*>(gemm_buffer.get()),
Expand All @@ -166,7 +182,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
work_space.get(),
output->MutableData<T>(),
nullptr == present ? nullptr : present->MutableData<T>(),
fused_runner));
fused_runner,
qkv_head_size[2]));

return Status::OK();
}
Expand Down
Loading

0 comments on commit febd5fa

Please sign in to comment.