Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metal : move mm_id indices to shared mem #5982

Merged
merged 1 commit into from Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions ggml-metal.m
Expand Up @@ -1642,8 +1642,8 @@ static enum ggml_status ggml_metal_graph_compute(
// TODO: make this more general
GGML_ASSERT(n_as <= 8);

// max size of the src1ids array in the kernel stack
GGML_ASSERT(ne11 <= 512);
// max size of the src1ids array in the kernel shared buffer
GGML_ASSERT(ne11 <= 4096);

const int64_t ne20 = src2 ? src2->ne[0] : 0;
const int64_t ne21 = src2 ? src2->ne[1] : 0;
Expand Down Expand Up @@ -1741,7 +1741,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
}

[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];

[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
Expand Down
6 changes: 3 additions & 3 deletions ggml-metal.metal
Expand Up @@ -5386,7 +5386,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
void kernel_mul_mm_id_impl(
device const uchar * src0,
device const uchar * src1,
thread short * src1ids,
threadgroup short * src1ids,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
Expand Down Expand Up @@ -5589,9 +5589,9 @@ kernel void kernel_mul_mm_id(
tgpig.z = tgpig.z%(ne12*ne13);

// row indices of src1 for expert id
int64_t _ne1 = 0;
short src1ids[512];
threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);

int64_t _ne1 = 0;
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
src1ids[_ne1++] = i1;
Expand Down