@@ -5964,7 +5964,30 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
5964
5964
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
5965
5965
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
5966
5966
} else {
5967
- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
5967
+ // Split based on number of ids, to fix in shared memory
5968
+ const uint32_t nei0 = (uint32_t)src2->ne[0];
5969
+ const uint32_t nei1 = (uint32_t)src2->ne[1];
5970
+
5971
+ GGML_ASSERT(nei0 <= 4096);
5972
+ const uint32_t split_size = std::min(nei1, 4096u / nei0);
5973
+
5974
+ for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
5975
+ const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
5976
+
5977
+ ggml_tensor src1_copy = *src1;
5978
+ ggml_tensor src2_copy = *src2;
5979
+ ggml_tensor dst_copy = *dst;
5980
+
5981
+ src1_copy.view_offs += token_start * src1_copy.nb[2];
5982
+ src2_copy.view_offs += token_start * src2_copy.nb[1];
5983
+ dst_copy.view_offs += token_start * dst_copy.nb[2];
5984
+
5985
+ src1_copy.ne[2] = n_tokens;
5986
+ src2_copy.ne[1] = n_tokens;
5987
+ dst_copy.ne[2] = n_tokens;
5988
+
5989
+ ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
5990
+ }
5968
5991
}
5969
5992
}
5970
5993
@@ -10127,9 +10150,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10127
10150
ggml_type src0_type = op->src[0]->type;
10128
10151
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
10129
10152
const vk_device& device = ggml_vk_get_device(ctx->device);
10130
- if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
10131
- // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
10132
- return false;
10153
+ if (op->op == GGML_OP_MUL_MAT_ID) {
10154
+ if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
10155
+ // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
10156
+ return false;
10157
+ }
10158
+ // Check against size of shared memory variable
10159
+ if (op->src[2]->ne[0] > 4096) {
10160
+ return false;
10161
+ }
10133
10162
}
10134
10163
switch (src0_type) {
10135
10164
case GGML_TYPE_F32:
0 commit comments