Skip to content

Commit 72cfb74

Browse files
committed
support cuda
1 parent 752c18f commit 72cfb74

File tree

4 files changed

+109
-56
lines changed

4 files changed

+109
-56
lines changed

CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,16 @@ add_subdirectory(third_party/abseil-cpp)
2020

2121
add_subdirectory(third_party/re2)
2222

23-
add_compile_definitions(GGML_CUDA_MMV_Y=2) # for large vocab
23+
add_compile_definitions(GGML_CUDA_MMV_Y=3) # for large vocab
2424
include_directories(third_party/ggml/include/ggml third_party/ggml/src)
2525
add_subdirectory(third_party/ggml)
2626

27+
if (GGML_CUBLAS)
28+
add_compile_definitions(GGML_USE_CUBLAS)
29+
set(CUDA_ARCHITECTURES "52;61;70;75;80;86" CACHE STRING "qwen: cuda architectures to compile")
30+
set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES ${CUDA_ARCHITECTURES})
31+
endif ()
32+
2733
file(GLOB CPP_SOURCES
2834
${PROJECT_SOURCE_DIR}/*.h
2935
${PROJECT_SOURCE_DIR}/*.cpp)

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Highlights:
1111
* [x] Python binding.
1212

1313
Support Matrix:
14-
* Hardwares: x86/arm CPU
14+
* Hardwares: x86/arm CPU, NVIDIA GPU
1515
* Platforms: Linux, MacOS
1616
* Models: [Qwen-LM](https://github.com/QwenLM/Qwen)
1717

qwen.cpp

Lines changed: 92 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,30 @@
2222

2323
namespace qwen {
2424

25+
ggml_tensor *tensor_assign_buffers(ggml_tensor *tensor) {
26+
#ifdef GGML_USE_CUBLAS
27+
ggml_cuda_assign_buffers(tensor);
28+
#endif
29+
return tensor;
30+
}
31+
2532
auto tensor_to_device(ggml_tensor *tensor) -> ggml_tensor * {
33+
#ifdef GGML_USE_CUBLAS
34+
if (tensor->backend == GGML_BACKEND_CPU) {
35+
tensor->backend = GGML_BACKEND_GPU;
36+
ggml_cuda_transform_tensor(tensor->data, tensor);
37+
}
38+
#endif
2639
return tensor;
2740
}
2841

2942
auto tensor_to_cpu(ggml_tensor *tensor) -> ggml_tensor * {
43+
#ifdef GGML_USE_CUBLAS
44+
if (tensor->backend != GGML_BACKEND_CPU) {
45+
ggml_cuda_free_data(tensor);
46+
tensor->backend = GGML_BACKEND_CPU;
47+
}
48+
#endif
3049
return tensor;
3150
}
3251

@@ -210,18 +229,18 @@ auto Embedding::forward(ModelContext *ctx, ggml_tensor *input) const -> ggml_ten
210229
auto Linear::forward(ModelContext *ctx, ggml_tensor *input) const -> ggml_tensor * {
211230
// input: [seqlen, in_features]
212231
ggml_context *gctx = ctx->ctx_b.get();
213-
ggml_tensor *output = ggml_mul_mat(gctx, weight, input); // [seqlen, out_features]
232+
ggml_tensor *output = tensor_assign_buffers(ggml_mul_mat(gctx, weight, input)); // [seqlen, out_features]
214233
if (bias) {
215-
output = ggml_add_inplace(gctx, output, bias);
234+
output = tensor_assign_buffers(ggml_add_inplace(gctx, output, bias));
216235
}
217236
return output;
218237
}
219238

220239
auto RMSNorm::forward(ModelContext *ctx, ggml_tensor *input, float eps) const -> ggml_tensor * {
221240
ggml_context *gctx = ctx->ctx_b.get();
222241
auto ggml_rms_norm_fn = inplace ? ggml_rms_norm_inplace : ggml_rms_norm;
223-
ggml_tensor *output = ggml_rms_norm_fn(gctx, input, eps);
224-
output = ggml_mul_inplace(gctx, output, weight);
242+
ggml_tensor *output = tensor_assign_buffers(ggml_rms_norm_fn(gctx, input, eps));
243+
output = tensor_assign_buffers(ggml_mul_inplace(gctx, output, weight));
225244
return output;
226245
}
227246

@@ -261,9 +280,9 @@ QwenTokenizer::QwenTokenizer(const std::string & tiktoken_path, const QwenConfig
261280
}
262281

263282
std::vector<std::string> special_tokens_s{"<|endoftext|>", "<|im_start|>", "<|im_end|>"};
264-
char buffer[12];
283+
char buffer[14];
265284
for (size_t i = 0; i < 205; i++) {
266-
snprintf(buffer, 12, "<|extra_%zu|>", i);
285+
snprintf(buffer, 14, "<|extra_%zu|>", i);
267286
special_tokens_s.push_back(buffer);
268287
}
269288
size_t encoder_size = encoder.size();
@@ -328,71 +347,79 @@ QwenAttention::QwenAttention(ModelContext *ctx, int hidden_size, int num_attenti
328347
v_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, max_length, hidden_size / num_attention_heads,
329348
num_kv_heads)) {}
330349

331-
auto QwenAttention::forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos) const -> ggml_tensor * {
350+
auto QwenAttention::forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos, int n_ctx) const -> ggml_tensor * {
332351
ggml_context *gctx = ctx->ctx_b.get();
333352

334353
const int hidden_size = hidden_states->ne[0];
335354
const int qlen = hidden_states->ne[1];
336355
const int head_size = hidden_size / num_attention_heads;
337356
const int rope_dim = head_size;
338-
const int mqa_scale = num_attention_heads / num_kv_heads;
339357
const int n_past = static_cast<int *>(KQ_pos->data)[0];
340358

341359
ggml_tensor *qkv = c_attn.forward(ctx, hidden_states); // [qlen, hidden + 2 * kv_hidden]
342360
ggml_tensor *query_layer =
343361
ggml_view_3d(gctx, qkv, head_size, num_attention_heads, qlen, head_size * ggml_element_size(qkv), qkv->nb[1],
344362
0); // [qlen, heads, head_size]
345-
query_layer = ggml_rope_inplace(gctx, query_layer, KQ_pos, rope_dim, 2, 0);
346-
query_layer = ggml_cont(gctx, ggml_permute(gctx, query_layer, 0, 2, 1, 3)); // [heads, qlen, head_size]
347-
query_layer = ggml_reshape_3d(gctx, query_layer, head_size, mqa_scale * qlen, num_kv_heads); // [kv_heads, mqa_scale * qlen, head_size]
363+
#ifdef GGML_USE_CUBLAS
364+
if (!ggml_is_contiguous(query_layer)) {
365+
query_layer = tensor_assign_buffers(ggml_cont(gctx, query_layer));
366+
}
367+
#endif
368+
query_layer = tensor_assign_buffers(ggml_rope_inplace(gctx, query_layer, KQ_pos, rope_dim, 2, n_ctx));
369+
query_layer = tensor_assign_buffers(ggml_cont(gctx, ggml_permute(gctx, query_layer, 0, 2, 1, 3))); // [heads, qlen, head_size]
348370

349371
ggml_tensor *key_layer =
350372
ggml_view_3d(gctx, qkv, head_size, num_kv_heads, qlen, head_size * ggml_element_size(qkv), qkv->nb[1],
351373
hidden_size * ggml_element_size(qkv)); // [qlen, kv_heads, head_size]
352-
key_layer = ggml_rope_inplace(gctx, key_layer, KQ_pos, rope_dim, 2, 0);
353-
key_layer = ggml_permute(gctx, key_layer, 0, 2, 1, 3); // [kv_heads, qlen, head_size]
374+
#ifdef GGML_USE_CUBLAS
375+
if (!ggml_is_contiguous(key_layer)) {
376+
key_layer = tensor_assign_buffers(ggml_cont(gctx, key_layer));
377+
}
378+
#endif
379+
key_layer = tensor_assign_buffers(ggml_rope_inplace(gctx, key_layer, KQ_pos, rope_dim, 2, n_ctx));
380+
key_layer = tensor_assign_buffers(ggml_permute(gctx, key_layer, 0, 2, 1, 3)); // [kv_heads, qlen, head_size]
354381

355382
ggml_tensor *value_layer =
356383
ggml_view_3d(gctx, qkv, head_size, num_kv_heads, qlen, head_size * ggml_element_size(qkv), qkv->nb[1],
357384
(hidden_size + head_size * num_kv_heads) * ggml_element_size(qkv)); // [qlen, kv_heads, head_size]
358-
value_layer = ggml_permute(gctx, value_layer, 1, 2, 0, 3); // [kv_heads, head_size, qlen]
385+
value_layer = tensor_assign_buffers(ggml_permute(gctx, value_layer, 1, 2, 0, 3)); // [kv_heads, head_size, qlen]
359386

360387
// store key & value to cache
361-
ggml_tensor *k_cache_view = ggml_view_3d(
362-
gctx, k_cache, head_size, qlen, num_kv_heads, k_cache->nb[1], k_cache->nb[2],
363-
n_past * head_size * ggml_element_size(k_cache)); // [kv_heads, qlen, head_size]
388+
ggml_tensor *k_cache_view = tensor_assign_buffers(
389+
ggml_view_3d(gctx, k_cache, head_size, qlen, num_kv_heads, k_cache->nb[1], k_cache->nb[2],
390+
n_past * head_size * ggml_element_size(k_cache))); // [kv_heads, qlen, head_size]
364391
ggml_build_forward_expand(&ctx->gf, ggml_cpy(gctx, key_layer, k_cache_view));
365-
ggml_tensor *v_cache_view = ggml_view_3d(
366-
gctx, v_cache, qlen, head_size, num_kv_heads, v_cache->nb[1], v_cache->nb[2],
367-
n_past * ggml_element_size(v_cache)); // [kv_heads, head_size, qlen]
392+
ggml_tensor *v_cache_view = tensor_assign_buffers(
393+
ggml_view_3d(gctx, v_cache, qlen, head_size, num_kv_heads, v_cache->nb[1], v_cache->nb[2],
394+
n_past * ggml_element_size(v_cache))); // [kv_heads, head_size, qlen]
368395
ggml_build_forward_expand(&ctx->gf, ggml_cpy(gctx, value_layer, v_cache_view));
369396

370397
// concat key & value with past kv
371-
key_layer = ggml_view_3d(gctx, k_cache, head_size, n_past + qlen, num_kv_heads,
372-
k_cache->nb[1], k_cache->nb[2],
373-
0); // [kv_heads, klen, head_size]
374-
value_layer = ggml_view_3d(gctx, v_cache, n_past + qlen, head_size, num_kv_heads,
375-
v_cache->nb[1], v_cache->nb[2],
376-
0); // [kv_heads, head_size, klen]
398+
key_layer = tensor_assign_buffers(
399+
ggml_view_3d(gctx, k_cache, head_size, n_past + qlen, num_kv_heads,
400+
k_cache->nb[1], k_cache->nb[2], 0)); // [kv_heads, klen, head_size]
401+
value_layer = tensor_assign_buffers(
402+
ggml_view_3d(gctx, v_cache, n_past + qlen, head_size, num_kv_heads,
403+
v_cache->nb[1], v_cache->nb[2], 0)); // [kv_heads, head_size, klen]
377404

378405
// attention
379-
ggml_tensor *attn_scores = ggml_mul_mat(gctx, key_layer, query_layer); // [kv_heads, mqa_scale * qlen, klen]
380-
attn_scores = ggml_scale_inplace(gctx, attn_scores, ggml_new_f32(gctx, 1.f / std::sqrt(head_size)));
406+
ggml_tensor *attn_scores =
407+
tensor_assign_buffers(ggml_mul_mat(gctx, key_layer, query_layer)); // [kv_heads, mqa_scale * qlen, klen]
408+
attn_scores = tensor_assign_buffers(
409+
ggml_scale_inplace(gctx, attn_scores, ggml_new_f32(gctx, 1.f / std::sqrt(head_size))));
381410
if (n_past == 0) {
382411
// build attention mask for context input
383-
attn_scores = ggml_reshape_3d(gctx, attn_scores, n_past + qlen, qlen,
384-
num_attention_heads); // [heads, qlen, klen]
385-
attn_scores = ggml_diag_mask_inf_inplace(gctx, attn_scores, n_past);
386-
attn_scores = ggml_reshape_3d(gctx, attn_scores, n_past + qlen, mqa_scale * qlen,
387-
num_kv_heads); // [kv_heads, mqa_scale * qlen, klen]
412+
attn_scores = tensor_assign_buffers(ggml_diag_mask_inf_inplace(gctx, attn_scores, n_past));
388413
}
389-
ggml_tensor *attn_probs = ggml_soft_max_inplace(gctx, attn_scores); // [kv_heads, mqa_scale * qlen, klen]
414+
ggml_tensor *attn_probs =
415+
tensor_assign_buffers(ggml_soft_max_inplace(gctx, attn_scores)); // [kv_heads, mqa_scale * qlen, klen]
390416

391-
ggml_tensor *context_layer = ggml_mul_mat(gctx, value_layer, attn_probs); // [kv_heads, mqa_scale * qlen, head_size]
392-
context_layer = ggml_reshape_3d(gctx, context_layer, head_size, qlen,
393-
num_attention_heads); // [heads, qlen, head_size]
394-
context_layer = ggml_cont(gctx, ggml_permute(gctx, context_layer, 0, 2, 1, 3)); // [qlen, heads, head_size]
395-
context_layer = ggml_reshape_2d(gctx, context_layer, hidden_size, qlen); // [qlen, hidden]
417+
ggml_tensor *context_layer = tensor_assign_buffers(
418+
ggml_mul_mat(gctx, value_layer, attn_probs)); // [kv_heads, mqa_scale * qlen, head_size]
419+
context_layer = tensor_assign_buffers(
420+
ggml_cont(gctx, ggml_permute(gctx, context_layer, 0, 2, 1, 3))); // [qlen, heads, head_size]
421+
context_layer = tensor_assign_buffers(
422+
ggml_reshape_2d(gctx, context_layer, hidden_size, qlen)); // [qlen, hidden]
396423

397424
ggml_tensor *attn_output = c_proj.forward(ctx, context_layer);
398425
return attn_output;
@@ -402,26 +429,26 @@ auto QwenMLP::forward(ModelContext *ctx, ggml_tensor *hidden_states) const -> gg
402429
ggml_context *gctx = ctx->ctx_b.get();
403430

404431
ggml_tensor *a2 = w2.forward(ctx, hidden_states);
405-
a2 = ggml_silu_inplace(gctx, a2);
432+
a2 = tensor_assign_buffers(ggml_silu_inplace(gctx, a2));
406433
ggml_tensor *a1 = w1.forward(ctx, hidden_states);
407434

408-
ggml_tensor *output = ggml_mul_inplace(gctx, a2, a1);
435+
ggml_tensor *output = tensor_assign_buffers(ggml_mul_inplace(gctx, a2, a1));
409436
output = c_proj.forward(ctx, output);
410437
return output;
411438
}
412439

413-
auto QwenBlock::forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos) const -> ggml_tensor * {
440+
auto QwenBlock::forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos, int n_ctx) const -> ggml_tensor * {
414441
ggml_context *gctx = ctx->ctx_b.get();
415442

416443
ggml_tensor *residual = hidden_states;
417444
hidden_states = ln_1.forward(ctx, hidden_states, 1e-6f);
418-
hidden_states = attn.forward(ctx, hidden_states, KQ_pos);
419-
hidden_states = ggml_add_inplace(gctx, hidden_states, residual);
445+
hidden_states = attn.forward(ctx, hidden_states, KQ_pos, n_ctx);
446+
hidden_states = tensor_assign_buffers(ggml_add_inplace(gctx, hidden_states, residual));
420447

421448
residual = hidden_states;
422449
hidden_states = ln_2.forward(ctx, hidden_states, 1e-6f);
423450
hidden_states = mlp.forward(ctx, hidden_states);
424-
hidden_states = ggml_add_inplace(gctx, hidden_states, residual);
451+
hidden_states = tensor_assign_buffers(ggml_add_inplace(gctx, hidden_states, residual));
425452

426453
return hidden_states;
427454
}
@@ -435,12 +462,12 @@ QwenModel::QwenModel(ModelContext *ctx, const QwenConfig &config)
435462
}
436463
}
437464

438-
auto QwenModel::forward(ModelContext *ctx, ggml_tensor *input_ids, ggml_tensor *KQ_pos) const -> ggml_tensor * {
465+
auto QwenModel::forward(ModelContext *ctx, ggml_tensor *input_ids, ggml_tensor *KQ_pos, int n_ctx) const -> ggml_tensor * {
439466
ggml_context *gctx = ctx->ctx_b.get();
440467
ggml_tensor *hidden_states = wte.forward(ctx, input_ids);
441468
for (const auto &layer : layers) {
442469
ggml_set_scratch(gctx, ctx->scratch);
443-
hidden_states = layer.forward(ctx, hidden_states, KQ_pos);
470+
hidden_states = layer.forward(ctx, hidden_states, KQ_pos, n_ctx);
444471
}
445472
ggml_scratch empty_scratch = {0, 0, nullptr};
446473
ggml_set_scratch(gctx, empty_scratch);
@@ -455,17 +482,25 @@ auto get_num_physical_cores() -> int {
455482
}
456483

457484
auto get_default_num_threads() -> int {
485+
#ifdef GGML_USE_CUBLAS
486+
return 1;
487+
#else
458488
return std::min(get_num_physical_cores(), 16);
489+
#endif
459490
}
460491

461492
QwenForCausalLM::QwenForCausalLM(const QwenConfig &config)
462493
: config(config) {
463494
ctx_.compute_buffer.resize(MEM_SIZE);
464495
ctx_.scratch_buffer.resize(SCRATCH_SIZE);
465496
ctx_.scratch = {0, ctx_.scratch_buffer.size(), ctx_.scratch_buffer.data()};
497+
#ifdef GGML_USE_CUBLAS
498+
ggml_cuda_set_scratch_size(SCRATCH_SIZE);
499+
#endif
466500
constexpr size_t tensor_ovhd = GGML_TENSOR_SIZE + GGML_OBJECT_SIZE;
467501
const size_t ctx_w_size = (3 + config.num_hidden_layers * 8) * tensor_ovhd;
468-
const size_t ctx_kv_size = 2 * config.num_hidden_layers * (config.max_length * config.hidden_size / config.num_attention_heads * config.num_kv_heads * ggml_type_size(GGML_TYPE_F16) + tensor_ovhd);
502+
const size_t ctx_kv_size = 2 * config.num_hidden_layers *
503+
(config.max_length * config.hidden_size / config.num_attention_heads * config.num_kv_heads * ggml_type_size(GGML_TYPE_F16) + tensor_ovhd);
469504
ctx_.dtype = config.dtype;
470505
ctx_.ctx_w = make_unique_ggml_context(ctx_w_size, nullptr, true);
471506
ctx_.ctx_kv = make_unique_ggml_context(ctx_kv_size + 1 * MB, nullptr, false);
@@ -537,9 +572,15 @@ auto QwenForCausalLM::generate_next_token(
537572
for (int i = 0; i < curr_input_ids_size; ++i) {
538573
data[i] = n_past + i;
539574
}
575+
if (KQ_pos) {
576+
tensor_to_device(KQ_pos);
577+
}
540578

541579
ggml_tensor *lm_logits = forward(&ctx_, curr_input_ids, KQ_pos, n_ctx);
542580
lm_logits->backend = GGML_BACKEND_CPU;
581+
if (KQ_pos) {
582+
tensor_to_cpu(KQ_pos);
583+
}
543584

544585
ggml_build_forward_expand(&ctx_.gf, lm_logits);
545586
ggml_graph_compute_helper(ctx_.work_buffer, &ctx_.gf, n_threads);
@@ -731,12 +772,12 @@ auto QwenForCausalLM::forward(
731772
ggml_tensor *KQ_pos,
732773
int n_ctx
733774
) const -> ggml_tensor * {
734-
ggml_tensor *transformer_outputs = transformer.forward(ctx, input_ids, KQ_pos);
775+
ggml_tensor *transformer_outputs = transformer.forward(ctx, input_ids, KQ_pos, n_ctx);
735776
// NOTE: only compute next_token_logits for the last token
736777
if (input_ids->ne[0] > 1) {
737-
transformer_outputs =
778+
transformer_outputs = tensor_assign_buffers(
738779
ggml_view_1d(ctx->ctx_b.get(), transformer_outputs, config.hidden_size,
739-
(input_ids->ne[0] - 1) * config.hidden_size * ggml_element_size(transformer_outputs));
780+
(input_ids->ne[0] - 1) * config.hidden_size * ggml_element_size(transformer_outputs)));
740781
}
741782
ggml_tensor *lm_logits = lm_head.forward(ctx, transformer_outputs);
742783
return lm_logits;

qwen.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
#include <unordered_map>
99
#include <vector>
1010

11+
#ifdef GGML_USE_CUBLAS
12+
#include <ggml-cuda.h>
13+
#endif
14+
1115
namespace qwen {
1216

1317
class QwenTokenizer;
@@ -33,6 +37,8 @@ class LogMessageFatal {
3337
if (!(cond)) \
3438
QWEN_THROW << "check failed (" #cond ") "
3539

40+
ggml_tensor *tensor_assign_buffers(ggml_tensor *tensor);
41+
3642
auto tensor_to_device(ggml_tensor *tensor) -> ggml_tensor *;
3743

3844
auto tensor_to_cpu(ggml_tensor *tensor) -> ggml_tensor *;
@@ -289,7 +295,7 @@ class QwenAttention {
289295
QwenAttention() : num_attention_heads(0), num_kv_heads(0) {}
290296
QwenAttention(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length);
291297

292-
auto forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos) const -> ggml_tensor *;
298+
auto forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos, int n_ctx) const -> ggml_tensor *;
293299

294300
int num_attention_heads;
295301
int num_kv_heads;
@@ -323,7 +329,7 @@ class QwenBlock {
323329
ln_2(ctx, hidden_size, false),
324330
mlp(ctx, hidden_size, intermediate_size) {}
325331

326-
auto forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos) const -> ggml_tensor *;
332+
auto forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos, int n_ctx) const -> ggml_tensor *;
327333

328334
RMSNorm ln_1;
329335
QwenAttention attn;
@@ -336,7 +342,7 @@ class QwenModel {
336342
QwenModel() = default;
337343
QwenModel(ModelContext *ctx, const QwenConfig &config);
338344

339-
auto forward(ModelContext *ctx, ggml_tensor *input_ids, ggml_tensor *KQ_pos) const -> ggml_tensor *;
345+
auto forward(ModelContext *ctx, ggml_tensor *input_ids, ggml_tensor *KQ_pos, int n_ctx) const -> ggml_tensor *;
340346

341347
Embedding wte;
342348
std::vector<QwenBlock> layers;

0 commit comments

Comments
 (0)