Skip to content

Commit 8412441

Browse files
committed
handle f16 input and f16 kernel, more opt
1 parent f555aa3 commit 8412441

File tree

4 files changed

+351
-137
lines changed

4 files changed

+351
-137
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ set(GGML_OPENCL_KERNELS
106106
repeat
107107
mul_mat_f16_f32
108108
conv2d
109+
conv2d_f16_f32
109110
)
110111

111112
foreach (K ${GGML_OPENCL_KERNELS})

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 80 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,9 @@ struct ggml_backend_opencl_context {
390390
cl_program program_tanh;
391391
cl_program program_upscale;
392392
cl_program program_concat;
393-
cl_program program_conv_2d;
393+
cl_program program_conv_2d_f16;
394+
cl_program program_conv_2d_f32;
395+
cl_program program_conv_2d_f16_f32;
394396
cl_program program_tsembd;
395397
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
396398

@@ -442,7 +444,9 @@ struct ggml_backend_opencl_context {
442444
cl_kernel kernel_upscale_bilinear;
443445
cl_kernel kernel_concat_f32_contiguous;
444446
cl_kernel kernel_concat_f32_non_contiguous;
445-
cl_kernel kernel_conv_2d;
447+
cl_kernel kernel_conv_2d_f16;
448+
cl_kernel kernel_conv_2d_f32;
449+
cl_kernel kernel_conv_2d_f16_f32;
446450
cl_kernel kernel_timestep_embedding;
447451
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
448452

@@ -1480,25 +1484,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
14801484
GGML_LOG_CONT(".");
14811485
}
14821486

1483-
// conv2d
1484-
{
1485-
#ifdef GGML_OPENCL_EMBED_KERNELS
1486-
const std::string kernel_src {
1487-
#include "conv2d.cl.h"
1488-
};
1489-
#else
1490-
const std::string kernel_src = read_file("conv2d.cl");
1491-
#endif
1492-
if (!kernel_src.empty()) {
1493-
backend_ctx->program_conv_2d =
1494-
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1495-
CL_CHECK((backend_ctx->kernel_conv_2d = clCreateKernel(backend_ctx->program_conv_2d, "kernel_conv_2d", &err), err));
1496-
GGML_LOG_CONT(".");
1497-
} else {
1498-
GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
1499-
backend_ctx->program_conv_2d = nullptr;
1500-
backend_ctx->kernel_conv_2d = nullptr;
1501-
}
1487+
// conv2d
1488+
{
1489+
#ifdef GGML_OPENCL_EMBED_KERNELS
1490+
const std::string kernel_src {
1491+
#include "conv2d.cl.h"
1492+
};
1493+
const std::string kernel_src_f16_f32 {
1494+
#include "conv2d_f16_f32.cl.h"
1495+
};
1496+
#else
1497+
const std::string kernel_src = read_file("conv2d.cl");
1498+
const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl");
1499+
#endif
1500+
if (!kernel_src.empty()) {
1501+
backend_ctx->program_conv_2d_f16 =
1502+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str());
1503+
CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err));
1504+
GGML_LOG_CONT(".");
1505+
backend_ctx->program_conv_2d_f32 =
1506+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1507+
CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err));
1508+
GGML_LOG_CONT(".");
1509+
} else {
1510+
GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
1511+
backend_ctx->program_conv_2d_f16 = nullptr;
1512+
backend_ctx->kernel_conv_2d_f16 = nullptr;
1513+
backend_ctx->program_conv_2d_f32 = nullptr;
1514+
backend_ctx->kernel_conv_2d_f32 = nullptr;
1515+
}
1516+
if (!kernel_src_f16_f32.empty()) {
1517+
backend_ctx->program_conv_2d_f16_f32 =
1518+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts);
1519+
CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err));
1520+
GGML_LOG_CONT(".");
1521+
} else {
1522+
GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n");
1523+
backend_ctx->program_conv_2d_f16_f32 = nullptr;
1524+
backend_ctx->kernel_conv_2d_f16_f32 = nullptr;
1525+
}
15021526
}
15031527

