22
22
23
23
namespace qwen {
24
24
25
+ ggml_tensor *tensor_assign_buffers (ggml_tensor *tensor) {
26
+ #ifdef GGML_USE_CUBLAS
27
+ ggml_cuda_assign_buffers (tensor);
28
+ #endif
29
+ return tensor;
30
+ }
31
+
25
32
auto tensor_to_device (ggml_tensor *tensor) -> ggml_tensor * {
33
+ #ifdef GGML_USE_CUBLAS
34
+ if (tensor->backend == GGML_BACKEND_CPU) {
35
+ tensor->backend = GGML_BACKEND_GPU;
36
+ ggml_cuda_transform_tensor (tensor->data , tensor);
37
+ }
38
+ #endif
26
39
return tensor;
27
40
}
28
41
29
42
auto tensor_to_cpu (ggml_tensor *tensor) -> ggml_tensor * {
43
+ #ifdef GGML_USE_CUBLAS
44
+ if (tensor->backend != GGML_BACKEND_CPU) {
45
+ ggml_cuda_free_data (tensor);
46
+ tensor->backend = GGML_BACKEND_CPU;
47
+ }
48
+ #endif
30
49
return tensor;
31
50
}
32
51
@@ -210,18 +229,18 @@ auto Embedding::forward(ModelContext *ctx, ggml_tensor *input) const -> ggml_ten
210
229
auto Linear::forward (ModelContext *ctx, ggml_tensor *input) const -> ggml_tensor * {
211
230
// input: [seqlen, in_features]
212
231
ggml_context *gctx = ctx->ctx_b .get ();
213
- ggml_tensor *output = ggml_mul_mat (gctx, weight, input); // [seqlen, out_features]
232
+ ggml_tensor *output = tensor_assign_buffers ( ggml_mul_mat (gctx, weight, input) ); // [seqlen, out_features]
214
233
if (bias) {
215
- output = ggml_add_inplace (gctx, output, bias);
234
+ output = tensor_assign_buffers ( ggml_add_inplace (gctx, output, bias) );
216
235
}
217
236
return output;
218
237
}
219
238
220
239
auto RMSNorm::forward (ModelContext *ctx, ggml_tensor *input, float eps) const -> ggml_tensor * {
221
240
ggml_context *gctx = ctx->ctx_b .get ();
222
241
auto ggml_rms_norm_fn = inplace ? ggml_rms_norm_inplace : ggml_rms_norm;
223
- ggml_tensor *output = ggml_rms_norm_fn (gctx, input, eps);
224
- output = ggml_mul_inplace (gctx, output, weight);
242
+ ggml_tensor *output = tensor_assign_buffers ( ggml_rms_norm_fn (gctx, input, eps) );
243
+ output = tensor_assign_buffers ( ggml_mul_inplace (gctx, output, weight) );
225
244
return output;
226
245
}
227
246
@@ -261,9 +280,9 @@ QwenTokenizer::QwenTokenizer(const std::string & tiktoken_path, const QwenConfig
261
280
}
262
281
263
282
std::vector<std::string> special_tokens_s{" <|endoftext|>" , " <|im_start|>" , " <|im_end|>" };
264
- char buffer[12 ];
283
+ char buffer[14 ];
265
284
for (size_t i = 0 ; i < 205 ; i++) {
266
- snprintf (buffer, 12 , " <|extra_%zu|>" , i);
285
+ snprintf (buffer, 14 , " <|extra_%zu|>" , i);
267
286
special_tokens_s.push_back (buffer);
268
287
}
269
288
size_t encoder_size = encoder.size ();
@@ -328,71 +347,79 @@ QwenAttention::QwenAttention(ModelContext *ctx, int hidden_size, int num_attenti
328
347
v_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, max_length, hidden_size / num_attention_heads,
329
348
num_kv_heads)) {}
330
349
331
- auto QwenAttention::forward (ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos) const -> ggml_tensor * {
350
+ auto QwenAttention::forward (ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos, int n_ctx ) const -> ggml_tensor * {
332
351
ggml_context *gctx = ctx->ctx_b .get ();
333
352
334
353
const int hidden_size = hidden_states->ne [0 ];
335
354
const int qlen = hidden_states->ne [1 ];
336
355
const int head_size = hidden_size / num_attention_heads;
337
356
const int rope_dim = head_size;
338
- const int mqa_scale = num_attention_heads / num_kv_heads;
339
357
const int n_past = static_cast <int *>(KQ_pos->data )[0 ];
340
358
341
359
ggml_tensor *qkv = c_attn.forward (ctx, hidden_states); // [qlen, hidden + 2 * kv_hidden]
342
360
ggml_tensor *query_layer =
343
361
ggml_view_3d (gctx, qkv, head_size, num_attention_heads, qlen, head_size * ggml_element_size (qkv), qkv->nb [1 ],
344
362
0 ); // [qlen, heads, head_size]
345
- query_layer = ggml_rope_inplace (gctx, query_layer, KQ_pos, rope_dim, 2 , 0 );
346
- query_layer = ggml_cont (gctx, ggml_permute (gctx, query_layer, 0 , 2 , 1 , 3 )); // [heads, qlen, head_size]
347
- query_layer = ggml_reshape_3d (gctx, query_layer, head_size, mqa_scale * qlen, num_kv_heads); // [kv_heads, mqa_scale * qlen, head_size]
363
+ #ifdef GGML_USE_CUBLAS
364
+ if (!ggml_is_contiguous (query_layer)) {
365
+ query_layer = tensor_assign_buffers (ggml_cont (gctx, query_layer));
366
+ }
367
+ #endif
368
+ query_layer = tensor_assign_buffers (ggml_rope_inplace (gctx, query_layer, KQ_pos, rope_dim, 2 , n_ctx));
369
+ query_layer = tensor_assign_buffers (ggml_cont (gctx, ggml_permute (gctx, query_layer, 0 , 2 , 1 , 3 ))); // [heads, qlen, head_size]
348
370
349
371
ggml_tensor *key_layer =
350
372
ggml_view_3d (gctx, qkv, head_size, num_kv_heads, qlen, head_size * ggml_element_size (qkv), qkv->nb [1 ],
351
373
hidden_size * ggml_element_size (qkv)); // [qlen, kv_heads, head_size]
352
- key_layer = ggml_rope_inplace (gctx, key_layer, KQ_pos, rope_dim, 2 , 0 );
353
- key_layer = ggml_permute (gctx, key_layer, 0 , 2 , 1 , 3 ); // [kv_heads, qlen, head_size]
374
+ #ifdef GGML_USE_CUBLAS
375
+ if (!ggml_is_contiguous (key_layer)) {
376
+ key_layer = tensor_assign_buffers (ggml_cont (gctx, key_layer));
377
+ }
378
+ #endif
379
+ key_layer = tensor_assign_buffers (ggml_rope_inplace (gctx, key_layer, KQ_pos, rope_dim, 2 , n_ctx));
380
+ key_layer = tensor_assign_buffers (ggml_permute (gctx, key_layer, 0 , 2 , 1 , 3 )); // [kv_heads, qlen, head_size]
354
381
355
382
ggml_tensor *value_layer =
356
383
ggml_view_3d (gctx, qkv, head_size, num_kv_heads, qlen, head_size * ggml_element_size (qkv), qkv->nb [1 ],
357
384
(hidden_size + head_size * num_kv_heads) * ggml_element_size (qkv)); // [qlen, kv_heads, head_size]
358
- value_layer = ggml_permute (gctx, value_layer, 1 , 2 , 0 , 3 ); // [kv_heads, head_size, qlen]
385
+ value_layer = tensor_assign_buffers ( ggml_permute (gctx, value_layer, 1 , 2 , 0 , 3 ) ); // [kv_heads, head_size, qlen]
359
386
360
387
// store key & value to cache
361
- ggml_tensor *k_cache_view = ggml_view_3d (
362
- gctx, k_cache, head_size, qlen, num_kv_heads, k_cache->nb [1 ], k_cache->nb [2 ],
363
- n_past * head_size * ggml_element_size (k_cache)); // [kv_heads, qlen, head_size]
388
+ ggml_tensor *k_cache_view = tensor_assign_buffers (
389
+ ggml_view_3d ( gctx, k_cache, head_size, qlen, num_kv_heads, k_cache->nb [1 ], k_cache->nb [2 ],
390
+ n_past * head_size * ggml_element_size (k_cache) )); // [kv_heads, qlen, head_size]
364
391
ggml_build_forward_expand (&ctx->gf , ggml_cpy (gctx, key_layer, k_cache_view));
365
- ggml_tensor *v_cache_view = ggml_view_3d (
366
- gctx, v_cache, qlen, head_size, num_kv_heads, v_cache->nb [1 ], v_cache->nb [2 ],
367
- n_past * ggml_element_size (v_cache)); // [kv_heads, head_size, qlen]
392
+ ggml_tensor *v_cache_view = tensor_assign_buffers (
393
+ ggml_view_3d ( gctx, v_cache, qlen, head_size, num_kv_heads, v_cache->nb [1 ], v_cache->nb [2 ],
394
+ n_past * ggml_element_size (v_cache) )); // [kv_heads, head_size, qlen]
368
395
ggml_build_forward_expand (&ctx->gf , ggml_cpy (gctx, value_layer, v_cache_view));
369
396
370
397
// concat key & value with past kv
371
- key_layer = ggml_view_3d (gctx, k_cache, head_size, n_past + qlen, num_kv_heads,
372
- k_cache-> nb [ 1 ], k_cache-> nb [ 2 ] ,
373
- 0 ); // [kv_heads, klen, head_size]
374
- value_layer = ggml_view_3d (gctx, v_cache, n_past + qlen, head_size, num_kv_heads,
375
- v_cache-> nb [ 1 ], v_cache-> nb [ 2 ] ,
376
- 0 ); // [kv_heads, head_size, klen]
398
+ key_layer = tensor_assign_buffers (
399
+ ggml_view_3d (gctx, k_cache, head_size, n_past + qlen, num_kv_heads ,
400
+ k_cache-> nb [ 1 ], k_cache-> nb [ 2 ], 0 ) ); // [kv_heads, klen, head_size]
401
+ value_layer = tensor_assign_buffers (
402
+ ggml_view_3d (gctx, v_cache, n_past + qlen, head_size, num_kv_heads ,
403
+ v_cache-> nb [ 1 ], v_cache-> nb [ 2 ], 0 ) ); // [kv_heads, head_size, klen]
377
404
378
405
// attention
379
- ggml_tensor *attn_scores = ggml_mul_mat (gctx, key_layer, query_layer); // [kv_heads, mqa_scale * qlen, klen]
380
- attn_scores = ggml_scale_inplace (gctx, attn_scores, ggml_new_f32 (gctx, 1 .f / std::sqrt (head_size)));
406
+ ggml_tensor *attn_scores =
407
+ tensor_assign_buffers (ggml_mul_mat (gctx, key_layer, query_layer)); // [kv_heads, mqa_scale * qlen, klen]
408
+ attn_scores = tensor_assign_buffers (
409
+ ggml_scale_inplace (gctx, attn_scores, ggml_new_f32 (gctx, 1 .f / std::sqrt (head_size))));
381
410
if (n_past == 0 ) {
382
411
// build attention mask for context input
383
- attn_scores = ggml_reshape_3d (gctx, attn_scores, n_past + qlen, qlen,
384
- num_attention_heads); // [heads, qlen, klen]
385
- attn_scores = ggml_diag_mask_inf_inplace (gctx, attn_scores, n_past);
386
- attn_scores = ggml_reshape_3d (gctx, attn_scores, n_past + qlen, mqa_scale * qlen,
387
- num_kv_heads); // [kv_heads, mqa_scale * qlen, klen]
412
+ attn_scores = tensor_assign_buffers (ggml_diag_mask_inf_inplace (gctx, attn_scores, n_past));
388
413
}
389
- ggml_tensor *attn_probs = ggml_soft_max_inplace (gctx, attn_scores); // [kv_heads, mqa_scale * qlen, klen]
414
+ ggml_tensor *attn_probs =
415
+ tensor_assign_buffers (ggml_soft_max_inplace (gctx, attn_scores)); // [kv_heads, mqa_scale * qlen, klen]
390
416
391
- ggml_tensor *context_layer = ggml_mul_mat (gctx, value_layer, attn_probs); // [kv_heads, mqa_scale * qlen, head_size]
392
- context_layer = ggml_reshape_3d (gctx, context_layer, head_size, qlen,
393
- num_attention_heads); // [heads, qlen, head_size]
394
- context_layer = ggml_cont (gctx, ggml_permute (gctx, context_layer, 0 , 2 , 1 , 3 )); // [qlen, heads, head_size]
395
- context_layer = ggml_reshape_2d (gctx, context_layer, hidden_size, qlen); // [qlen, hidden]
417
+ ggml_tensor *context_layer = tensor_assign_buffers (
418
+ ggml_mul_mat (gctx, value_layer, attn_probs)); // [kv_heads, mqa_scale * qlen, head_size]
419
+ context_layer = tensor_assign_buffers (
420
+ ggml_cont (gctx, ggml_permute (gctx, context_layer, 0 , 2 , 1 , 3 ))); // [qlen, heads, head_size]
421
+ context_layer = tensor_assign_buffers (
422
+ ggml_reshape_2d (gctx, context_layer, hidden_size, qlen)); // [qlen, hidden]
396
423
397
424
ggml_tensor *attn_output = c_proj.forward (ctx, context_layer);
398
425
return attn_output;
@@ -402,26 +429,26 @@ auto QwenMLP::forward(ModelContext *ctx, ggml_tensor *hidden_states) const -> gg
402
429
ggml_context *gctx = ctx->ctx_b .get ();
403
430
404
431
ggml_tensor *a2 = w2.forward (ctx, hidden_states);
405
- a2 = ggml_silu_inplace (gctx, a2);
432
+ a2 = tensor_assign_buffers ( ggml_silu_inplace (gctx, a2) );
406
433
ggml_tensor *a1 = w1.forward (ctx, hidden_states);
407
434
408
- ggml_tensor *output = ggml_mul_inplace (gctx, a2, a1);
435
+ ggml_tensor *output = tensor_assign_buffers ( ggml_mul_inplace (gctx, a2, a1) );
409
436
output = c_proj.forward (ctx, output);
410
437
return output;
411
438
}
412
439
413
- auto QwenBlock::forward (ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos) const -> ggml_tensor * {
440
+ auto QwenBlock::forward (ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos, int n_ctx ) const -> ggml_tensor * {
414
441
ggml_context *gctx = ctx->ctx_b .get ();
415
442
416
443
ggml_tensor *residual = hidden_states;
417
444
hidden_states = ln_1.forward (ctx, hidden_states, 1e-6f );
418
- hidden_states = attn.forward (ctx, hidden_states, KQ_pos);
419
- hidden_states = ggml_add_inplace (gctx, hidden_states, residual);
445
+ hidden_states = attn.forward (ctx, hidden_states, KQ_pos, n_ctx );
446
+ hidden_states = tensor_assign_buffers ( ggml_add_inplace (gctx, hidden_states, residual) );
420
447
421
448
residual = hidden_states;
422
449
hidden_states = ln_2.forward (ctx, hidden_states, 1e-6f );
423
450
hidden_states = mlp.forward (ctx, hidden_states);
424
- hidden_states = ggml_add_inplace (gctx, hidden_states, residual);
451
+ hidden_states = tensor_assign_buffers ( ggml_add_inplace (gctx, hidden_states, residual) );
425
452
426
453
return hidden_states;
427
454
}
@@ -435,12 +462,12 @@ QwenModel::QwenModel(ModelContext *ctx, const QwenConfig &config)
435
462
}
436
463
}
437
464
438
- auto QwenModel::forward (ModelContext *ctx, ggml_tensor *input_ids, ggml_tensor *KQ_pos) const -> ggml_tensor * {
465
+ auto QwenModel::forward (ModelContext *ctx, ggml_tensor *input_ids, ggml_tensor *KQ_pos, int n_ctx ) const -> ggml_tensor * {
439
466
ggml_context *gctx = ctx->ctx_b .get ();
440
467
ggml_tensor *hidden_states = wte.forward (ctx, input_ids);
441
468
for (const auto &layer : layers) {
442
469
ggml_set_scratch (gctx, ctx->scratch );
443
- hidden_states = layer.forward (ctx, hidden_states, KQ_pos);
470
+ hidden_states = layer.forward (ctx, hidden_states, KQ_pos, n_ctx );
444
471
}
445
472
ggml_scratch empty_scratch = {0 , 0 , nullptr };
446
473
ggml_set_scratch (gctx, empty_scratch);
@@ -455,17 +482,25 @@ auto get_num_physical_cores() -> int {
455
482
}
456
483
457
484
auto get_default_num_threads () -> int {
485
+ #ifdef GGML_USE_CUBLAS
486
+ return 1 ;
487
+ #else
458
488
return std::min (get_num_physical_cores (), 16 );
489
+ #endif
459
490
}
460
491
461
492
QwenForCausalLM::QwenForCausalLM (const QwenConfig &config)
462
493
: config(config) {
463
494
ctx_.compute_buffer .resize (MEM_SIZE);
464
495
ctx_.scratch_buffer .resize (SCRATCH_SIZE);
465
496
ctx_.scratch = {0 , ctx_.scratch_buffer .size (), ctx_.scratch_buffer .data ()};
497
+ #ifdef GGML_USE_CUBLAS
498
+ ggml_cuda_set_scratch_size (SCRATCH_SIZE);
499
+ #endif
466
500
constexpr size_t tensor_ovhd = GGML_TENSOR_SIZE + GGML_OBJECT_SIZE;
467
501
const size_t ctx_w_size = (3 + config.num_hidden_layers * 8 ) * tensor_ovhd;
468
- const size_t ctx_kv_size = 2 * config.num_hidden_layers * (config.max_length * config.hidden_size / config.num_attention_heads * config.num_kv_heads * ggml_type_size (GGML_TYPE_F16) + tensor_ovhd);
502
+ const size_t ctx_kv_size = 2 * config.num_hidden_layers *
503
+ (config.max_length * config.hidden_size / config.num_attention_heads * config.num_kv_heads * ggml_type_size (GGML_TYPE_F16) + tensor_ovhd);
469
504
ctx_.dtype = config.dtype ;
470
505
ctx_.ctx_w = make_unique_ggml_context (ctx_w_size, nullptr , true );
471
506
ctx_.ctx_kv = make_unique_ggml_context (ctx_kv_size + 1 * MB, nullptr , false );
@@ -537,9 +572,15 @@ auto QwenForCausalLM::generate_next_token(
537
572
for (int i = 0 ; i < curr_input_ids_size; ++i) {
538
573
data[i] = n_past + i;
539
574
}
575
+ if (KQ_pos) {
576
+ tensor_to_device (KQ_pos);
577
+ }
540
578
541
579
ggml_tensor *lm_logits = forward (&ctx_, curr_input_ids, KQ_pos, n_ctx);
542
580
lm_logits->backend = GGML_BACKEND_CPU;
581
+ if (KQ_pos) {
582
+ tensor_to_cpu (KQ_pos);
583
+ }
543
584
544
585
ggml_build_forward_expand (&ctx_.gf , lm_logits);
545
586
ggml_graph_compute_helper (ctx_.work_buffer , &ctx_.gf , n_threads);
@@ -731,12 +772,12 @@ auto QwenForCausalLM::forward(
731
772
ggml_tensor *KQ_pos,
732
773
int n_ctx
733
774
) const -> ggml_tensor * {
734
- ggml_tensor *transformer_outputs = transformer.forward (ctx, input_ids, KQ_pos);
775
+ ggml_tensor *transformer_outputs = transformer.forward (ctx, input_ids, KQ_pos, n_ctx );
735
776
// NOTE: only compute next_token_logits for the last token
736
777
if (input_ids->ne [0 ] > 1 ) {
737
- transformer_outputs =
778
+ transformer_outputs = tensor_assign_buffers (
738
779
ggml_view_1d (ctx->ctx_b .get (), transformer_outputs, config.hidden_size ,
739
- (input_ids->ne [0 ] - 1 ) * config.hidden_size * ggml_element_size (transformer_outputs));
780
+ (input_ids->ne [0 ] - 1 ) * config.hidden_size * ggml_element_size (transformer_outputs))) ;
740
781
}
741
782
ggml_tensor *lm_logits = lm_head.forward (ctx, transformer_outputs);
742
783
return lm_logits;
0 commit comments