Skip to content

Commit

Permalink
metal : move mm_id indices to shared mem (ggerganov#5982)
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov authored and hodlen committed Apr 1, 2024
1 parent 95b2071 commit f6fa30e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions ggml-metal.m
Expand Up @@ -1642,8 +1642,8 @@ static enum ggml_status ggml_metal_graph_compute(
// TODO: make this more general
GGML_ASSERT(n_as <= 8);

// max size of the src1ids array in the kernel stack
GGML_ASSERT(ne11 <= 512);
// max size of the src1ids array in the kernel shared buffer
GGML_ASSERT(ne11 <= 4096);

const int64_t ne20 = src2 ? src2->ne[0] : 0;
const int64_t ne21 = src2 ? src2->ne[1] : 0;
Expand Down Expand Up @@ -1741,7 +1741,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
}

[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];

[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
Expand Down
6 changes: 3 additions & 3 deletions ggml-metal.metal
Expand Up @@ -5386,7 +5386,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
void kernel_mul_mm_id_impl(
device const uchar * src0,
device const uchar * src1,
thread short * src1ids,
threadgroup short * src1ids,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
Expand Down Expand Up @@ -5589,9 +5589,9 @@ kernel void kernel_mul_mm_id(
tgpig.z = tgpig.z%(ne12*ne13);

// row indices of src1 for expert id
int64_t _ne1 = 0;
short src1ids[512];
threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);

int64_t _ne1 = 0;
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
src1ids[_ne1++] = i1;
Expand Down

0 comments on commit f6fa30e

Please sign in to comment.