15041528
// mul_mv_id_q4_0_f32_8x_flat
@@ -2385,7 +2409,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
23852409
case GGML_OP_UPSCALE:
23862410
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
23872411
case GGML_OP_CONV_2D:
2388-
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2412+
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
2413+
(op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
2414+
(op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
23892415
case GGML_OP_CONCAT:
23902416
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
23912417
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -5035,25 +5061,44 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co
50355061
const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3];
50365062
const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5];
50375063

5038-
const cl_uint cl_nb01 = nb01/nb00; const cl_uint cl_nb02 = nb02/nb00; const cl_uint cl_nb03 = nb03/nb00;
5039-
const cl_uint cl_nb11 = nb11/nb10; const cl_uint cl_nb12 = nb12/nb10; const cl_uint cl_nb13 = nb13/nb10;
5040-
const cl_uint cl_nb1 = nb1/nb0; const cl_uint cl_nb2 = nb2/nb0; const cl_uint cl_nb3 = nb3/nb0;
5064+
const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type);
5065+
const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type);
5066+
const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type);
50415067

50425068
const int64_t NPQ = (int64_t)N * OW * OH;
50435069

5044-
const uint32_t WG_SIZE = 128;
5045-
const uint32_t BS_K = 128;
5046-
const uint32_t BS_CRS = 16;
5070+
const uint32_t BS_K = 64;
50475071
const uint32_t BS_NPQ = 64;
5072+
const uint32_t BS_CRS = 16;
50485073
const uint32_t VEC_SIZE = 4;
50495074

5075+
const uint32_t TS_K = 4;
5076+
const uint32_t TS_NPQ = 8;
5077+
5078+
const uint32_t WG_K = BS_K / TS_K;
5079+
const uint32_t WG_NPQ = BS_NPQ / TS_NPQ;
5080+
50505081
auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; };
50515082
const uint32_t NB_K = splitWork(Cout, BS_K);
50525083
const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ);
50535084

5054-
const size_t shmem_size = (size_t)(BS_K * (BS_CRS + 1) * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE + 1) * sizeof(cl_half4));
5085+
cl_kernel kernel;
5086+
size_t shmem_size;
5087+
5088+
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
5089+
kernel = backend_ctx->kernel_conv_2d_f16;
5090+
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4));
5091+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
5092+
kernel = backend_ctx->kernel_conv_2d_f32;
5093+
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
5094+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
5095+
kernel = backend_ctx->kernel_conv_2d_f16_f32;
5096+
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
5097+
} else {
5098+
GGML_ASSERT(false && "Unsupported data type combination for conv2d");
5099+
return;
5100+
}
50555101

5056-
cl_kernel kernel = backend_ctx->kernel_conv_2d;
50575102
cl_uint idx = 0;
50585103
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0));
50595104
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1));
@@ -5068,18 +5113,18 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co
50685113
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13));
50695114
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3));
50705115

5071-
size_t global_work_size[] = { (size_t)NB_K * WG_SIZE, (size_t)NB_NPQ, 1 };
5072-
size_t local_work_size[] = { (size_t)WG_SIZE, 1, 1 };
5116+
size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 };
5117+
size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 };
50735118

50745119
#ifdef GGML_OPENCL_PROFILING
50755120
cl_event evt;
5076-
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
5121+
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 2, NULL, global_work_size, local_work_size, 0, NULL, &evt));
50775122

50785123
backend_ctx->profiling_info.emplace_back();
5079-
populateProfilingInfo(backend_ctx->profiling_info.back(), evt, kernel, 3, global_work_size, local_work_size, dst);
5124+
populateProfilingInfo(backend_ctx->profiling_info.back(), evt, kernel, 2, global_work_size, local_work_size, dst);
50805125
#else
50815126
GGML_UNUSED(dst);
5082-
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
5127+
CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 2, NULL, global_work_size, local_work_size, 0, NULL, NULL));
50835128
#endif
50845129
>>>>>>> 4d5d5a83 (add conv2d kernel)
50855130
}

0 commit comments

Comments
 (0)