@@ -390,6 +390,9 @@ struct ggml_backend_opencl_context {
390
390
cl_program program_tanh;
391
391
cl_program program_upscale;
392
392
cl_program program_concat;
393
+ cl_program program_conv_2d_f16;
394
+ cl_program program_conv_2d_f32;
395
+ cl_program program_conv_2d_f16_f32;
393
396
cl_program program_tsembd;
394
397
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
395
398
@@ -441,6 +444,9 @@ struct ggml_backend_opencl_context {
441
444
cl_kernel kernel_upscale_bilinear;
442
445
cl_kernel kernel_concat_f32_contiguous;
443
446
cl_kernel kernel_concat_f32_non_contiguous;
447
+ cl_kernel kernel_conv_2d_f16;
448
+ cl_kernel kernel_conv_2d_f32;
449
+ cl_kernel kernel_conv_2d_f16_f32;
444
450
cl_kernel kernel_timestep_embedding;
445
451
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
446
452
@@ -1478,6 +1484,47 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1478
1484
GGML_LOG_CONT (" ." );
1479
1485
}
1480
1486
1487
+ // conv2d
1488
+ {
1489
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1490
+ const std::string kernel_src {
1491
+ #include " conv2d.cl.h"
1492
+ };
1493
+ const std::string kernel_src_f16_f32 {
1494
+ #include " conv2d_f16_f32.cl.h"
1495
+ };
1496
+ #else
1497
+ const std::string kernel_src = read_file (" conv2d.cl" );
1498
+ const std::string kernel_src_f16_f32 = read_file (" conv2d_f16_f32.cl" );
1499
+ #endif
1500
+ if (!kernel_src.empty ()) {
1501
+ backend_ctx->program_conv_2d_f16 =
1502
+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), (std::string (compile_opts) + " -DUSE_FP16=1" ).c_str ());
1503
+ CL_CHECK ((backend_ctx->kernel_conv_2d_f16 = clCreateKernel (backend_ctx->program_conv_2d_f16 , " kernel_conv_2d" , &err), err));
1504
+ GGML_LOG_CONT (" ." );
1505
+ backend_ctx->program_conv_2d_f32 =
1506
+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1507
+ CL_CHECK ((backend_ctx->kernel_conv_2d_f32 = clCreateKernel (backend_ctx->program_conv_2d_f32 , " kernel_conv_2d" , &err), err));
1508
+ GGML_LOG_CONT (" ." );
1509
+ } else {
1510
+ GGML_LOG_WARN (" ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n " );
1511
+ backend_ctx->program_conv_2d_f16 = nullptr ;
1512
+ backend_ctx->kernel_conv_2d_f16 = nullptr ;
1513
+ backend_ctx->program_conv_2d_f32 = nullptr ;
1514
+ backend_ctx->kernel_conv_2d_f32 = nullptr ;
1515
+ }
1516
+ if (!kernel_src_f16_f32.empty ()) {
1517
+ backend_ctx->program_conv_2d_f16_f32 =
1518
+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src_f16_f32.c_str (), compile_opts);
1519
+ CL_CHECK ((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel (backend_ctx->program_conv_2d_f16_f32 , " kernel_conv_2d" , &err), err));
1520
+ GGML_LOG_CONT (" ." );
1521
+ } else {
1522
+ GGML_LOG_WARN (" ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n " );
1523
+ backend_ctx->program_conv_2d_f16_f32 = nullptr ;
1524
+ backend_ctx->kernel_conv_2d_f16_f32 = nullptr ;
1525
+ }
1526
+ }
1527
+
1481
1528
// mul_mv_id_q4_0_f32_8x_flat
1482
1529
{
1483
1530
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2361,6 +2408,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2361
2408
op->src [0 ]->ne [3 ] == 1 && op->ne [3 ] == 1 ;
2362
2409
case GGML_OP_UPSCALE:
2363
2410
return op->src [0 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2411
+ case GGML_OP_CONV_2D:
2412
+ return (op->src [0 ]->type == GGML_TYPE_F16 && op->src [1 ]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
2413
+ (op->src [0 ]->type == GGML_TYPE_F32 && op->src [1 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
2414
+ (op->src [0 ]->type == GGML_TYPE_F16 && op->src [1 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
2364
2415
case GGML_OP_CONCAT:
2365
2416
return op->src [0 ]->type == GGML_TYPE_F32 && op->src [1 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2366
2417
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -4998,6 +5049,83 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten
4998
5049
backend_ctx->enqueue_ndrange_kernel (kernel, 2 , global_work_size, local_work_size, dst);
4999
5050
}
5000
5051
5052
+ static void ggml_cl_conv_2d (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5053
+ GGML_TENSOR_BINARY_OP_LOCALS;
5054
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
5055
+
5056
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
5057
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
5058
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
5059
+
5060
+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
5061
+ cl_ulong offset1 = extra1->offset + src1->view_offs ;
5062
+ cl_ulong offsetd = extrad->offset + dst->view_offs ;
5063
+
5064
+ const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13;
5065
+ const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1;
5066
+
5067
+ const cl_uint s0 = dst->op_params [0 ]; const cl_uint s1 = dst->op_params [1 ];
5068
+ const cl_uint p0 = dst->op_params [2 ]; const cl_uint p1 = dst->op_params [3 ];
5069
+ const cl_uint d0 = dst->op_params [4 ]; const cl_uint d1 = dst->op_params [5 ];
5070
+
5071
+ const cl_uint cl_nb01 = nb01/ggml_type_size (src0->type ); const cl_uint cl_nb02 = nb02/ggml_type_size (src0->type ); const cl_uint cl_nb03 = nb03/ggml_type_size (src0->type );
5072
+ const cl_uint cl_nb11 = nb11/ggml_type_size (src1->type ); const cl_uint cl_nb12 = nb12/ggml_type_size (src1->type ); const cl_uint cl_nb13 = nb13/ggml_type_size (src1->type );
5073
+ const cl_uint cl_nb1 = nb1/ggml_type_size (dst->type ); const cl_uint cl_nb2 = nb2/ggml_type_size (dst->type ); const cl_uint cl_nb3 = nb3/ggml_type_size (dst->type );
5074
+
5075
+ const int64_t NPQ = (int64_t )N * OW * OH;
5076
+
5077
+ const uint32_t BS_K = 64 ;
5078
+ const uint32_t BS_NPQ = 64 ;
5079
+ const uint32_t BS_CRS = 16 ;
5080
+ const uint32_t VEC_SIZE = 4 ;
5081
+
5082
+ const uint32_t TS_K = 4 ;
5083
+ const uint32_t TS_NPQ = 8 ;
5084
+
5085
+ const uint32_t WG_K = BS_K / TS_K;
5086
+ const uint32_t WG_NPQ = BS_NPQ / TS_NPQ;
5087
+
5088
+ auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1 ) / block_size; };
5089
+ const uint32_t NB_K = splitWork (Cout, BS_K);
5090
+ const uint32_t NB_NPQ = splitWork (NPQ, BS_NPQ);
5091
+
5092
+ cl_kernel kernel;
5093
+ size_t shmem_size;
5094
+
5095
+ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
5096
+ kernel = backend_ctx->kernel_conv_2d_f16 ;
5097
+ shmem_size = (size_t )(BS_K * BS_CRS * sizeof (cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof (cl_half4));
5098
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
5099
+ kernel = backend_ctx->kernel_conv_2d_f32 ;
5100
+ shmem_size = (size_t )(BS_K * BS_CRS * sizeof (cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof (cl_float4));
5101
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
5102
+ kernel = backend_ctx->kernel_conv_2d_f16_f32 ;
5103
+ shmem_size = (size_t )(BS_K * BS_CRS * sizeof (cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof (cl_float4));
5104
+ } else {
5105
+ GGML_ASSERT (false && " Unsupported data type combination for conv2d" );
5106
+ return ;
5107
+ }
5108
+
5109
+ cl_uint idx = 0 ;
5110
+ CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_mem), &extra0->data_device )); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_ulong), &offset0));
5111
+ CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_mem), &extra1->data_device )); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_ulong), &offset1));
5112
+ CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_mem), &extrad->data_device )); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_ulong), &offsetd));
5113
+ CL_CHECK (clSetKernelArg (kernel, idx++, shmem_size, NULL ));
5114
+ CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &Cout)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &Cin)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &N));
5115
+ CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &KW)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &KH)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &W)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &H));
5116
+ CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &OW)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &OH));
5117
+ CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &s0)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &s1)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &p0)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &p1));
5118
+ CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &d0)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &d1));
5119
+ CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &cl_nb01)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &cl_nb02)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &cl_nb03));
5120
+ CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &cl_nb11)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &cl_nb12)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &cl_nb13));
5121
+ CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &cl_nb1)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &cl_nb2)); CL_CHECK (clSetKernelArg (kernel, idx++, sizeof (cl_uint), &cl_nb3));
5122
+
5123
+ size_t global_work_size[] = { (size_t )NB_K * WG_K, (size_t )NB_NPQ * WG_NPQ, 1 };
5124
+ size_t local_work_size[] = { (size_t )WG_K, (size_t )WG_NPQ, 1 };
5125
+
5126
+ backend_ctx->enqueue_ndrange_kernel (kernel, 2 , global_work_size, local_work_size, dst);
5127
+ }
5128
+
5001
5129
static void ggml_cl_mul_mat (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5002
5130
GGML_ASSERT (src0);
5003
5131
GGML_ASSERT (src0->extra );
@@ -6752,6 +6880,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
6752
6880
}
6753
6881
ggml_cl_upscale (backend, tensor->src [0 ], tensor);
6754
6882
return true ;
6883
+ case GGML_OP_CONV_2D:
6884
+ if (!any_on_device) {
6885
+ return false ;
6886
+ }
6887
+ func = ggml_cl_conv_2d;
6888
+ break ;
6755
6889
case GGML_OP_CONCAT:
6756
6890
if (!any_on_device) {
6757
6891
return false ;
0 commit comments