Skip to content

Commit 8d94219

Browse files
rgerganovggerganov
andauthored
ggml : add ggml_set_rows (#14274)
* ggml : add ggml_set_rows Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using indices from 'c'. ref: #8366 * use I64 for indices * ggml : add repeat impl for i64 * ggml : add ggml_is_contiguous_rows * ggml : ggml_set_rows support broadcast * ggml : ggml_set_rows support quantized dst ggml-ci * ggml : support GGML_TYPE_F32 ".from_float" trait * ggml : ggml_set_rows update comment + better index name * tests : add ggml_set_rows * metal : add ggml_set_rows implementation ggml-ci * ggml : simplify forward_dup_f32 * ggml : fix supports_op * tests : add comment to set_rows * ggml : leave the repeat_i64 for a separate PR ggml-ci * ggml : set_rows use std::min instead of MIN * ggml : better error message for set_rows unsupported type * metal : perform op->type check only once * tests : more consistent implementation + more tests ggml-ci --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent f667f1e commit 8d94219

File tree

12 files changed

+653
-204
lines changed

12 files changed

+653
-204
lines changed

examples/eval-callback/eval-callback.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
5555
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
5656
} else if (type == GGML_TYPE_F32) {
5757
v = *(float *) &data[i];
58+
} else if (type == GGML_TYPE_I64) {
59+
v = (float) *(int64_t *) &data[i];
5860
} else if (type == GGML_TYPE_I32) {
5961
v = (float) *(int32_t *) &data[i];
6062
} else if (type == GGML_TYPE_I16) {

ggml/include/ggml-cpu.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ extern "C" {
134134

135135
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
136136

137+
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
137138
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
138139
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
139140
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);

ggml/include/ggml.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ extern "C" {
470470
GGML_OP_TRANSPOSE,
471471
GGML_OP_GET_ROWS,
472472
GGML_OP_GET_ROWS_BACK,
473+
GGML_OP_SET_ROWS,
473474
GGML_OP_DIAG,
474475
GGML_OP_DIAG_MASK_INF,
475476
GGML_OP_DIAG_MASK_ZERO,
@@ -687,6 +688,9 @@ extern "C" {
687688
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
688689
GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
689690

691+
// true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
692+
GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
693+
690694
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
691695
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
692696

@@ -1375,6 +1379,23 @@ extern "C" {
13751379
struct ggml_tensor * b, // row indices
13761380
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
13771381

1382+
// a TD [n_embd, ne1, ne2, ne3]
1383+
// b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
1384+
// c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
1385+
//
1386+
// undefined behavior if destination rows overlap
1387+
//
1388+
// broadcast:
1389+
// ne2 % ne11 == 0
1390+
// ne3 % ne12 == 0
1391+
//
1392+
// return view(a)
1393+
GGML_API struct ggml_tensor * ggml_set_rows(
1394+
struct ggml_context * ctx,
1395+
struct ggml_tensor * a, // destination
1396+
struct ggml_tensor * b, // source
1397+
struct ggml_tensor * c); // row indices
1398+
13781399
GGML_API struct ggml_tensor * ggml_diag(
13791400
struct ggml_context * ctx,
13801401
struct ggml_tensor * a);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ typedef pthread_t ggml_thread_t;
195195

196196
static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
197197
[GGML_TYPE_F32] = {
198+
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
198199
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
199200
.vec_dot_type = GGML_TYPE_F32,
200201
.nrows = 1,
@@ -1817,6 +1818,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
18171818
{
18181819
ggml_compute_forward_get_rows_back(params, tensor);
18191820
} break;
1821+
case GGML_OP_SET_ROWS:
1822+
{
1823+
ggml_compute_forward_set_rows(params, tensor);
1824+
} break;
18201825
case GGML_OP_DIAG:
18211826
{
18221827
ggml_compute_forward_diag(params, tensor);
@@ -2170,6 +2175,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21702175
n_tasks = n_threads;
21712176
} break;
21722177
case GGML_OP_GET_ROWS:
2178+
case GGML_OP_SET_ROWS:
21732179
{
21742180
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
21752181
// decreases performance with GPU offloading
@@ -3124,6 +3130,10 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
31243130
return ggml_graph_compute(cgraph, &cplan);
31253131
}
31263132

3133+
void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
3134+
memcpy(y, x, n * sizeof(float));
3135+
}
3136+
31273137
void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
31283138
int64_t i = 0;
31293139
#if defined(__F16C__)

ggml/src/ggml-cpu/ggml-cpu.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
416416

417417
switch (op->op) {
418418
case GGML_OP_CPY:
419+
case GGML_OP_SET_ROWS:
419420
return
420421
op->type != GGML_TYPE_IQ3_XXS &&
421422
op->type != GGML_TYPE_IQ3_S &&

ggml/src/ggml-cpu/ops.cpp

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -696,24 +696,8 @@ static void ggml_compute_forward_dup_f32(
696696
if (ggml_is_contiguous(dst)) {
697697
// TODO: simplify
698698
if (nb00 == sizeof(float)) {
699-
if (dst->type == GGML_TYPE_F32) {
700-
size_t id = 0;
701-
const size_t rs = ne00 * nb00;
702-
char * dst_ptr = (char *) dst->data;
703-
704-
for (int i03 = 0; i03 < ne03; i03++) {
705-
for (int i02 = 0; i02 < ne02; i02++) {
706-
id += rs * ir0;
707-
for (int i01 = ir0; i01 < ir1; i01++) {
708-
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
709-
memcpy(dst_ptr + id, src0_ptr, rs);
710-
id += rs;
711-
}
712-
id += rs * (ne01 - ir1);
713-
}
714-
}
715-
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
716-
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
699+
if (ggml_get_type_traits_cpu(dst->type)->from_float) {
700+
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
717701

718702
size_t id = 0;
719703
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
@@ -724,7 +708,7 @@ static void ggml_compute_forward_dup_f32(
724708
id += rs * ir0;
725709
for (int i01 = ir0; i01 < ir1; i01++) {
726710
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
727-
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
711+
from_float(src0_ptr, dst_ptr + id, ne00);
728712
id += rs;
729713
}
730714
id += rs * (ne01 - ir1);
@@ -2300,6 +2284,12 @@ void ggml_compute_forward_repeat(
23002284
{
23012285
ggml_compute_forward_repeat_f32(params, dst);
23022286
} break;
2287+
// TODO: templateify the implemenation and support for I64
2288+
// ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
2289+
//case GGML_TYPE_I64:
2290+
// {
2291+
// ggml_compute_forward_repeat_i64(params, dst);
2292+
// } break;
23032293
default:
23042294
{
23052295
GGML_ABORT("fatal error");
@@ -4470,6 +4460,74 @@ void ggml_compute_forward_get_rows(
44704460
//}
44714461
}
44724462

4463+
static void ggml_compute_forward_set_rows_f32(
4464+
const ggml_compute_params * params,
4465+
ggml_tensor * dst) {
4466+
4467+
const ggml_tensor * src0 = dst->src[0];
4468+
const ggml_tensor * src1 = dst->src[1];
4469+
4470+
GGML_TENSOR_BINARY_OP_LOCALS
4471+
4472+
const int64_t nc = ne00;
4473+
const int64_t nr = ne01;
4474+
4475+
assert(ne0 == nc);
4476+
assert(ne2 == ne02);
4477+
assert(ne3 == ne03);
4478+
assert(src0->type == GGML_TYPE_F32);
4479+
assert(ne02 % ne11 == 0);
4480+
assert(ne03 % ne12 == 0);
4481+
4482+
const int ith = params->ith;
4483+
const int nth = params->nth;
4484+
4485+
// rows per thread
4486+
const int64_t dr = (nr + nth - 1)/nth;
4487+
4488+
// row range for this thread
4489+
const int64_t ir0 = dr*ith;
4490+
const int64_t ir1 = std::min(ir0 + dr, nr);
4491+
4492+
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
4493+
4494+
for (int64_t i03 = 0; i03 < ne03; ++i03) {
4495+
for (int64_t i02 = 0; i02 < ne02; ++i02) {
4496+
for (int64_t i = ir0; i < ir1; ++i) {
4497+
const int64_t i12 = i03%ne12;
4498+
const int64_t i11 = i02%ne11;
4499+
const int64_t i10 = i;
4500+
4501+
const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4502+
4503+
GGML_ASSERT(i1 >= 0 && i1 < ne1);
4504+
4505+
from_float(
4506+
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
4507+
((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
4508+
}
4509+
}
4510+
}
4511+
}
4512+
4513+
void ggml_compute_forward_set_rows(
4514+
const ggml_compute_params * params,
4515+
ggml_tensor * dst) {
4516+
4517+
const ggml_tensor * src0 = dst->src[0];
4518+
4519+
switch (src0->type) {
4520+
case GGML_TYPE_F32:
4521+
{
4522+
ggml_compute_forward_set_rows_f32(params, dst);
4523+
} break;
4524+
default:
4525+
{
4526+
GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
4527+
}
4528+
}
4529+
}
4530+
44734531
// ggml_compute_forward_get_rows_back
44744532

44754533
static void ggml_compute_forward_get_rows_back_f32_f16(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str
5353
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5454
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5555
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
56+
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5657
void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5758
void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5859
void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,22 @@ typedef struct {
521521
uint64_t nb2;
522522
} ggml_metal_kargs_get_rows;
523523

524+
typedef struct {
525+
int32_t nk0;
526+
int32_t ne01;
527+
uint64_t nb01;
528+
uint64_t nb02;
529+
uint64_t nb03;
530+
int32_t ne11;
531+
int32_t ne12;
532+
uint64_t nb10;
533+
uint64_t nb11;
534+
uint64_t nb12;
535+
uint64_t nb1;
536+
uint64_t nb2;
537+
uint64_t nb3;
538+
} ggml_metal_kargs_set_rows;
539+
524540
typedef struct {
525541
int64_t ne00;
526542
int64_t ne01;

0 commit comments

Comments
 (0)