Skip to content

Commit 8b5c0a4

Browse files
committed
Support F16 operations
1 parent 29b85f1 commit 8b5c0a4

File tree

1 file changed

+34
-31
lines changed

1 file changed

+34
-31
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6059,29 +6059,29 @@ void ggml_compute_forward_im2col_back_f32(
60596059
}
60606060
}
60616061

6062-
static void ggml_call_mul_mat(
6063-
const ggml_compute_params * params,
6064-
int64_t m, int64_t n, int64_t k,
6065-
void * a, void * b, void * c) {
6066-
6062+
static void ggml_call_mul_mat(ggml_type T, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6063+
void * a, void * b, void * c) {
6064+
const ggml_type_traits * traits = ggml_get_type_traits(T);
60676065
struct ggml_tensor src1 = {};
6066+
src1.type = T;
60686067
src1.ne[0] = k;
60696068
src1.ne[1] = m;
60706069
src1.ne[2] = 1;
60716070
src1.ne[3] = 1;
6072-
src1.nb[0] = sizeof(float);
6073-
src1.nb[1] = k * sizeof(float);
6071+
src1.nb[0] = traits->type_size;
6072+
src1.nb[1] = k * traits->type_size;
60746073
src1.nb[2] = src1.nb[1];
60756074
src1.nb[3] = src1.nb[2];
60766075
src1.data = a;
60776076

60786077
struct ggml_tensor src0 = {};
6078+
src0.type = T;
60796079
src0.ne[0] = k;
60806080
src0.ne[1] = n;
60816081
src0.ne[2] = 1;
60826082
src0.ne[3] = 1;
6083-
src0.nb[0] = sizeof(float);
6084-
src0.nb[1] = k * sizeof(float);
6083+
src0.nb[0] = traits->type_size;
6084+
src0.nb[1] = k * traits->type_size;
60856085
src0.nb[2] = src0.nb[1];
60866086
src0.nb[3] = src0.nb[2];
60876087
src0.data = b;
@@ -6102,17 +6102,18 @@ static void ggml_call_mul_mat(
61026102
ggml_compute_forward_mul_mat(params, &dst);
61036103
}
61046104

6105-
61066105
// ggml_compute_forward_conv_2d
61076106

6108-
static void ggml_compute_forward_conv_2d_f32(
6109-
const ggml_compute_params * params,
6110-
const ggml_tensor * kernel, // [KW, KH, IC, OC] - fp32
6111-
const ggml_tensor * src, // [W, H, C, N]
6112-
ggml_tensor * dst) { // [OW, OH, OC, N]
6107+
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6108+
const ggml_tensor * kernel, // [KW, KH, IC, OC]
6109+
const ggml_tensor * src, // [W, H, C, N]
6110+
ggml_tensor * dst, // [OW, OH, OC, N]
6111+
ggml_type kernel_type) {
61136112

61146113
GGML_ASSERT(ggml_is_contiguous(kernel));
6115-
GGML_ASSERT(kernel->type == GGML_TYPE_F32);
6114+
GGML_ASSERT(kernel->type == kernel_type);
6115+
6116+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
61166117

61176118
const int32_t stride_x = dst->op_params[0];
61186119
const int32_t stride_y = dst->op_params[1];
@@ -6133,20 +6134,20 @@ static void ggml_compute_forward_conv_2d_f32(
61336134
const int64_t dst_h = dst->ne[1];
61346135

61356136
float * src_data = (float*) src->data;
6136-
float * knl_data = (float*) kernel->data;
6137+
void * knl_data = kernel->data;
61376138
float * dst_data = (float*) dst->data;
61386139

61396140
const int64_t knl_n = knl_w * knl_h * c_in;
61406141
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
61416142

6142-
const int64_t space_per_patch = knl_n * sizeof(float) + c_out * sizeof(float);
6143+
const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
61436144
const int64_t batch_size = params->wsize / space_per_patch;
61446145
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
61456146
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
61466147

61476148
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
61486149

6149-
float * tmp = (float *) params->wdata;
6150+
void * tmp = params->wdata;
61506151

61516152
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
61526153

@@ -6166,7 +6167,7 @@ static void ggml_compute_forward_conv_2d_f32(
61666167
const int64_t src_y = p % dst_w;
61676168

61686169
float * src_base = (float *)((char *)src_data + batch_n * src->nb[3]);
6169-
float * dst_row = tmp + (p % patches_per_batch) * knl_n;
6170+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
61706171

61716172
for (int64_t ic = 0; ic < c_in; ++ic) {
61726173
for (int64_t ky = 0; ky < knl_h; ++ky) {
@@ -6176,11 +6177,19 @@ static void ggml_compute_forward_conv_2d_f32(
61766177

61776178
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
61786179

6180+
float src_val;
61796181
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6180-
dst_row[dst_idx] = 0.0f;
6182+
src_val = 0.0f;
61816183
} else {
61826184
float * src_ptr = (float *)((char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6183-
dst_row[dst_idx] = *src_ptr;
6185+
src_val = *src_ptr;
6186+
}
6187+
6188+
char * element_ptr = dst_row + dst_idx * traits->type_size;
6189+
if (kernel_type == GGML_TYPE_F32) {
6190+
*(float *) element_ptr = src_val;
6191+
} else if (kernel_type == GGML_TYPE_F16) {
6192+
*(ggml_fp16_t *) element_ptr = GGML_FP32_TO_FP16(src_val);
61846193
}
61856194
}
61866195
}
@@ -6189,11 +6198,10 @@ static void ggml_compute_forward_conv_2d_f32(
61896198

61906199
ggml_barrier(params->threadpool);
61916200

6192-
float * gemm_output = tmp + patches_per_batch * knl_n;
6201+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
61936202

61946203
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6195-
ggml_call_mul_mat(params, patch_n, c_out, knl_n,
6196-
tmp, knl_data, gemm_output);
6204+
ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
61976205

61986206
ggml_barrier(params->threadpool);
61996207

@@ -6211,7 +6219,6 @@ static void ggml_compute_forward_conv_2d_f32(
62116219

62126220
for (int64_t oc = 0; oc < c_out; ++oc) {
62136221
const float value = gemm_output[i * c_out + oc];
6214-
// Write to WHCN layout: dst[w, h, c, n]
62156222
float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
62166223
*dst_ptr = value;
62176224
}
@@ -6226,11 +6233,7 @@ void ggml_compute_forward_conv_2d(
62266233
const ggml_tensor * src0 = dst->src[0];
62276234
const ggml_tensor * src1 = dst->src[1];
62286235

6229-
if (src0->type == GGML_TYPE_F16) {
6230-
GGML_ASSERT(false && "F16 not supported yet");
6231-
} else {
6232-
ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
6233-
}
6236+
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
62346237
}
62356238

62366239
// ggml_compute_forward_conv_transpose_2d

0 commit comments

Comments
 (0)