diff --git a/README.md b/README.md index ce678f0c36bd9..a56a60049c08d 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ### Recent API changes +- [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122 - [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017 - [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328 - [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796 @@ -630,6 +631,15 @@ Building the program with BLAS support may lead to some performance improvements - #### Vulkan +> [!WARNING] +> +> Vulkan support has been broken in https://github.com/ggerganov/llama.cpp/pull/6122 +> due to relying on `GGML_OP_GET_ROWS` which is not yet properly supported by the Vulkan backend, +> but should be fixed relatively soon (possibly in https://github.com/ggerganov/llama.cpp/pull/6155 +> (ref: https://github.com/ggerganov/llama.cpp/pull/6122#issuecomment-2015327635)). +> +> Meanwhile, if you want to use the Vulkan backend, you should use the commit right before the breaking change, https://github.com/ggerganov/llama.cpp/commit/55c1b2a3bbd470e9e2a3a0618b92cf64a885f806 + **With docker**: You don't need to install Vulkan SDK. It will be installed inside the container. diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 264e73f4e66f9..12d34462b78ec 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -424,6 +424,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); } + // TODO: use batch.logits to save computations instead of relying on logits_all == true if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index a2ef0fb039c3f..f66c91013eaeb 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -132,7 +132,6 @@ int main(int argc, char ** argv) { llama_context * ctx = NULL; // load the target model - params.logits_all = true; std::tie(model, ctx) = llama_init_from_gpt_params(params); // load the prompts from an external file if there are any diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index d766aef6ac1b1..c70385c62bb07 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -380,6 +380,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const int batch_size = std::min(end - batch_start, n_batch); //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); + // TODO: use llama_batch.logits instead of relying on logits_all == true if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { //fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; @@ -552,6 +553,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); + int n_outputs = 0; + batch.n_tokens = 0; for (int seq = 0; seq < n_seq_batch; seq++) { int seq_start = batch_start + seq*n_ctx; @@ -566,11 +569,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par for (int k = 0; k < batch_size; ++k) { const int idx = seq*n_ctx + k; - batch.token[idx] = tokens[seq_start + k]; - batch.pos[idx] = j*n_batch + k; - batch.n_seq_id[idx] = 1; - batch.seq_id[idx][0] = seq; - batch.logits[idx] = batch.pos[idx] >= first ? 1 : 0; + batch.token [idx] = tokens[seq_start + k]; + batch.pos [idx] = j*n_batch + k; + batch.n_seq_id[idx] = 1; + batch.seq_id [idx][0] = seq; + batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; + + n_outputs += batch.logits[idx] != 0; } batch.n_tokens += batch_size; @@ -583,9 +588,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par return {tokens, -1, logit_history, prob_history}; } - if (num_batches > 1) { + if (num_batches > 1 && n_outputs > 0) { const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab); } } @@ -604,14 +609,15 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } for (int seq = 0; seq < n_seq_batch; seq++) { - const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx); + const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); + llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first; if (!params.logits_file.empty()) { - process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, + process_logits(logits_stream, n_vocab, all_logits, tokens_data, n_ctx - 1 - first, workers, log_probs, nll, nll2); } else { - process_logits(n_vocab, all_logits + first*n_vocab, + process_logits(n_vocab, all_logits, tokens_data, n_ctx - 1 - first, workers, nll, nll2, logit_history.data() + start + seq*n_ctx + first, @@ -652,6 +658,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector & batch_logits, int32_t n_batch, int32_t n_vocab) { + int prev_outputs = 0; for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); @@ -672,7 +679,14 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< return false; } - memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float)); + int n_outputs = 0; + for (int i = 0; i < n_tokens; ++i) { + n_outputs += batch_view.logits[i] != 0; + } + + memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float)); + + prev_outputs += n_outputs; } return true; @@ -779,7 +793,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { size_t ending_logprob_count[4]; double ending_logprob[4]; - size_t i_batch; // starting index in the llama_batch + size_t i_logits; // starting index of logits in the llama_batch size_t common_prefix; // max number of initial tokens that are the same in all sentences size_t required_tokens; // needed number of tokens to evaluate all 4 endings std::vector seq_tokens[4]; @@ -844,9 +858,10 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); + llama_batch batch = llama_batch_init(n_ctx, 0, 4); std::vector tok_logits(n_vocab); + // TODO: this could be made smaller; it's currently the worst-case size std::vector batch_logits(n_vocab*n_ctx); std::vector> eval_pairs; @@ -857,16 +872,17 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { int n_cur = 0; size_t i1 = i0; - size_t i_batch = 0; // this tells us where in `llama_batch` we are currently + size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch llama_batch_clear(batch); // batch as much tasks as possible into the available context - // each task has 4 unique seuqnce ids - one for each ending + // each task has 4 unique sequence ids - one for each ending // the common prefix is shared among the 4 sequences to save tokens // we extract logits only from the last common token and from all ending tokens of each sequence while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) { auto & hs_cur = hs_data[i1]; + int n_logits = 0; const int s0 = 4*(i1 - i0); if (s0 + 4 > max_seq) { @@ -874,18 +890,23 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } for (size_t i = 0; i < hs_cur.common_prefix; ++i) { - llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); + llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); } batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + n_logits += 1; for (int s = 0; s < 4; ++s) { - for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) { - llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true); + const size_t seq_tokens_size = hs_cur.seq_tokens[s].size(); + // TODO: don't evaluate the last token of each sequence + for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { + const bool needs_logits = i < seq_tokens_size - 1; + llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); + n_logits += needs_logits; } } - hs_cur.i_batch = i_batch; - i_batch += hs_cur.required_tokens; + hs_cur.i_logits = i_logits; + i_logits += n_logits; n_cur += hs_data[i1].required_tokens; if (++i1 == hs_task_count) { @@ -911,12 +932,11 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { eval_pairs.clear(); for (size_t i = i0; i < i1; ++i) { auto & hs_cur = hs_data[i]; - size_t li = hs_cur.common_prefix; + size_t li = 1; // skip the last logit of the common prefix (computed separately below) for (int s = 0; s < 4; ++s) { for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) { - eval_pairs.emplace_back(hs_cur.i_batch + li++, hs_cur.seq_tokens[s][j + 1]); + eval_pairs.emplace_back(hs_cur.i_logits + li++, hs_cur.seq_tokens[s][j + 1]); } - ++li; } } // Then we do the actual calculation @@ -928,7 +948,8 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { for (size_t i = i0; i < i1; ++i) { auto & hs_cur = hs_data[i]; - std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float)); + // get the logits of the last token of the common prefix + std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*hs_cur.i_logits, n_vocab*sizeof(float)); const auto first_probs = softmax(tok_logits); @@ -978,7 +999,7 @@ struct winogrande_entry { std::array choices; int answer; - size_t i_batch; + size_t i_logits; size_t common_prefix; size_t required_tokens; size_t n_base1; // number of tokens for context + choice 1 @@ -1104,6 +1125,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { task.common_prefix++; } + // TODO: the last token of each of the sequences don't need to be evaluated task.required_tokens = task.common_prefix + task.seq_tokens[0].size() - task.common_prefix + task.seq_tokens[1].size() - task.common_prefix; @@ -1121,9 +1143,10 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { const int max_tasks_per_batch = 128; const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); + llama_batch batch = llama_batch_init(n_ctx, 0, 2); std::vector tok_logits(n_vocab); + // TODO: this could be made smaller; it's currently the worst-case size std::vector batch_logits(n_vocab*n_ctx); std::vector> eval_pairs; @@ -1137,29 +1160,33 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { int n_cur = 0; size_t i1 = i0; - size_t i_batch = 0; + size_t i_logits = 0; llama_batch_clear(batch); while (n_cur + (int) data[i1].required_tokens <= n_ctx) { + int n_logits = 0; const int s0 = 2*(i1 - i0); if (s0 + 2 > max_seq) { break; } for (size_t i = 0; i < data[i1].common_prefix; ++i) { - llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false); + llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); } batch.logits[batch.n_tokens - 1] = true; + n_logits += 1; for (int s = 0; s < 2; ++s) { + // TODO: end before the last token, no need to predict past the end of the sequences for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true); + n_logits += 1; } } - data[i1].i_batch = i_batch; - i_batch += data[i1].required_tokens; + data[i1].i_logits = i_logits; + i_logits += n_logits; n_cur += data[i1].required_tokens; if (++i1 == data.size()) { @@ -1190,15 +1217,16 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix; const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0; - size_t li = n_base1 - 1; + size_t li = n_base1 - task.common_prefix; for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) { - eval_pairs.emplace_back(task.i_batch + li++, task.seq_tokens[0][j+1]); + eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[0][j+1]); } const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix; const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0; - li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1; + // FIXME: this uses the wrong first logits when not skipping the choice word + li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix; for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) { - eval_pairs.emplace_back(task.i_batch + li++, task.seq_tokens[1][j+1]); + eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[1][j+1]); } } compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results); @@ -1287,7 +1315,7 @@ struct multiple_choice_task { } // For evaluation - size_t i_batch; // starting index in the llama_batch + size_t i_logits; // starting index of logits in the llama_batch size_t common_prefix; // max number of initial tokens that are the same in all sentences size_t required_tokens; // needed number of tokens to evaluate all answers std::vector> seq_tokens; @@ -1366,7 +1394,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params std::vector task_pos(n_task); strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t)); if (strstream.fail()) { - printf("%s: failed to raad task positions from prompt\n", __func__); + printf("%s: failed to read task positions from prompt\n", __func__); return; } @@ -1447,7 +1475,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params return; } } else { - int n_dot = n_task/100; + int n_dot = std::max((int) n_task/100, 1); int i_task = 0; for (auto& task : tasks) { ++i_task; @@ -1491,17 +1519,18 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params int n_cur = 0; size_t i1 = i0; - size_t i_batch = 0; // this tells us where in `llama_batch` we are currently + size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch llama_batch_clear(batch); // batch as much tasks as possible into the available context - // each task has 4 unique seuqnce ids - one for each ending + // each task has 4 unique sequence ids - one for each ending // the common prefix is shared among the 4 sequences to save tokens // we extract logits only from the last common token and from all ending tokens of each sequence int s0 = 0; while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) { auto& cur_task = tasks[i1]; + int n_logits = 0; int num_answers = cur_task.seq_tokens.size(); if (s0 + num_answers > max_seq) { @@ -1518,17 +1547,22 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); } batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + n_logits += 1; for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { - for (size_t i = cur_task.common_prefix; i < cur_task.seq_tokens[s].size(); ++i) { - llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, true); + const size_t seq_tokens_size = cur_task.seq_tokens[s].size(); + // TODO: don't evaluate the last token of each sequence + for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { + const bool needs_logits = i < seq_tokens_size - 1; + llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); + n_logits += needs_logits; } } s0 += num_answers; - cur_task.i_batch = i_batch; - i_batch += cur_task.required_tokens; + cur_task.i_logits = i_logits; + i_logits += n_logits; n_cur += cur_task.required_tokens; if (++i1 == tasks.size()) { @@ -1554,12 +1588,11 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params eval_pairs.clear(); for (size_t i = i0; i < i1; ++i) { auto& cur_task = tasks[i]; - size_t li = cur_task.common_prefix; + size_t li = 1; // skip the last logit of the common prefix (computed separately below) for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) { - eval_pairs.emplace_back(cur_task.i_batch + li++, cur_task.seq_tokens[s][j + 1]); + eval_pairs.emplace_back(cur_task.i_logits + li++, cur_task.seq_tokens[s][j + 1]); } - ++li; } } // Then we do the actual calculation @@ -1578,7 +1611,8 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params //} //printf("\n common_prefix: %zu\n", cur_task.common_prefix); - std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(cur_task.i_batch + cur_task.common_prefix - 1), n_vocab*sizeof(float)); + // get the logits of the last token of the common prefix + std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*cur_task.i_logits, n_vocab*sizeof(float)); const auto first_probs = softmax(tok_logits); @@ -1730,6 +1764,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); } + // TODO: use llama_batch.logits instead of relying on logits_all == true if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 526de596e34c0..53ad9239efb99 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -747,7 +747,8 @@ struct server_context { { const int32_t n_batch = llama_n_batch(ctx); - batch = llama_batch_init(n_batch, 0, params.n_parallel); + // only a single seq_id per token is needed + batch = llama_batch_init(n_batch, 0, 1); } metrics.init(); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 8b31b678a6849..6e0815b369986 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -65,7 +65,6 @@ int main(int argc, char ** argv) { llama_context * ctx_dft = NULL; // load the target model - params.logits_all = true; std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params); // load the draft model diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 48232b6e18d6c..be8e33a56c40f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2505,7 +2505,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 81dd5067864ce..407062e6fd476 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1430,6 +1430,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml struct ggml_tensor * dst = gf->nodes[i]; GGML_ASSERT(dst->data != nullptr); + if (ggml_is_empty(dst)) { + continue; + } + switch (dst->op) { case GGML_OP_NONE: case GGML_OP_RESHAPE: diff --git a/ggml-metal.m b/ggml-metal.m index cbe22aa3792b4..a08abbc291802 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -847,6 +847,10 @@ static enum ggml_status ggml_metal_graph_compute( struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * dst = gf->nodes[i]; + if (ggml_is_empty(dst)) { + continue; + } + switch (dst->op) { case GGML_OP_NONE: case GGML_OP_RESHAPE: diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index aa73d67df84b0..b3f8b7eaf0a3b 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -2234,6 +2234,11 @@ static ggml_backend_buffer_type_t ggml_backend_opencl_get_default_buffer_type(gg static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) { for (int i = 0; i < graph->n_nodes; ++i) { ggml_tensor * node = graph->nodes[i]; + + if (ggml_is_empty(node)) { + continue; + } + switch (node->op) { case GGML_OP_MUL_MAT: ggml_cl_mul_mat(node->src[0], node->src[1], node, nullptr, 0); diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index fc4d2964ccac9..789ba97bfba39 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -16973,7 +16973,7 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back params.ith = 0; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } #ifndef NDEBUG diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index cbceaa19fbacd..521a1314b3565 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -5566,7 +5566,7 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } diff --git a/ggml.c b/ggml.c index a86b41c158558..eb469d0f7953d 100644 --- a/ggml.c +++ b/ggml.c @@ -2607,6 +2607,16 @@ static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } +GGML_CALL bool ggml_is_empty(const struct ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] == 0) { + // empty if any dimension has no elements + return true; + } + } + return false; +} + bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -2621,7 +2631,7 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - return + return ggml_is_empty(t0) ? ggml_is_empty(t1) : (t1->ne[0]%t0->ne[0] == 0) && (t1->ne[1]%t0->ne[1] == 0) && (t1->ne[2]%t0->ne[2] == 0) && @@ -16114,7 +16124,7 @@ static void ggml_compute_forward_cross_entropy_loss_back( static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { GGML_ASSERT(params); - if (tensor->op == GGML_OP_NONE) { + if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { return; } @@ -17983,6 +17993,12 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_threads) { int n_tasks = 0; + if (ggml_is_empty(node)) { + // no need to multi-thread a no-op + n_tasks = 1; + return n_tasks; + } + switch (node->op) { case GGML_OP_CPY: case GGML_OP_DUP: diff --git a/ggml.h b/ggml.h index 425c9b6ab2d6d..5d4a4ceb65c7e 100644 --- a/ggml.h +++ b/ggml.h @@ -750,6 +750,7 @@ extern "C" { GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor); GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor); GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor); + GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor); GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor); GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor); diff --git a/llama.cpp b/llama.cpp index 68c360c7d8036..22db79d6341c0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1777,6 +1777,7 @@ struct llama_cparams { uint32_t n_ctx; // context size used during inference uint32_t n_batch; uint32_t n_ubatch; + uint32_t n_seq_max; uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing @@ -2139,20 +2140,20 @@ struct llama_context { // host buffer for the model output (logits and embeddings) ggml_backend_buffer_t buf_output = nullptr; - // decode output (2-dimensional array: [n_tokens][n_vocab]) - size_t logits_size = 0; - float * logits = nullptr; + // decode output (2-dimensional array: [n_outputs][n_vocab]) + size_t logits_size = 0; // capacity (of floats) for logits + float * logits = nullptr; + + std::vector output_ids; // map batch token positions to ids of the logits and embd buffers + size_t output_size = 0; // capacity (of tokens positions) for the output buffers + int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch -#ifndef NDEBUG - // guard against access to unset logits - std::vector logits_valid; -#endif bool logits_all = false; - // embeddings output (2-dimensional array: [n_tokens][n_embd]) + // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE - size_t embd_size = 0; - float * embd = nullptr; + size_t embd_size = 0; // capacity (of floats) for embeddings + float * embd = nullptr; // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE @@ -2169,14 +2170,15 @@ struct llama_context { struct ggml_tensor * inp_tokens; // I32 [n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] struct ggml_tensor * inp_pos; // I32 [n_batch] + struct ggml_tensor * inp_out_ids; // I32 [n_outputs] struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] - struct ggml_tensor * inp_KQ_pos; // F32 [kv_size] + struct ggml_tensor * inp_KQ_pos; // F32 [n_kv] struct ggml_tensor * inp_K_shift; // I32 [kv_size] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, kv_size] - struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch] + struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] + struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] // control vectors struct llama_control_vector cvec; @@ -5846,7 +5848,8 @@ struct llm_build_context { const float norm_rms_eps; const int32_t n_tokens; - const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx) + const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) + const int32_t n_outputs; const int32_t kv_head; // index of where we store new KV data in the cache const int32_t n_orig_ctx; @@ -5893,6 +5896,7 @@ struct llm_build_context { norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), + n_outputs (worst_case ? n_tokens : lctx.n_outputs), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), pooling_type (cparams.pooling_type), @@ -5914,6 +5918,7 @@ struct llm_build_context { lctx.inp_tokens = nullptr; lctx.inp_embd = nullptr; lctx.inp_pos = nullptr; + lctx.inp_out_ids = nullptr; lctx.inp_KQ_mask = nullptr; lctx.inp_KQ_pos = nullptr; lctx.inp_K_shift = nullptr; @@ -6037,6 +6042,13 @@ struct llm_build_context { return lctx.inp_pos; } + struct ggml_tensor * build_inp_out_ids() { + lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs); + cb(lctx.inp_out_ids, "inp_out_ids", -1); + ggml_set_input(lctx.inp_out_ids); + return lctx.inp_out_ids; + } + struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { if (causal) { lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens); @@ -6093,6 +6105,9 @@ struct llm_build_context { struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -6160,6 +6175,14 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -6339,6 +6362,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -6454,6 +6484,14 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids); + } + struct ggml_tensor * ffn_inp = cur; // feed forward @@ -6497,6 +6535,9 @@ struct llm_build_context { struct ggml_cgraph * build_grok() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -6568,6 +6609,14 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + // Grok // if attn_out_norm is present then apply it before adding the input if (model.layers[il].attn_out_norm) { @@ -6745,6 +6794,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + // add the input struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -6942,6 +6998,13 @@ struct llm_build_context { Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur); cb(ffn_inp, "ffn_inp", il); @@ -7031,6 +7094,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -7188,6 +7258,13 @@ struct llm_build_context { } cb(cur, "kqv_out", il); + if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + // re-add the layer input cur = ggml_add(ctx0, cur, inpL); @@ -7310,6 +7387,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + // Add the input struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -7408,6 +7492,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + // Add the input struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -7521,6 +7612,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -7627,6 +7725,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -7739,6 +7844,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -7857,6 +7969,14 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids); + } + // FF { ffn_output = llm_build_ffn(ctx0, attn_norm_output, @@ -7954,6 +8074,14 @@ struct llm_build_context { cur = attention_norm; + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + // feed-forward network { cur = llm_build_ffn(ctx0, cur, @@ -8046,6 +8174,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + // add the input struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -8146,6 +8281,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + // add the input struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -8255,6 +8397,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -8365,6 +8514,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -8488,6 +8644,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + // scale_res - scale the hidden states for residual connection const float scale_res = scale_depth/sqrtf(float(n_layer)); cur = ggml_scale(ctx0, cur, scale_res); @@ -8602,6 +8765,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); cb(sa_out, "sa_out", il); @@ -8714,6 +8884,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -8861,6 +9038,15 @@ struct llm_build_context { struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + x = ggml_get_rows(ctx0, x, inp_out_ids); + y = ggml_get_rows(ctx0, y, inp_out_ids); + z = ggml_get_rows(ctx0, z, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); @@ -8963,6 +9149,13 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + struct ggml_tensor * attn_out = cur; // feed-forward network @@ -9260,9 +9453,39 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer)); + int32_t * data = (int32_t *) lctx.inp_out_ids->data; + + if (lctx.n_outputs == n_tokens) { + for (int i = 0; i < n_tokens; ++i) { + data[i] = i; + } + } else if (batch.logits) { + int32_t n_outputs = 0; + for (int i = 0; i < n_tokens; ++i) { + if (batch.logits[i]) { + data[n_outputs++] = i; + } + } + // the graph needs to have been passed the correct number of outputs + GGML_ASSERT(lctx.n_outputs == n_outputs); + } else if (lctx.n_outputs == 1) { + // only keep last output + data[0] = n_tokens - 1; + } else { + GGML_ASSERT(lctx.n_outputs == 0); + } + } + GGML_ASSERT( + // (!a || b) is a logical implication (a -> b) + // !hparams.causal_attn -> !cparams.causal_attn (hparams.causal_attn || !cparams.causal_attn) && - "non-causal attention with generative models is not supported" + "causal attention with embedding models is not supported" ); if (lctx.inp_KQ_mask) { @@ -9441,6 +9664,74 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } +// Make sure enough space is available for outputs. +// Returns max number of outputs for which space was reserved. +static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { + const auto & cparams = lctx.cparams; + const auto & hparams = lctx.model.hparams; + + const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max); + + const auto n_batch = cparams.n_batch; + const auto n_vocab = hparams.n_vocab; + const auto n_embd = hparams.n_embd; + + // TODO: use a per-batch flag for logits presence instead + const bool has_logits = cparams.causal_attn; + const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); + + const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; + const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; + + if (lctx.output_ids.empty()) { + // init, never resized afterwards + lctx.output_ids.resize(n_batch); + } + + const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0; + const size_t new_size = (logits_size + embd_size) * sizeof(float); + + // alloc only when more than the current capacity is required + // TODO: also consider shrinking the buffer + if (!lctx.buf_output || prev_size < new_size) { + if (lctx.buf_output) { +#ifndef NDEBUG + // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) + LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); +#endif + ggml_backend_buffer_free(lctx.buf_output); + lctx.buf_output = nullptr; + lctx.logits = nullptr; + lctx.embd = nullptr; + } + + lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), new_size); + if (lctx.buf_output == nullptr) { + LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); + return 0; + } + } + + float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output); + + lctx.logits = has_logits ? output_base : nullptr; + lctx.embd = has_embd ? output_base + logits_size : nullptr; + + lctx.output_size = n_outputs_max; + lctx.logits_size = logits_size; + lctx.embd_size = embd_size; + + // set all ids as invalid (negative) + std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1); + + ggml_backend_buffer_clear(lctx.buf_output, 0); + + lctx.n_outputs = 0; + + return n_outputs_max; +} + + static void llama_graph_compute( llama_context & lctx, ggml_cgraph * gf, @@ -9516,16 +9807,8 @@ static int llama_decode_internal( const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; - - auto * logits_out = lctx.logits; - -#ifndef NDEBUG - auto & logits_valid = lctx.logits_valid; - logits_valid.clear(); - logits_valid.resize(n_tokens_all); - - memset(logits_out, 0, lctx.logits_size*sizeof(float)); -#endif + uint32_t n_outputs = 0; + uint32_t n_outputs_prev = 0; const auto n_ubatch = cparams.n_ubatch; @@ -9534,6 +9817,38 @@ static int llama_decode_internal( std::vector seq_id_arr; std::vector> seq_id; + // count outputs + if (batch_all.logits) { + for (uint32_t i = 0; i < n_tokens_all; ++i) { + n_outputs += batch_all.logits[i] != 0; + } + } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) { + n_outputs = n_tokens_all; + } else { + // keep last output only + n_outputs = 1; + } + + // reserve output buffer + if (llama_output_reserve(lctx, n_outputs) < n_outputs) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs); + return -2; + }; + + // set output mappings + if (batch_all.logits) { + int32_t i_logits = 0; + for (uint32_t i = 0; i < n_tokens_all; ++i) { + if (batch_all.logits[i]) { + lctx.output_ids[i] = i_logits++; + } + } + } else { + for (uint32_t i = 0; i < n_outputs; ++i) { + lctx.output_ids[i] = i; + } + } + for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); llama_batch u_batch = { @@ -9549,6 +9864,27 @@ static int llama_decode_internal( /* .all_seq_id = */ batch_all.all_seq_id, }; + // count the outputs in this u_batch + { + int32_t n_outputs_new = 0; + + if (u_batch.logits) { + for (uint32_t i = 0; i < n_tokens; i++) { + n_outputs_new += u_batch.logits[i] != 0; + } + } else if (n_outputs == n_tokens_all) { + n_outputs_new = n_tokens; + } else { + // keep last output only + if (cur_token + n_tokens >= n_tokens_all) { + n_outputs_new = 1; + } + } + + // needs to happen before the graph is built + lctx.n_outputs = n_outputs_new; + } + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; GGML_ASSERT(n_threads > 0); @@ -9612,23 +9948,37 @@ static int llama_decode_internal( struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2]; - if (!hparams.causal_attn) { + if (lctx.n_outputs == 0) { + // no output + res = nullptr; + embd = nullptr; + } else if (!hparams.causal_attn) { res = nullptr; // do not extract logits for embedding models such as BERT // token or sequence embeddings embd = gf->nodes[gf->n_nodes - 1]; GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0); - } else { - if (strcmp(res->name, "result_output") == 0) { - // the token embeddings could be the second to last tensor, or the third to last tensor - if (strcmp(embd->name, "result_norm") != 0) { - embd = gf->nodes[gf->n_nodes - 3]; - GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); - } - } else { - GGML_ASSERT(false && "missing result_output tensor"); + } else if (cparams.embeddings) { + // the embeddings could be in the second to last tensor, or any of the previous tensors + int i_embd = gf->n_nodes - 2; + for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) { + i_embd = gf->n_nodes - i; + if (i_embd < 0) { break; } + embd = gf->nodes[i_embd]; + } + GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor"); + + // TODO: use a per-batch flag to know when to skip logits while keeping embeddings + if (!cparams.causal_attn) { + res = nullptr; // do not extract logits when not needed + // skip computing logits + // TODO: is this safe? + gf->n_nodes = i_embd + 1; } + } else { + embd = nullptr; // do not extract embeddings when not needed + GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -9671,50 +10021,23 @@ static int llama_decode_internal( //} // extract logits - // TODO: do not compute and extract logits if only embeddings are needed - // update the graphs to skip "result_output" if logits are not needed if (res) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); GGML_ASSERT(backend_res != nullptr); - if (u_batch.logits) { - int32_t i_first = -1; - for (uint32_t i = 0; i < n_tokens; i++) { - if (u_batch.logits[i] && i_first == -1) { - i_first = (int32_t) i; - } - if (u_batch.logits[i] == 0 || i == n_tokens - 1) { - if (i_first != -1) { - int i_last = u_batch.logits[i] == 0 ? i : i + 1; - // extract logits for the range [i_first, i_last) - // group the requests to minimize the number of calls to the backend - ggml_backend_tensor_get_async(backend_res, res, - logits_out + n_vocab*(cur_token + i_first), - i_first*n_vocab*sizeof(float), - (i_last - i_first)*n_vocab*sizeof(float)); - i_first = -1; - } - } -#ifndef NDEBUG - logits_valid[cur_token + i] = u_batch.logits[i] != 0;; -#endif - } - } else if (lctx.logits_all) { - ggml_backend_tensor_get_async(backend_res, res, logits_out + n_vocab*cur_token, 0, n_vocab*n_tokens*sizeof(float)); -#ifndef NDEBUG - std::fill(logits_valid.begin() + cur_token, logits_valid.begin() + cur_token + n_tokens, true); -#endif - } else { - if (cur_token + n_tokens >= n_tokens_all) { - ggml_backend_tensor_get_async(backend_res, res, logits_out, n_vocab*(n_tokens - 1)*sizeof(float), n_vocab*sizeof(float)); -#ifndef NDEBUG - logits_valid[0] = true; -#endif - } + GGML_ASSERT(lctx.logits != nullptr); + + float * logits_out = lctx.logits + n_outputs_prev*n_vocab; + const int32_t n_outputs_new = lctx.n_outputs; + + if (n_outputs_new) { + GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs); + GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size); + ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float)); } } // extract embeddings - if (cparams.embeddings && embd) { + if (embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); GGML_ASSERT(backend_embd != nullptr); @@ -9722,16 +10045,14 @@ static int llama_decode_internal( case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - auto & embd_out = lctx.embd; - - if (u_batch.logits) { - //embd_out.resize(n_embd * n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - if (u_batch.logits[i] == 0) { - continue; - } - ggml_backend_tensor_get_async(backend_embd, embd, embd_out + n_embd*(i + cur_token), (n_embd*i)*sizeof(float), n_embd*sizeof(float)); - } + GGML_ASSERT(lctx.embd != nullptr); + float * embd_out = lctx.embd + n_outputs_prev*n_embd; + const int32_t n_outputs_new = lctx.n_outputs; + + if (n_outputs_new) { + GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs); + GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size); + ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_CLS: @@ -9758,6 +10079,7 @@ static int llama_decode_internal( } break; } } + n_outputs_prev += lctx.n_outputs; } // wait for the computation to finish (automatically done when obtaining the model output) @@ -13531,7 +13853,7 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; - // TODO: maybe add n_seq_max here too + cparams.n_seq_max = std::max(1u, params.n_seq_max); cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -13733,25 +14055,12 @@ struct llama_context * llama_new_context_with_model( // graph outputs buffer { - // resized during inference, reserve maximum - ctx->logits_size = hparams.n_vocab*cparams.n_batch; - ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_batch : 0; - - const size_t buf_output_size = (ctx->logits_size + ctx->embd_size)*sizeof(float); - - ctx->buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size); - if (ctx->buf_output == nullptr) { - LLAMA_LOG_ERROR("%s: failed to allocate logits buffer\n", __func__); + // resized during inference when a batch uses more outputs + if (llama_output_reserve(*ctx, params.n_seq_max) < params.n_seq_max) { + LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__); llama_free(ctx); return nullptr; } - ggml_backend_buffer_clear(ctx->buf_output, 0); - - - ctx->logits = (float *) ggml_backend_buffer_get_base(ctx->buf_output); - if (params.embeddings) { - ctx->embd = ctx->logits + ctx->logits_size; - } LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(ctx->buf_output), @@ -14268,27 +14577,33 @@ void llama_kv_cache_update(struct llama_context * ctx) { // Returns the *maximum* size of the state size_t llama_get_state_size(const struct llama_context * ctx) { + const auto & cparams = ctx->cparams; + const auto & hparams = ctx->model.hparams; + // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. // for reference, std::mt19937(1337) serializes to 6701 bytes. const size_t s_rng_size = sizeof(size_t); const size_t s_rng = LLAMA_MAX_RNG_STATE; + const size_t s_n_outputs = sizeof(size_t); + // assume worst case for outputs although only currently set ones are serialized + const size_t s_output_pos = ctx->cparams.n_batch * sizeof(int32_t); const size_t s_logits_size = sizeof(size_t); - // assume worst case for logits although only currently set ones are serialized - const size_t s_logits = ctx->logits_size * sizeof(float); + const size_t s_logits = ctx->logits_size ? cparams.n_batch * hparams.n_vocab * sizeof(float) : 0; const size_t s_embedding_size = sizeof(size_t); - const size_t s_embedding = ctx->embd_size * sizeof(float); + const size_t s_embedding = ctx->embd_size ? cparams.n_batch * hparams.n_embd * sizeof(float) : 0; const size_t s_kv_buf_size = sizeof(size_t); const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_used = sizeof(uint32_t); const size_t s_kv = ctx->kv_self.total_size(); - // TODO: assume the max is more than 1 seq_id per KV cell - const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id); + const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; const size_t s_total = ( + s_rng_size + s_rng + + s_n_outputs + + s_output_pos + s_logits_size + s_logits + s_embedding_size @@ -14363,7 +14678,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat std::ostringstream rng_ss; rng_ss << ctx->rng; - const std::string & rng_str = rng_ss.str(); + const std::string & rng_str = rng_ss.str(); const size_t rng_size = rng_str.size(); GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); @@ -14372,25 +14687,61 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat data_ctx->write(rng_str.data(), rng_size); } - // copy logits + // copy outputs { - const size_t logits_size = ctx->logits_size; + // Can't use ctx->n_outputs because it's not for the + // entire last batch when n_ubatch is smaller than n_batch + size_t n_outputs = 0; - data_ctx->write(&logits_size, sizeof(logits_size)); + // copy output ids + { + std::vector output_pos; - if (logits_size) { - data_ctx->write(ctx->logits, logits_size * sizeof(float)); + const size_t n_batch = ctx->cparams.n_batch; + const auto & output_ids = ctx->output_ids; + + output_pos.resize(ctx->output_size); + + // build a more compact representation of the output ids + for (size_t i = 0; i < n_batch; ++i) { + // map an output id to a position in the batch + int32_t pos = output_ids[i]; + if (pos >= 0) { + if ((size_t) pos >= n_outputs) { + n_outputs = pos + 1; + } + GGML_ASSERT((size_t) pos < ctx->output_size); + output_pos[pos] = i; + } + } + + data_ctx->write(&n_outputs, sizeof(n_outputs)); + + if (n_outputs) { + data_ctx->write(output_pos.data(), n_outputs * sizeof(int32_t)); + } } - } - // copy embeddings - { - const size_t embeddings_size = ctx->embd_size; + // copy logits + { + const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab); - data_ctx->write(&embeddings_size, sizeof(embeddings_size)); + data_ctx->write(&logits_size, sizeof(logits_size)); - if (embeddings_size) { - data_ctx->write(ctx->embd, embeddings_size * sizeof(float)); + if (logits_size) { + data_ctx->write(ctx->logits, logits_size * sizeof(float)); + } + } + + // copy embeddings + { + const size_t embeddings_size = std::min(ctx->embd_size, n_outputs * ctx->model.hparams.n_embd); + + data_ctx->write(&embeddings_size, sizeof(embeddings_size)); + + if (embeddings_size) { + data_ctx->write(ctx->embd, embeddings_size * sizeof(float)); + } } } @@ -14403,9 +14754,10 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); - const size_t kv_buf_size = kv_self.total_size(); + // NOTE: kv_size and kv_buf_size are mostly used for sanity checks const uint32_t kv_head = llama_kv_cache_cell_max(kv_self); const uint32_t kv_size = kv_self.size; + const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head; const uint32_t kv_used = kv_self.used; data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); @@ -14414,6 +14766,8 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat data_ctx->write(&kv_used, sizeof(kv_used)); if (kv_buf_size) { + const size_t pre_kv_buf_size = data_ctx->get_size_written(); + std::vector tmp_buf; for (int il = 0; il < (int) n_layer; ++il) { const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head); @@ -14443,6 +14797,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat data_ctx->write(tmp_buf.data(), tmp_buf.size()); } } + GGML_ASSERT(kv_buf_size == data_ctx->get_size_written() - pre_kv_buf_size); } for (uint32_t i = 0; i < kv_head; ++i) { @@ -14487,6 +14842,28 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { GGML_ASSERT(!rng_ss.fail()); } + // set output ids + { + size_t n_outputs; + std::vector output_pos; + + memcpy(&n_outputs, inp, sizeof(n_outputs)); inp += sizeof(n_outputs); + + GGML_ASSERT(n_outputs <= llama_output_reserve(*ctx, n_outputs)); + + if (n_outputs) { + output_pos.resize(n_outputs); + memcpy(output_pos.data(), inp, n_outputs * sizeof(int32_t)); + inp += n_outputs * sizeof(int32_t); + + for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { + int32_t id = output_pos[i]; + GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch); + ctx->output_ids[id] = i; + } + } + } + // set logits { size_t logits_size; @@ -14507,7 +14884,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size); - GGML_ASSERT(ctx->embd_size == embeddings_size); + GGML_ASSERT(ctx->embd_size >= embeddings_size); if (embeddings_size) { memcpy(ctx->embd, inp, embeddings_size * sizeof(float)); @@ -14534,8 +14911,18 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used); + if (kv_self.size != kv_size) { + // the KV cache needs to be big enough to load all the KV cells from the saved state + GGML_ASSERT(kv_self.size >= kv_head); + + LLAMA_LOG_INFO("%s: state contains %d KV cells, was saved with kv_size=%d, but is loaded with kv_size=%d (fine, but different)\n", + __func__, kv_head, kv_size, kv_self.size); + } + if (kv_buf_size) { - GGML_ASSERT(kv_self.total_size() == kv_buf_size); + const size_t pre_kv_buf_size = inp - src; + + GGML_ASSERT(kv_self.total_size() >= kv_buf_size); for (int il = 0; il < (int) n_layer; ++il) { const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head); @@ -14555,23 +14942,21 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { // v is not contiguous, copy row by row const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); - const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size); + const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size); for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) { ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size); inp += v_row_size; } } + GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size); } - GGML_ASSERT(kv_self.size == kv_size); + llama_kv_cache_clear(ctx); ctx->kv_self.head = kv_head; - ctx->kv_self.size = kv_size; ctx->kv_self.used = kv_used; - ctx->kv_self.cells.resize(kv_size); - for (uint32_t i = 0; i < kv_head; ++i) { llama_pos pos; size_t seq_id_size; @@ -14588,11 +14973,6 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { ctx->kv_self.cells[i].seq_id.insert(seq_id); } } - - for (uint32_t i = kv_head; i < kv_size; ++i) { - ctx->kv_self.cells[i].pos = -1; - ctx->kv_self.cells[i].seq_id.clear(); - } } const size_t nread = inp - src; @@ -14798,11 +15178,33 @@ float * llama_get_logits(struct llama_context * ctx) { } float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { - assert(ctx->logits_valid.at(i)); - llama_synchronize(ctx); - return ctx->logits + i*ctx->model.hparams.n_vocab; + try { + if (ctx->logits == nullptr) { + throw std::runtime_error("no logits"); + } + if ((size_t) i >= ctx->output_ids.size()) { + throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); + } + const int32_t j = ctx->output_ids[i]; + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if ((size_t) j >= ctx->output_size) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size)); + } + + return ctx->logits + j*ctx->model.hparams.n_vocab; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ASSERT(false); +#endif + return nullptr; + } } float * llama_get_embeddings(struct llama_context * ctx) { @@ -14814,7 +15216,31 @@ float * llama_get_embeddings(struct llama_context * ctx) { float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { llama_synchronize(ctx); - return ctx->embd + i*ctx->model.hparams.n_embd; + try { + if (ctx->embd == nullptr) { + throw std::runtime_error("no embeddings"); + } + if ((size_t) i >= ctx->output_ids.size()) { + throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); + } + const int32_t j = ctx->output_ids[i]; + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if ((size_t) j >= ctx->output_size) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size)); + } + + return ctx->embd + j*ctx->model.hparams.n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ASSERT(false); +#endif + return nullptr; + } } float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) { diff --git a/llama.h b/llama.h index 54d6224070402..1fe4af495820f 100644 --- a/llama.h +++ b/llama.h @@ -39,7 +39,7 @@ #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 4 +#define LLAMA_SESSION_VERSION 5 #ifdef __cplusplus extern "C" { @@ -678,23 +678,29 @@ extern "C" { LLAMA_API void llama_synchronize(struct llama_context * ctx); // Token logits obtained from the last call to llama_decode() - // The logits for the last token are stored in the last row - // Logits for which llama_batch.logits[i] == 0 are undefined - // Rows: n_tokens provided with llama_batch + // The logits for which llama_batch.logits[i] != 0 are stored contiguously + // in the order they have appeared in the batch. + // Rows: number of tokens for which llama_batch.logits[i] != 0 // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); // Logits for the ith token. Equivalent to: - // llama_get_logits(ctx) + i*n_vocab + // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab + // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); - // Get all output token embeddings - // shape: [n_tokens*n_embd] (1-dimensional) + // Get all output token embeddings. + // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, + // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously + // in the order they have appeared in the batch. + // shape: [n_outputs*n_embd] + // Otherwise, returns NULL. LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); - // Get the embeddings for the ith token - // llama_get_embeddings(ctx) + i*n_embd + // Get the embeddings for the ith token. Equivalent to: + // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd // shape: [n_embd] (1-dimensional) + // returns NULL for invalid ids. LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); // Get the embeddings for a sequence id