From 662aaea8c9c624f9a6622229e0ed01b7d37248d1 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Wed, 27 Mar 2024 16:56:35 +0800 Subject: [PATCH 01/33] llama : save and restore kv cache for single seq id --- examples/server/server.cpp | 223 +++++++++++++++++++++++++++++++- llama.cpp | 253 +++++++++++++++++++++++++++++++++++++ llama.h | 14 ++ 3 files changed, 489 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 53ad9239efb99..4e9a0e9e32624 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -61,7 +61,10 @@ enum server_task_type { SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_NEXT_RESPONSE, - SERVER_TASK_TYPE_METRICS + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, }; struct server_task { @@ -1612,6 +1615,142 @@ struct server_context { } queue_results.send(res); } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + const size_t token_count = slot->cache_tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.data["filename"]; + size_t state_size = llama_get_seq_size(ctx, slot->id + 1); + std::vector state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token)); + size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1); + GGML_ASSERT(nwrite <= state_size); + + // write the cached token count of the slot->cache_tokens.size() + memcpy(state_data.data() + nwrite, &token_count, sizeof(size_t)); + nwrite += sizeof(size_t); + + // write the cached tokens (loop) + for (size_t i = 0; i < token_count; i++) { + const llama_token token = slot->cache_tokens[i]; + memcpy(state_data.data() + nwrite, &token, sizeof(llama_token)); + nwrite += sizeof(llama_token); + } + GGML_ASSERT(nwrite <= state_data.size()); + + std::ofstream outfile(filename, std::ios::binary); + outfile.write(reinterpret_cast(state_data.data()), nwrite); + outfile.close(); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", token_count }, // tokens saved + { "n_written", nwrite }, // bytes written + { "timings", { + { "save_ms", t_save_ms } + } } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.data["filename"]; // TODO: restrict to files in path specified in server params? + std::ifstream infile(filename, std::ios::binary); + if (!infile.is_open()) { + send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST); + break; + } + + std::vector state_data((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); + infile.close(); + + size_t nread = llama_set_seq_data(ctx, state_data.data(), slot->id + 1); + GGML_ASSERT(nread <= state_data.size()); + + // restore cached token values + size_t token_count = 0; + if (nread + sizeof(size_t) <= state_data.size()) { + token_count = *reinterpret_cast(state_data.data() + nread); + nread += sizeof(size_t); + } + slot->cache_tokens.resize(token_count); + GGML_ASSERT(nread + (token_count * sizeof(llama_token)) <= state_data.size()); + + // tokens are of type llama_token (an integer) + for (size_t i = 0; i < token_count; i++) { + if (nread + sizeof(llama_token) <= state_data.size()) { + slot->cache_tokens[i] = *reinterpret_cast(state_data.data() + nread); + nread += sizeof(llama_token); + } + } + GGML_ASSERT(nread <= state_data.size()); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", token_count }, // tokens restored + { "n_read", nread }, // bytes read + { "timings", { + { "restore_ms", t_restore_ms } + } } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + // Erase token cache + const size_t n_erased = slot->cache_tokens.size(); + llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + slot->cache_tokens.clear(); + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "n_erased", n_erased } + }; + queue_results.send(result); + } break; } } @@ -3157,6 +3296,85 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; + const auto handle_slot_save = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + json request_data = json::parse(req.body); + int id_slot = request_data["id_slot"]; + std::string filename = request_data["filename"]; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_SAVE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slot_restore = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + json request_data = json::parse(req.body); + int id_slot = request_data["id_slot"]; + std::string filename = request_data["filename"]; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_RESTORE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slot_erase = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + json request_data = json::parse(req.body); + int id_slot = request_data["id_slot"]; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_ERASE; + task.data = { + { "id_slot", id_slot }, + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data, "application/json"); + } + }; + const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { @@ -3519,6 +3737,9 @@ int main(int argc, char ** argv) { svr->Post("/v1/embeddings", handle_embeddings); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); + svr->Post("/slot/save", handle_slot_save); + svr->Post("/slot/restore", handle_slot_restore); + svr->Post("/slot/erase", handle_slot_erase); // // Start the server diff --git a/llama.cpp b/llama.cpp index 892d46fbcfcec..e3a9eea4c384b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15059,6 +15059,259 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi return true; } +size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) { + // save the size of size_t as a uint32_t for safety check + const size_t size_t_size_size = sizeof(uint32_t); + + // other values + const size_t s_cell_count_size = sizeof(uint32_t); + const size_t s_layer_count_size = sizeof(uint32_t); + const size_t n_embd_v_gqa_size = sizeof(uint32_t); + + size_t s_cell_count = 0; + size_t s_cell_data_size = 0; + const auto& kv_self = ctx->kv_self; + const auto& hparams = ctx->model.hparams; + + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + + for (uint32_t i = 0; i < kv_self.size; ++i) { + const auto& cell = kv_self.cells[i]; + if (cell.seq_id.count(seq_id) > 0) { + ++s_cell_count; + s_cell_data_size += sizeof(llama_pos); + } + } + + for (int il = 0; il < (int)n_layer; ++il) { + // k_size_row and v_size_el values of layer + s_cell_data_size += sizeof(size_t) * 2; + + // keys + const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + s_cell_data_size += k_size_row * s_cell_count; + + // values (transposed) + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + s_cell_data_size += v_size_el * s_cell_count * n_embd_v_gqa; + } + + const size_t s_total = ( + size_t_size_size + + s_cell_count_size + + s_layer_count_size + + n_embd_v_gqa_size + + s_cell_data_size + ); + + return s_total; +} + +size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) { + llama_data_buffer_context data_ctx(dst); + + // Save the size of size_t as a uint32_t for safety check + const uint32_t size_t_size = sizeof(size_t); + data_ctx.write(&size_t_size, sizeof(size_t_size)); + + const auto& kv_self = ctx->kv_self; + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id + { + uint32_t cell_range_begin = kv_self.size; + for (uint32_t i = 0; i < kv_self.size; ++i) { + const auto& cell = kv_self.cells[i]; + if (cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == kv_self.size) { + cell_range_begin = i; + } + } + else { + if (cell_range_begin != kv_self.size) { + cell_ranges.push_back({ cell_range_begin, i }); + cell_range_begin = kv_self.size; + } + } + } + if (cell_range_begin != kv_self.size) { + cell_ranges.push_back({ cell_range_begin, kv_self.size }); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto& range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + } + + // Write the cell count + data_ctx.write(&cell_count, sizeof(cell_count)); + + const auto & hparams = ctx->model.hparams; + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + + // Write the layer count + data_ctx.write(&n_layer, sizeof(n_layer)); + + // Write n_embd_v_gqa + data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // Iterate the ranges and write all the pos (this is the token position in the prompt) + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = kv_self.cells[i]; + data_ctx.write(&cell.pos, sizeof(cell.pos)); + } + } + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + std::vector tmp_buf; + for (int il = 0; il < (int)n_layer; ++il) { + // Write row size of key + const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + data_ctx.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + tmp_buf.resize(range_size * k_size_row); + ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row); + data_ctx.write(&tmp_buf[0], tmp_buf.size()); + } + } + + // For the values, they are transposed, so we also need the element size and get the element ranges from each row + const uint32_t kv_size = kv_self.size; + for (int il = 0; il < (int)n_layer; ++il) { + // Write element size + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + data_ctx.write(&v_size_el, sizeof(v_size_el)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + tmp_buf.resize(range_size * v_size_el); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + data_ctx.write(&tmp_buf[0], tmp_buf.size()); + } + } + } + + return data_ctx.get_size_written(); +} + +size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { + auto & kv_self = ctx->kv_self; + + // Wipe the slot + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + + const uint8_t * inp = src; + + // Read size of size_t + uint32_t size_t_size; + memcpy(&size_t_size, inp, sizeof(size_t_size)); + inp += sizeof(size_t_size); + GGML_ASSERT(size_t_size == sizeof(size_t)); + + // Read the cell count + uint32_t cell_count; + memcpy(&cell_count, inp, sizeof(cell_count)); + inp += sizeof(cell_count); + + // Read the layer count + uint32_t n_layer_ref; + memcpy(&n_layer_ref, inp, sizeof(n_layer_ref)); + inp += sizeof(n_layer_ref); + + // Read n_embd_v_gqa + uint32_t n_embd_v_gqa_ref; + memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref)); + inp += sizeof(n_embd_v_gqa_ref); + + // Allocate the new cells for the slot + llama_batch batch = llama_batch_init(cell_count, 0, 1); + batch.n_tokens = cell_count; + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + memcpy(&pos, inp, sizeof(pos)); + inp += sizeof(pos); + + batch.pos[i] = pos; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = dest_seq_id; + } + llama_kv_cache_find_slot(kv_self, batch); + + // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); + GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); + + const auto& hparams = ctx->model.hparams; + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t kv_size = kv_self.size; + const uint32_t kv_head = kv_self.head; + GGML_ASSERT(n_layer == n_layer_ref); + GGML_ASSERT(n_embd_v_gqa == n_embd_v_gqa_ref); + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo + for (int il = 0; il < (int)n_layer; ++il) { + // Read row size of key + size_t k_size_row_ref; + memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref)); + inp += sizeof(k_size_row_ref); + const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + GGML_ASSERT(k_size_row == k_size_row_ref); + + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row); + inp += cell_count * k_size_row; + } + + // For each layer, read the values for each cell (transposed) + for (int il = 0; il < (int)n_layer; ++il) { + // Read element size of value + size_t v_size_el_ref; + memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); + inp += sizeof(v_size_el_ref); + + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + GGML_ASSERT(v_size_el == v_size_el_ref); + + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); + inp += cell_count * v_size_el; + } + } + + // Cleanup + llama_batch_free(batch); + + const size_t nread = inp - src; + return nread; +} + void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { ctx->cparams.n_threads = n_threads; ctx->cparams.n_threads_batch = n_threads_batch; diff --git a/llama.h b/llama.h index 1fe4af495820f..33164a33af217 100644 --- a/llama.h +++ b/llama.h @@ -623,6 +623,20 @@ extern "C" { const llama_token * tokens, size_t n_token_count); + LLAMA_API size_t llama_get_seq_size( + struct llama_context * ctx, + llama_seq_id seq_id); + + LLAMA_API size_t llama_copy_seq_data( + struct llama_context * ctx, + uint8_t * dst, + llama_seq_id seq_id); + + LLAMA_API size_t llama_set_seq_data( + struct llama_context * ctx, + const uint8_t * src, + llama_seq_id dest_seq_id); + // // Decoding // From 5462817851c8b50a761d930b49ab285b0d1dca5f Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Wed, 27 Mar 2024 18:49:00 +0800 Subject: [PATCH 02/33] remove trailing whitespace --- llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index e3a9eea4c384b..e4ee7b3e6b146 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15196,7 +15196,7 @@ size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_ // Write element size const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); data_ctx.write(&v_size_el, sizeof(v_size_el)); - + // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { // Read each range of cells of v_size_el length each into tmp_buf and write out @@ -15215,7 +15215,7 @@ size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { auto & kv_self = ctx->kv_self; - + // Wipe the slot llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); @@ -15226,7 +15226,7 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama memcpy(&size_t_size, inp, sizeof(size_t_size)); inp += sizeof(size_t_size); GGML_ASSERT(size_t_size == sizeof(size_t)); - + // Read the cell count uint32_t cell_count; memcpy(&cell_count, inp, sizeof(cell_count)); From ab1c46a7bfa1196e1e5cb669ad13c1be7cc2ce21 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Wed, 27 Mar 2024 19:11:47 +0800 Subject: [PATCH 03/33] respond error in case there's no space in the kv cache --- examples/server/server.cpp | 4 +++ llama.cpp | 55 +++++++++++++++++++++----------------- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 4e9a0e9e32624..227bb3c6b60fe 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1689,6 +1689,10 @@ struct server_context { infile.close(); size_t nread = llama_set_seq_data(ctx, state_data.data(), slot->id + 1); + if (nread == 0) { + send_error(task, "Unable to restore slot, no available space in KV cache", ERROR_TYPE_INVALID_REQUEST); + break; + } GGML_ASSERT(nread <= state_data.size()); // restore cached token values diff --git a/llama.cpp b/llama.cpp index e4ee7b3e6b146..0faaac0134395 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15111,12 +15111,13 @@ size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) { size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) { llama_data_buffer_context data_ctx(dst); + const auto& kv_self = ctx->kv_self; + GGML_ASSERT(!kv_self.recurrent); // not implemented // Save the size of size_t as a uint32_t for safety check const uint32_t size_t_size = sizeof(size_t); data_ctx.write(&size_t_size, sizeof(size_t_size)); - const auto& kv_self = ctx->kv_self; std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; @@ -15215,6 +15216,7 @@ size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { auto & kv_self = ctx->kv_self; + GGML_ASSERT(!kv_self.recurrent); // not implemented // Wipe the slot llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); @@ -15243,26 +15245,34 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama inp += sizeof(n_embd_v_gqa_ref); // Allocate the new cells for the slot - llama_batch batch = llama_batch_init(cell_count, 0, 1); - batch.n_tokens = cell_count; - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - memcpy(&pos, inp, sizeof(pos)); - inp += sizeof(pos); - - batch.pos[i] = pos; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = dest_seq_id; - } - llama_kv_cache_find_slot(kv_self, batch); - - // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); - GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); + { + llama_batch batch = llama_batch_init(cell_count, 0, 1); + batch.n_tokens = cell_count; + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + memcpy(&pos, inp, sizeof(pos)); + inp += sizeof(pos); + + batch.pos[i] = pos; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = dest_seq_id; + } + if (!llama_kv_cache_find_slot(kv_self, batch)) { + llama_batch_free(batch); + return 0; + } + + // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); + GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); + + // Cleanup + llama_batch_free(batch); + } const auto& hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; @@ -15305,9 +15315,6 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama } } - // Cleanup - llama_batch_free(batch); - const size_t nread = inp - src; return nread; } From 02a184065a55552145919b466609f3b8f7ce6a05 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Wed, 27 Mar 2024 22:39:28 +0800 Subject: [PATCH 04/33] add kv seq save restore to test case --- examples/save-load-state/save-load-state.cpp | 93 +++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index ef952e2bd987c..6a1966712204a 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -24,6 +24,7 @@ int main(int argc, char ** argv) { std::string result0; std::string result1; + std::string result2; // init llama_model * model; @@ -141,16 +142,104 @@ int main(int argc, char ** argv) { n_past += 1; } - printf("\n"); + printf("\n\n"); llama_free(ctx2); - llama_free_model(model); if (result0 != result1) { fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__); return 1; } + // make new context + auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + + printf("\nsingle seq run: %s", params.prompt.c_str()); + + // load state (rng, logits, embedding and kv_cache) from file + { + std::vector state_mem(llama_get_state_size(ctx3)); + + FILE * fp_read = fopen("dump_state.bin", "rb"); + const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); + fclose(fp_read); + + if (read != llama_set_state_data(ctx3, state_mem.data())) { + fprintf(stderr, "\n%s : failed to read state\n", __func__); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + + fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + } + + // restore state (last tokens) + n_past = n_past_saved; + + // save seq 0 and load into seq 1 + { + // save kv of seq 0 + std::vector seq_store(llama_get_seq_size(ctx3, 0)); + const size_t ncopy = llama_copy_seq_data(ctx3, seq_store.data(), 0); + if (ncopy != seq_store.size()) { + fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); + + // erase whole kv + llama_kv_cache_clear(ctx3); + fprintf(stderr, "%s : kv cache cleared\n", __func__); + + // restore kv into seq 1 + const size_t nset = llama_set_seq_data(ctx3, seq_store.data(), 1); + if (nset != seq_store.size()) { + fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset); + } + + // third run with seq 1 instead of 0 + for (auto i = 0; i < params.n_predict; i++) { + auto * logits = llama_get_logits(ctx3); + auto n_vocab = llama_n_vocab(model); + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + auto next_token = llama_sample_token(ctx3, &candidates_p); + auto next_token_str = llama_token_to_piece(ctx3, next_token); + + printf("%s", next_token_str.c_str()); + result2 += next_token_str; + + if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) { + fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + n_past += 1; + } + + printf("\n"); + + llama_free(ctx3); + llama_free_model(model); + + if (result0 != result2) { + fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); + return 1; + } + fprintf(stderr, "\n%s : success\n", __func__); return 0; From b8e8facb0ee25cafc36cca5edf27c74f21f6edc8 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Thu, 28 Mar 2024 00:05:56 +0800 Subject: [PATCH 05/33] add --slot-save-path arg to enable save restore and restrict save location --- examples/server/server.cpp | 45 +++++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 227bb3c6b60fe..a86a20ae68d46 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -131,6 +131,7 @@ struct server_params { bool slots_endpoint = true; bool metrics_endpoint = false; + std::string slot_save_path; }; struct server_slot { @@ -1628,6 +1629,7 @@ struct server_context { const int64_t t_start = ggml_time_us(); std::string filename = task.data["filename"]; + std::string filepath = task.data["filepath"]; size_t state_size = llama_get_seq_size(ctx, slot->id + 1); std::vector state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token)); size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1); @@ -1645,7 +1647,7 @@ struct server_context { } GGML_ASSERT(nwrite <= state_data.size()); - std::ofstream outfile(filename, std::ios::binary); + std::ofstream outfile(filepath, std::ios::binary); outfile.write(reinterpret_cast(state_data.data()), nwrite); outfile.close(); @@ -1678,8 +1680,9 @@ struct server_context { const int64_t t_start = ggml_time_us(); - std::string filename = task.data["filename"]; // TODO: restrict to files in path specified in server params? - std::ifstream infile(filename, std::ios::binary); + std::string filename = task.data["filename"]; + std::string filepath = task.data["filepath"]; + std::ifstream infile(filepath, std::ios::binary); if (!infile.is_open()) { send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST); break; @@ -2392,6 +2395,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --log-disable disables logging to a file.\n"); printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n"); printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled"); + printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n"); printf("\n"); printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict); printf(" --override-kv KEY=TYPE:VALUE\n"); @@ -2798,6 +2802,16 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, sparams.slots_endpoint = false; } else if (arg == "--metrics") { sparams.metrics_endpoint = true; + } else if (arg == "--slot-save-path") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.slot_save_path = argv[i]; + // if doesn't end with DIRECTORY_SEPARATOR, add it + if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { + sparams.slot_save_path += DIRECTORY_SEPARATOR; + } } else if (arg == "--chat-template") { if (++i >= argc) { invalid_param = true; @@ -3300,18 +3314,24 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_slot_save = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + const auto handle_slot_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json request_data = json::parse(req.body); int id_slot = request_data["id_slot"]; std::string filename = request_data["filename"]; + if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) { + res_error(res, "Invalid filename"); + return; + } + std::string filepath = sparams.slot_save_path + filename; server_task task; task.type = SERVER_TASK_TYPE_SLOT_SAVE; task.data = { { "id_slot", id_slot }, { "filename", filename }, + { "filepath", filepath } }; const int id_task = ctx_server.queue_tasks.post(task); @@ -3327,18 +3347,24 @@ int main(int argc, char ** argv) { } }; - const auto handle_slot_restore = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + const auto handle_slot_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json request_data = json::parse(req.body); int id_slot = request_data["id_slot"]; std::string filename = request_data["filename"]; + if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) { + res_error(res, "Invalid filename"); + return; + } + std::string filepath = sparams.slot_save_path + filename; server_task task; task.type = SERVER_TASK_TYPE_SLOT_RESTORE; task.data = { { "id_slot", id_slot }, { "filename", filename }, + { "filepath", filepath } }; const int id_task = ctx_server.queue_tasks.post(task); @@ -3741,9 +3767,12 @@ int main(int argc, char ** argv) { svr->Post("/v1/embeddings", handle_embeddings); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); - svr->Post("/slot/save", handle_slot_save); - svr->Post("/slot/restore", handle_slot_restore); - svr->Post("/slot/erase", handle_slot_erase); + if (!sparams.slot_save_path.empty()) { + // only enable slot endpoints if slot_save_path is set + svr->Post("/slot/save", handle_slot_save); + svr->Post("/slot/restore", handle_slot_restore); + svr->Post("/slot/erase", handle_slot_erase); + } // // Start the server From b182f8f67f72fe41ec0aca3333ef9a06ac4b2d47 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 27 Mar 2024 16:31:27 +0000 Subject: [PATCH 06/33] Returning 0 for some cases, instead of asserting. --- llama.cpp | 20 +++++++++++++++----- llama.h | 4 ++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 0faaac0134395..6987a53449480 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15227,7 +15227,9 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama uint32_t size_t_size; memcpy(&size_t_size, inp, sizeof(size_t_size)); inp += sizeof(size_t_size); - GGML_ASSERT(size_t_size == sizeof(size_t)); + if (size_t_size != sizeof(size_t)) { + return 0; + } // Read the cell count uint32_t cell_count; @@ -15244,6 +15246,18 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref)); inp += sizeof(n_embd_v_gqa_ref); + // Sanity check model compatibility + const auto& hparams = ctx->model.hparams; + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + if (n_layer != n_layer_ref) { + return 0; + } + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + return 0; + } + // Allocate the new cells for the slot { llama_batch batch = llama_batch_init(cell_count, 0, 1); @@ -15274,10 +15288,6 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama llama_batch_free(batch); } - const auto& hparams = ctx->model.hparams; - const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); const uint32_t kv_size = kv_self.size; const uint32_t kv_head = kv_self.head; GGML_ASSERT(n_layer == n_layer_ref); diff --git a/llama.h b/llama.h index 33164a33af217..943f8fbb10a2e 100644 --- a/llama.h +++ b/llama.h @@ -632,6 +632,10 @@ extern "C" { uint8_t * dst, llama_seq_id seq_id); + // Copy the sequence data (originally copied with `llama_copy_seq_data`) into a sequence. + // Returns: + // - Positive: Ok + // - Zero: Failed to load LLAMA_API size_t llama_set_seq_data( struct llama_context * ctx, const uint8_t * src, From a2b48b95f59fd96007fd5a59c52744671a0f7c49 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Thu, 28 Mar 2024 01:11:07 +0800 Subject: [PATCH 07/33] cleanup error cases --- examples/server/server.cpp | 2 +- llama.cpp | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a86a20ae68d46..c39b59a83edc0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1693,7 +1693,7 @@ struct server_context { size_t nread = llama_set_seq_data(ctx, state_data.data(), slot->id + 1); if (nread == 0) { - send_error(task, "Unable to restore slot, no available space in KV cache", ERROR_TYPE_INVALID_REQUEST); + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } GGML_ASSERT(nread <= state_data.size()); diff --git a/llama.cpp b/llama.cpp index 6987a53449480..8151b24e62367 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15290,8 +15290,6 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama const uint32_t kv_size = kv_self.size; const uint32_t kv_head = kv_self.head; - GGML_ASSERT(n_layer == n_layer_ref); - GGML_ASSERT(n_embd_v_gqa == n_embd_v_gqa_ref); // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo for (int il = 0; il < (int)n_layer; ++il) { @@ -15300,7 +15298,10 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref)); inp += sizeof(k_size_row_ref); const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); - GGML_ASSERT(k_size_row == k_size_row_ref); + if (k_size_row != k_size_row_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + return 0; + } // Read and set the keys for the whole cell range ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row); @@ -15315,7 +15316,10 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama inp += sizeof(v_size_el_ref); const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - GGML_ASSERT(v_size_el == v_size_el_ref); + if (v_size_el != v_size_el_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + return 0; + } // For each row in the transposed matrix, read the values for the whole cell range for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { From c4443d7ad4f8417f88148b064f50c8da46fd8e52 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Thu, 28 Mar 2024 22:10:04 +0800 Subject: [PATCH 08/33] rename sequence state functions --- examples/save-load-state/save-load-state.cpp | 6 +++--- examples/server/server.cpp | 6 +++--- llama.cpp | 6 +++--- llama.h | 8 ++++---- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 6a1966712204a..d9f8c937e4caf 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -180,8 +180,8 @@ int main(int argc, char ** argv) { // save seq 0 and load into seq 1 { // save kv of seq 0 - std::vector seq_store(llama_get_seq_size(ctx3, 0)); - const size_t ncopy = llama_copy_seq_data(ctx3, seq_store.data(), 0); + std::vector seq_store(llama_state_seq_get_size(ctx3, 0)); + const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0); if (ncopy != seq_store.size()) { fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); llama_free(ctx3); @@ -195,7 +195,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s : kv cache cleared\n", __func__); // restore kv into seq 1 - const size_t nset = llama_set_seq_data(ctx3, seq_store.data(), 1); + const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1); if (nset != seq_store.size()) { fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); llama_free(ctx3); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c39b59a83edc0..28d8b31d9b6ef 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1630,9 +1630,9 @@ struct server_context { std::string filename = task.data["filename"]; std::string filepath = task.data["filepath"]; - size_t state_size = llama_get_seq_size(ctx, slot->id + 1); + size_t state_size = llama_state_seq_get_size(ctx, slot->id + 1); std::vector state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token)); - size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1); + size_t nwrite = llama_state_seq_get_data(ctx, state_data.data(), slot->id + 1); GGML_ASSERT(nwrite <= state_size); // write the cached token count of the slot->cache_tokens.size() @@ -1691,7 +1691,7 @@ struct server_context { std::vector state_data((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); infile.close(); - size_t nread = llama_set_seq_data(ctx, state_data.data(), slot->id + 1); + size_t nread = llama_state_seq_set_data(ctx, state_data.data(), slot->id + 1); if (nread == 0) { send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; diff --git a/llama.cpp b/llama.cpp index 8151b24e62367..69f5324af61c8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15059,7 +15059,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi return true; } -size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) { +size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) { // save the size of size_t as a uint32_t for safety check const size_t size_t_size_size = sizeof(uint32_t); @@ -15109,7 +15109,7 @@ size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) { return s_total; } -size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) { +size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) { llama_data_buffer_context data_ctx(dst); const auto& kv_self = ctx->kv_self; GGML_ASSERT(!kv_self.recurrent); // not implemented @@ -15214,7 +15214,7 @@ size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_ return data_ctx.get_size_written(); } -size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { +size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { auto & kv_self = ctx->kv_self; GGML_ASSERT(!kv_self.recurrent); // not implemented diff --git a/llama.h b/llama.h index 943f8fbb10a2e..375435acd3fe2 100644 --- a/llama.h +++ b/llama.h @@ -623,20 +623,20 @@ extern "C" { const llama_token * tokens, size_t n_token_count); - LLAMA_API size_t llama_get_seq_size( + LLAMA_API size_t llama_state_seq_get_size( struct llama_context * ctx, llama_seq_id seq_id); - LLAMA_API size_t llama_copy_seq_data( + LLAMA_API size_t llama_state_seq_get_data( struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id); - // Copy the sequence data (originally copied with `llama_copy_seq_data`) into a sequence. + // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into a sequence. // Returns: // - Positive: Ok // - Zero: Failed to load - LLAMA_API size_t llama_set_seq_data( + LLAMA_API size_t llama_state_seq_set_data( struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id); From 4d5356bbbb6d666d84f936c33aec1aa3a09c67d8 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Thu, 28 Mar 2024 22:19:57 +0800 Subject: [PATCH 09/33] rename state get set functions --- examples/main/main.cpp | 6 ++-- examples/save-load-state/save-load-state.cpp | 12 ++++---- llama.cpp | 30 ++++++++++---------- llama.h | 10 +++---- 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e2d07a6319d50..711f162d79fca 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -235,7 +235,7 @@ int main(int argc, char ** argv) { // The file exists and is not empty session_tokens.resize(n_ctx); size_t n_token_count_out = 0; - if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { + if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str()); return 1; } @@ -693,7 +693,7 @@ int main(int argc, char ** argv) { // optionally save the session on first sample (for faster prompt loading next time) if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) { need_to_save_session = false; - llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); + llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); LOG("saved session to %s\n", path_session.c_str()); } @@ -935,7 +935,7 @@ int main(int argc, char ** argv) { if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) { LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str()); - llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); + llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } llama_print_timings(ctx); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index d9f8c937e4caf..c3b766882dbec 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -45,8 +45,8 @@ int main(int argc, char ** argv) { // save state (rng, logits, embedding and kv_cache) to file { - std::vector state_mem(llama_get_state_size(ctx)); - const size_t written = llama_copy_state_data(ctx, state_mem.data()); + std::vector state_mem(llama_state_get_size(ctx)); + const size_t written = llama_state_get_data(ctx, state_mem.data()); FILE *fp_write = fopen("dump_state.bin", "wb"); fwrite(state_mem.data(), 1, written, fp_write); @@ -98,13 +98,13 @@ int main(int argc, char ** argv) { // load state (rng, logits, embedding and kv_cache) from file { - std::vector state_mem(llama_get_state_size(ctx2)); + std::vector state_mem(llama_state_get_size(ctx2)); FILE * fp_read = fopen("dump_state.bin", "rb"); const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); fclose(fp_read); - if (read != llama_set_state_data(ctx2, state_mem.data())) { + if (read != llama_state_set_data(ctx2, state_mem.data())) { fprintf(stderr, "\n%s : failed to read state\n", __func__); llama_free(ctx2); llama_free_model(model); @@ -158,13 +158,13 @@ int main(int argc, char ** argv) { // load state (rng, logits, embedding and kv_cache) from file { - std::vector state_mem(llama_get_state_size(ctx3)); + std::vector state_mem(llama_state_get_size(ctx3)); FILE * fp_read = fopen("dump_state.bin", "rb"); const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); fclose(fp_read); - if (read != llama_set_state_data(ctx3, state_mem.data())) { + if (read != llama_state_set_data(ctx3, state_mem.data())) { fprintf(stderr, "\n%s : failed to read state\n", __func__); llama_free(ctx3); llama_free_model(model); diff --git a/llama.cpp b/llama.cpp index 69f5324af61c8..1d11e8822c001 100644 --- a/llama.cpp +++ b/llama.cpp @@ -14570,7 +14570,7 @@ void llama_kv_cache_update(struct llama_context * ctx) { // Returns the *maximum* size of the state -size_t llama_get_state_size(const struct llama_context * ctx) { +size_t llama_state_get_size(const struct llama_context * ctx) { const auto & cparams = ctx->cparams; const auto & hparams = ctx->model.hparams; @@ -14658,15 +14658,15 @@ struct llama_data_file_context : llama_data_context { * file context: * llama_file file("/path", "wb"); * llama_data_file_context data_ctx(&file); - * llama_copy_state_data(ctx, &data_ctx); + * llama_state_get_data(ctx, &data_ctx); * * buffer context: * std::vector buf(max_size, 0); * llama_data_buffer_context data_ctx(&buf.data()); - * llama_copy_state_data(ctx, &data_ctx); + * llama_state_get_data(ctx, &data_ctx); * */ -static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { +static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { // copy rng { std::ostringstream rng_ss; @@ -14810,15 +14810,15 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat } } -size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { +size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst) { llama_data_buffer_context data_ctx(dst); - llama_copy_state_data_internal(ctx, &data_ctx); + llama_state_get_data_internal(ctx, &data_ctx); return data_ctx.get_size_written(); } // Sets the state reading from the specified source address -size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { +size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { const uint8_t * inp = src; // set rng @@ -14970,14 +14970,14 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { } const size_t nread = inp - src; - const size_t max_size = llama_get_state_size(ctx); + const size_t max_size = llama_state_get_size(ctx); GGML_ASSERT(nread <= max_size); return nread; } -static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(path_session, "rb"); // sanity checks @@ -15015,7 +15015,7 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c // restore the context state { const size_t n_state_size_cur = file.size - file.tell(); - const size_t n_state_size_max = llama_get_state_size(ctx); + const size_t n_state_size_max = llama_state_get_size(ctx); if (n_state_size_cur > n_state_size_max) { LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur); @@ -15025,22 +15025,22 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c std::vector state_data(n_state_size_max); file.read_raw(state_data.data(), n_state_size_cur); - llama_set_state_data(ctx, state_data.data()); + llama_state_set_data(ctx, state_data.data()); } return true; } -bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { try { - return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); + return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); } catch (const std::exception & err) { LLAMA_LOG_ERROR("error loading session file: %s\n", err.what()); return false; } } -bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { +bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { llama_file file(path_session, "wb"); file.write_u32(LLAMA_SESSION_MAGIC); @@ -15054,7 +15054,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi // save the context state using stream saving llama_data_file_context data_ctx(&file); - llama_copy_state_data_internal(ctx, &data_ctx); + llama_state_get_data_internal(ctx, &data_ctx); return true; } diff --git a/llama.h b/llama.h index 375435acd3fe2..6290c1d6c212f 100644 --- a/llama.h +++ b/llama.h @@ -594,30 +594,30 @@ extern "C" { // Returns the maximum size in bytes of the state (rng, logits, embedding // and kv_cache) - will often be smaller after compacting tokens - LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); + LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx); // Copies the state to the specified destination address. // Destination needs to have allocated enough memory. // Returns the number of bytes copied - LLAMA_API size_t llama_copy_state_data( + LLAMA_API size_t llama_state_get_data( struct llama_context * ctx, uint8_t * dst); // Set the state reading from the specified address // Returns the number of bytes read - LLAMA_API size_t llama_set_state_data( + LLAMA_API size_t llama_state_set_data( struct llama_context * ctx, const uint8_t * src); // Save/load session file - LLAMA_API bool llama_load_session_file( + LLAMA_API bool llama_state_load_file( struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); - LLAMA_API bool llama_save_session_file( + LLAMA_API bool llama_state_save_file( struct llama_context * ctx, const char * path_session, const llama_token * tokens, From bbcbf47b6dcf419cae192375cea1a6628f3a18e6 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 01:53:51 +0800 Subject: [PATCH 10/33] add previous function names back in with DEPRECATED notice --- llama.cpp | 24 ++++++++++++++++++++++++ llama.h | 23 +++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/llama.cpp b/llama.cpp index 1d11e8822c001..ffbac73824a07 100644 --- a/llama.cpp +++ b/llama.cpp @@ -14568,6 +14568,30 @@ void llama_kv_cache_update(struct llama_context * ctx) { llama_kv_cache_update_internal(*ctx); } +// deprecated +size_t llama_get_state_size(const struct llama_context * ctx) { + return llama_state_get_size(ctx); +} + +// deprecated +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { + return llama_state_get_data(ctx, dst); +} + +// deprecated +size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { + return llama_state_set_data(ctx, src); +} + +// deprecated +bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); +} + +// deprecated +bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + return llama_state_save_file(ctx, path_session, tokens, n_token_count); +} // Returns the *maximum* size of the state size_t llama_state_get_size(const struct llama_context * ctx) { diff --git a/llama.h b/llama.h index 6290c1d6c212f..f3e0c00229f48 100644 --- a/llama.h +++ b/llama.h @@ -595,6 +595,8 @@ extern "C" { // Returns the maximum size in bytes of the state (rng, logits, embedding // and kv_cache) - will often be smaller after compacting tokens LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx); + LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx), + "use llama_state_get_size instead"); // Copies the state to the specified destination address. // Destination needs to have allocated enough memory. @@ -602,12 +604,20 @@ extern "C" { LLAMA_API size_t llama_state_get_data( struct llama_context * ctx, uint8_t * dst); + LLAMA_API DEPRECATED(size_t llama_copy_state_data( + struct llama_context * ctx, + uint8_t * dst), + "use llama_state_get_data instead"); // Set the state reading from the specified address // Returns the number of bytes read LLAMA_API size_t llama_state_set_data( struct llama_context * ctx, const uint8_t * src); + LLAMA_API DEPRECATED(size_t llama_set_state_data( + struct llama_context * ctx, + const uint8_t * src), + "use llama_state_set_data instead"); // Save/load session file LLAMA_API bool llama_state_load_file( @@ -616,12 +626,25 @@ extern "C" { llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); + LLAMA_API DEPRECATED(bool llama_load_session_file( + struct llama_context * ctx, + const char * path_session, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out), + "use llama_state_load_file instead"); LLAMA_API bool llama_state_save_file( struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count); + LLAMA_API DEPRECATED(bool llama_save_session_file( + struct llama_context * ctx, + const char * path_session, + const llama_token * tokens, + size_t n_token_count), + "use llama_state_save_file instead"); LLAMA_API size_t llama_state_seq_get_size( struct llama_context * ctx, From 8b5ae299ecd625d793591f7e14fa1e3014f84d8a Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 02:03:28 +0800 Subject: [PATCH 11/33] update doc --- README.md | 1 + llama.h | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5517bf0939df6..6cd05be6a2660 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ### Recent API changes +- [2024 Mar 30] State and session file functions reorganized under `llama_state_*` https://github.com/ggerganov/llama.cpp/pull/6341 - [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122 - [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017 - [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328 diff --git a/llama.h b/llama.h index f3e0c00229f48..3c313b884d62a 100644 --- a/llama.h +++ b/llama.h @@ -646,16 +646,18 @@ extern "C" { size_t n_token_count), "use llama_state_save_file instead"); + // Get the exact size needed to copy the KV cache of a single sequence LLAMA_API size_t llama_state_seq_get_size( struct llama_context * ctx, llama_seq_id seq_id); + // Copy the KV cache of a single sequence into the specified buffer LLAMA_API size_t llama_state_seq_get_data( struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id); - // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into a sequence. + // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence // Returns: // - Positive: Ok // - Zero: Failed to load From a71ec3db7b3921de0b70cf344af574e7991e9f29 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 02:19:50 +0800 Subject: [PATCH 12/33] adjust endpoints to preferred style --- examples/server/server.cpp | 47 ++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 28d8b31d9b6ef..27d8f2d10f9b8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3314,11 +3314,8 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_slot_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - + const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); - int id_slot = request_data["id_slot"]; std::string filename = request_data["filename"]; if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) { res_error(res, "Invalid filename"); @@ -3347,11 +3344,8 @@ int main(int argc, char ** argv) { } }; - const auto handle_slot_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - + const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); - int id_slot = request_data["id_slot"]; std::string filename = request_data["filename"]; if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) { res_error(res, "Invalid filename"); @@ -3380,12 +3374,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_slot_erase = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - - json request_data = json::parse(req.body); - int id_slot = request_data["id_slot"]; - + const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res, int id_slot) { server_task task; task.type = SERVER_TASK_TYPE_SLOT_ERASE; task.data = { @@ -3405,6 +3394,32 @@ int main(int argc, char ** argv) { } }; + const auto handle_slots_action = [&ctx_server, &res_error, &sparams, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + std::string id_slot_str = req.path_params.at("id_slot"); + int id_slot; + + try { + id_slot = std::stoi(id_slot_str); + } catch (const std::exception &) { + res_error(res, "Invalid slot ID"); + return; + } + + std::string action = req.get_param_value("action"); + + if (action == "save") { + handle_slots_save(req, res, id_slot); + } else if (action == "restore") { + handle_slots_restore(req, res, id_slot); + } else if (action == "erase") { + handle_slots_erase(req, res, id_slot); + } else { + res_error(res, "Invalid action"); + } + }; + const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { @@ -3769,9 +3784,7 @@ int main(int argc, char ** argv) { svr->Post("/detokenize", handle_detokenize); if (!sparams.slot_save_path.empty()) { // only enable slot endpoints if slot_save_path is set - svr->Post("/slot/save", handle_slot_save); - svr->Post("/slot/restore", handle_slot_restore); - svr->Post("/slot/erase", handle_slot_erase); + svr->Post("/slots/:id_slot", handle_slots_action); } // From bf1d4932f8a4748bbc1ac11e4a63cc703078d24a Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 02:35:33 +0800 Subject: [PATCH 13/33] fix restoring zero cell count --- llama.cpp | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/llama.cpp b/llama.cpp index ffbac73824a07..f708d690cda1b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15283,7 +15283,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, } // Allocate the new cells for the slot - { + if (cell_count) { llama_batch batch = llama_batch_init(cell_count, 0, 1); batch.n_tokens = cell_count; for (uint32_t i = 0; i < cell_count; ++i) { @@ -15327,9 +15327,11 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, return 0; } - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row); - inp += cell_count * k_size_row; + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row); + inp += cell_count * k_size_row; + } } // For each layer, read the values for each cell (transposed) @@ -15339,17 +15341,19 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); inp += sizeof(v_size_el_ref); - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - if (v_size_el != v_size_el_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - return 0; - } + if (cell_count) { + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + if (v_size_el != v_size_el_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + return 0; + } - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; - ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); - inp += cell_count * v_size_el; + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); + inp += cell_count * v_size_el; + } } } From 8ab1a17251312a16704e8ba902eacfaa402ddbb0 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 02:39:33 +0800 Subject: [PATCH 14/33] handle seq rm return value --- examples/server/server.cpp | 5 ++++- llama.cpp | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 27d8f2d10f9b8..f08cad382fb56 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1745,7 +1745,10 @@ struct server_context { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + if (!llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1)) { + send_error(task, "Failed to erase slot KV cache", ERROR_TYPE_INVALID_REQUEST); + break; + } slot->cache_tokens.clear(); server_task_result result; diff --git a/llama.cpp b/llama.cpp index f708d690cda1b..c9ebd58f248f2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15243,7 +15243,9 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, GGML_ASSERT(!kv_self.recurrent); // not implemented // Wipe the slot - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + if (!llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1)) { + return 0; + } const uint8_t * inp = src; From 0d2213678ce22044290f9dacd7c7ff0c7008743c Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 02:41:28 +0800 Subject: [PATCH 15/33] unused param --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f08cad382fb56..20152d50dd9f0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3377,7 +3377,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { server_task task; task.type = SERVER_TASK_TYPE_SLOT_ERASE; task.data = { From 29f18c29b441ff289e4d851e6fe1a9d4bfd8d7ff Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 02:49:31 +0800 Subject: [PATCH 16/33] keep in the size check --- llama.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/llama.cpp b/llama.cpp index c9ebd58f248f2..ef1f3c7c6e89d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15343,13 +15343,13 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); inp += sizeof(v_size_el_ref); - if (cell_count) { - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - if (v_size_el != v_size_el_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - return 0; - } + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + if (v_size_el != v_size_el_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + return 0; + } + if (cell_count) { // For each row in the transposed matrix, read the values for the whole cell range for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; From f2e41b32393a8e33ba71dbec9081697c4b334591 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 06:01:38 +0800 Subject: [PATCH 17/33] fix return types --- examples/server/server.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 20152d50dd9f0..542598532dee9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3321,7 +3321,7 @@ int main(int argc, char ** argv) { json request_data = json::parse(req.body); std::string filename = request_data["filename"]; if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) { - res_error(res, "Invalid filename"); + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } std::string filepath = sparams.slot_save_path + filename; @@ -3351,7 +3351,7 @@ int main(int argc, char ** argv) { json request_data = json::parse(req.body); std::string filename = request_data["filename"]; if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) { - res_error(res, "Invalid filename"); + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } std::string filepath = sparams.slot_save_path + filename; @@ -3393,7 +3393,7 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data, "application/json"); + res.set_content(result.data.dump(), "application/json"); } }; @@ -3406,7 +3406,7 @@ int main(int argc, char ** argv) { try { id_slot = std::stoi(id_slot_str); } catch (const std::exception &) { - res_error(res, "Invalid slot ID"); + res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -3419,7 +3419,7 @@ int main(int argc, char ** argv) { } else if (action == "erase") { handle_slots_erase(req, res, id_slot); } else { - res_error(res, "Invalid action"); + res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); } }; From 92c468105bdb871b8a1150a81f25d02609566705 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 06:03:41 +0800 Subject: [PATCH 18/33] add server test case for slot save restore --- .../server/tests/features/slotsave.feature | 48 +++++++++++++++ examples/server/tests/features/steps/steps.py | 60 +++++++++++++++++++ 2 files changed, 108 insertions(+) create mode 100644 examples/server/tests/features/slotsave.feature diff --git a/examples/server/tests/features/slotsave.feature b/examples/server/tests/features/slotsave.feature new file mode 100644 index 0000000000000..37eefd5c073c1 --- /dev/null +++ b/examples/server/tests/features/slotsave.feature @@ -0,0 +1,48 @@ +@llama.cpp +@server +Feature: llama.cpp server slot management + + Background: Server startup + Given a server listening on localhost:8080 + And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models + And prompt caching is enabled + And 2 slots + And . as slot save path + And 2048 KV cache size + And 42 as server seed + And 24 max tokens to predict + Then the server is starting + Then the server is healthy + + Scenario: Save and Restore Slot + Given a user prompt "What is the capital of France?" + And using slot id 1 + And a completion request with no api error + Then 24 tokens are predicted matching Lily + And 22 prompt tokens are processed + When the slot 1 is saved with filename "slot1.bin" + Then the server responds with status code 200 + Given a user prompt "What is the capital of Germany?" + And a completion request with no api error + Then 24 tokens are predicted matching Thank + And 7 prompt tokens are processed + When the slot 2 is restored with filename "slot1.bin" + Then the server responds with status code 200 + Given a user prompt "What is the capital of France?" + And using slot id 2 + And a completion request with no api error + Then 24 tokens are predicted matching Lily + And 1 prompt tokens are processed + + Scenario: Erase Slot + Given a user prompt "What is the capital of France?" + And using slot id 1 + And a completion request with no api error + Then 24 tokens are predicted matching Lily + And 22 prompt tokens are processed + When the slot 1 is erased + Then the server responds with status code 200 + Given a user prompt "What is the capital of France?" + And a completion request with no api error + Then 24 tokens are predicted matching Lily + And 22 prompt tokens are processed diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 86c3339dc7183..0dcf6a86653bc 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -49,6 +49,9 @@ def step_server_config(context, server_fqdn, server_port): context.n_predict = None context.n_prompts = 0 context.n_server_predict = None + context.slot_save_path = None + context.id_slot = None + context.cache_prompt = None context.n_slots = None context.prompt_prefix = None context.prompt_suffix = None @@ -119,6 +122,21 @@ def step_server_n_predict(context, n_predict): context.n_server_predict = n_predict +@step('{slot_save_path} as slot save path') +def step_slot_save_path(context, slot_save_path): + context.slot_save_path = slot_save_path + + +@step('using slot id {id_slot:d}') +def step_id_slot(context, id_slot): + context.id_slot = id_slot + + +@step('prompt caching is enabled') +def step_enable_prompt_cache(context): + context.cache_prompt = True + + @step('continuous batching') def step_server_continuous_batching(context): context.server_continuous_batching = True @@ -212,6 +230,8 @@ async def step_request_completion(context, api_error): context.base_url, debug=context.debug, n_predict=context.n_predict, + cache_prompt=context.cache_prompt, + id_slot=context.id_slot, seed=await completions_seed(context), expect_api_error=expect_api_error, user_api_key=context.user_api_key) @@ -711,12 +731,48 @@ async def concurrent_requests(context, f_completion, *args, **kwargs): await asyncio.sleep(0.1) +@step('the slot {slot_id:d} is saved with filename "{filename}"') +@async_run_until_complete +async def step_save_slot(context, slot_id, filename): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=save', + json={"filename": filename}, + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the slot {slot_id:d} is restored with filename "{filename}"') +@async_run_until_complete +async def step_restore_slot(context, slot_id, filename): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore', + json={"filename": filename}, + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the slot {slot_id:d} is erased') +@async_run_until_complete +async def step_erase_slot(context, slot_id): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase', + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the server responds with status code {status_code:d}') +def step_server_responds_with_status_code(context, status_code): + assert context.response.status == status_code + + async def request_completion(prompt, base_url, debug=False, prompt_prefix=None, prompt_suffix=None, n_predict=None, + cache_prompt=False, + id_slot=None, seed=None, expect_api_error=None, user_api_key=None): @@ -738,6 +794,8 @@ async def request_completion(prompt, "prompt": prompt, "input_suffix": prompt_suffix, "n_predict": n_predict if n_predict is not None else -1, + "cache_prompt": cache_prompt, + "id_slot": id_slot, "seed": seed if seed is not None else 42 }, headers=headers, @@ -1104,6 +1162,8 @@ def start_server_background(context): server_args.extend(['--parallel', context.n_slots]) if context.n_server_predict: server_args.extend(['--n-predict', context.n_server_predict]) + if context.slot_save_path: + server_args.extend(['--slot-save-path', context.slot_save_path]) if context.server_api_key: server_args.extend(['--api-key', context.server_api_key]) if context.n_ga: From 60f685ff7a2986448642b5dbd7f4df6e5d0cac8d Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 06:14:33 +0800 Subject: [PATCH 19/33] cleanup --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 542598532dee9..fdd6ff01f88ac 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3397,7 +3397,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_slots_action = [&ctx_server, &res_error, &sparams, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); std::string id_slot_str = req.path_params.at("id_slot"); From d38eef468f6a2178697d35026b59836c0eae8f7d Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 23:23:21 +0800 Subject: [PATCH 20/33] add cake --- .../server/tests/features/slotsave.feature | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/examples/server/tests/features/slotsave.feature b/examples/server/tests/features/slotsave.feature index 37eefd5c073c1..9f1e58d23cbba 100644 --- a/examples/server/tests/features/slotsave.feature +++ b/examples/server/tests/features/slotsave.feature @@ -1,5 +1,5 @@ @llama.cpp -@server +@slotsave Feature: llama.cpp server slot management Background: Server startup @@ -15,34 +15,44 @@ Feature: llama.cpp server slot management Then the server is healthy Scenario: Save and Restore Slot + # First prompt in slot 1 should be fully processed Given a user prompt "What is the capital of France?" And using slot id 1 And a completion request with no api error - Then 24 tokens are predicted matching Lily + Then 24 tokens are predicted matching (Lily|cake) And 22 prompt tokens are processed When the slot 1 is saved with filename "slot1.bin" Then the server responds with status code 200 + # Since we have cache, this should only process the last tokens Given a user prompt "What is the capital of Germany?" And a completion request with no api error Then 24 tokens are predicted matching Thank And 7 prompt tokens are processed - When the slot 2 is restored with filename "slot1.bin" + # Loading the original cache into slot 0, + # we should only be processing 1 prompt token and get the same output + When the slot 0 is restored with filename "slot1.bin" Then the server responds with status code 200 Given a user prompt "What is the capital of France?" - And using slot id 2 + And using slot id 0 And a completion request with no api error - Then 24 tokens are predicted matching Lily + Then 24 tokens are predicted matching (Lily|cake) + And 1 prompt tokens are processed + # For verification that slot 1 was not corrupted during slot 0 load, same thing + Given a user prompt "What is the capital of Germany?" + And using slot id 1 + And a completion request with no api error + Then 24 tokens are predicted matching Thank And 1 prompt tokens are processed Scenario: Erase Slot Given a user prompt "What is the capital of France?" And using slot id 1 And a completion request with no api error - Then 24 tokens are predicted matching Lily + Then 24 tokens are predicted matching (Lily|cake) And 22 prompt tokens are processed When the slot 1 is erased Then the server responds with status code 200 Given a user prompt "What is the capital of France?" And a completion request with no api error - Then 24 tokens are predicted matching Lily + Then 24 tokens are predicted matching (Lily|cake) And 22 prompt tokens are processed From ea717f773edfbbcb062dea1af3ae3de06fb593cb Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 23:39:53 +0800 Subject: [PATCH 21/33] cleanup style --- llama.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index ef1f3c7c6e89d..ac8703ca2d770 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15094,15 +15094,15 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) size_t s_cell_count = 0; size_t s_cell_data_size = 0; - const auto& kv_self = ctx->kv_self; - const auto& hparams = ctx->model.hparams; + const auto & kv_self = ctx->kv_self; + const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); for (uint32_t i = 0; i < kv_self.size; ++i) { - const auto& cell = kv_self.cells[i]; + const auto & cell = kv_self.cells[i]; if (cell.seq_id.count(seq_id) > 0) { ++s_cell_count; s_cell_data_size += sizeof(llama_pos); @@ -15135,7 +15135,7 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) { llama_data_buffer_context data_ctx(dst); - const auto& kv_self = ctx->kv_self; + const auto & kv_self = ctx->kv_self; GGML_ASSERT(!kv_self.recurrent); // not implemented // Save the size of size_t as a uint32_t for safety check @@ -15150,7 +15150,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama { uint32_t cell_range_begin = kv_self.size; for (uint32_t i = 0; i < kv_self.size; ++i) { - const auto& cell = kv_self.cells[i]; + const auto & cell = kv_self.cells[i]; if (cell.has_seq_id(seq_id)) { ++cell_count; if (cell_range_begin == kv_self.size) { @@ -15170,7 +15170,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count uint32_t cell_count_check = 0; - for (const auto& range : cell_ranges) { + for (const auto & range : cell_ranges) { cell_count_check += range.second - range.first; } GGML_ASSERT(cell_count == cell_count_check); @@ -15211,7 +15211,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama const size_t range_size = range.second - range.first; tmp_buf.resize(range_size * k_size_row); ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row); - data_ctx.write(&tmp_buf[0], tmp_buf.size()); + data_ctx.write(tmp_buf.data(), tmp_buf.size()); } } @@ -15230,7 +15230,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama const size_t src_offset = (range.first + j * kv_size) * v_size_el; tmp_buf.resize(range_size * v_size_el); ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); - data_ctx.write(&tmp_buf[0], tmp_buf.size()); + data_ctx.write(tmp_buf.data(), tmp_buf.size()); } } } @@ -15273,7 +15273,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(n_embd_v_gqa_ref); // Sanity check model compatibility - const auto& hparams = ctx->model.hparams; + const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); From b509b8b3de4bb89eb3803d0d939a6519a41c929a Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 30 Mar 2024 23:57:38 +0800 Subject: [PATCH 22/33] add special --- examples/server/tests/features/slotsave.feature | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/tests/features/slotsave.feature b/examples/server/tests/features/slotsave.feature index 9f1e58d23cbba..1c281c0741afe 100644 --- a/examples/server/tests/features/slotsave.feature +++ b/examples/server/tests/features/slotsave.feature @@ -26,7 +26,7 @@ Feature: llama.cpp server slot management # Since we have cache, this should only process the last tokens Given a user prompt "What is the capital of Germany?" And a completion request with no api error - Then 24 tokens are predicted matching Thank + Then 24 tokens are predicted matching (Thank|special) And 7 prompt tokens are processed # Loading the original cache into slot 0, # we should only be processing 1 prompt token and get the same output @@ -41,7 +41,7 @@ Feature: llama.cpp server slot management Given a user prompt "What is the capital of Germany?" And using slot id 1 And a completion request with no api error - Then 24 tokens are predicted matching Thank + Then 24 tokens are predicted matching (Thank|special) And 1 prompt tokens are processed Scenario: Erase Slot From 129b6ffea63120e53d413a4828b8e85ab86deff3 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sun, 31 Mar 2024 00:43:47 +0800 Subject: [PATCH 23/33] removing a whole sequence never fails --- examples/server/server.cpp | 5 +---- llama.cpp | 4 +--- llama.h | 1 + 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fdd6ff01f88ac..adcfa79f9442e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1745,10 +1745,7 @@ struct server_context { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - if (!llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1)) { - send_error(task, "Failed to erase slot KV cache", ERROR_TYPE_INVALID_REQUEST); - break; - } + llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); slot->cache_tokens.clear(); server_task_result result; diff --git a/llama.cpp b/llama.cpp index ac8703ca2d770..145942078cf59 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15243,9 +15243,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, GGML_ASSERT(!kv_self.recurrent); // not implemented // Wipe the slot - if (!llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1)) { - return 0; - } + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); const uint8_t * inp = src; diff --git a/llama.h b/llama.h index 3c313b884d62a..0473f726abeef 100644 --- a/llama.h +++ b/llama.h @@ -523,6 +523,7 @@ extern "C" { struct llama_context * ctx); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) From 8af72118ec9ef155f942de265e3ba2a80c2e880d Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sun, 31 Mar 2024 03:26:25 +0800 Subject: [PATCH 24/33] move sequence state file functionality from server to llama to match session api and add version tags --- examples/server/server.cpp | 51 +++--------------------- llama.cpp | 80 +++++++++++++++++++++++++++++++++++++- llama.h | 19 +++++++++ 3 files changed, 102 insertions(+), 48 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index adcfa79f9442e..de05f47b9f995 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1630,26 +1630,8 @@ struct server_context { std::string filename = task.data["filename"]; std::string filepath = task.data["filepath"]; - size_t state_size = llama_state_seq_get_size(ctx, slot->id + 1); - std::vector state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token)); - size_t nwrite = llama_state_seq_get_data(ctx, state_data.data(), slot->id + 1); - GGML_ASSERT(nwrite <= state_size); - - // write the cached token count of the slot->cache_tokens.size() - memcpy(state_data.data() + nwrite, &token_count, sizeof(size_t)); - nwrite += sizeof(size_t); - - // write the cached tokens (loop) - for (size_t i = 0; i < token_count; i++) { - const llama_token token = slot->cache_tokens[i]; - memcpy(state_data.data() + nwrite, &token, sizeof(llama_token)); - nwrite += sizeof(llama_token); - } - GGML_ASSERT(nwrite <= state_data.size()); - std::ofstream outfile(filepath, std::ios::binary); - outfile.write(reinterpret_cast(state_data.data()), nwrite); - outfile.close(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -1682,39 +1664,16 @@ struct server_context { std::string filename = task.data["filename"]; std::string filepath = task.data["filepath"]; - std::ifstream infile(filepath, std::ios::binary); - if (!infile.is_open()) { - send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST); - break; - } - std::vector state_data((std::istreambuf_iterator(infile)), std::istreambuf_iterator()); - infile.close(); - - size_t nread = llama_state_seq_set_data(ctx, state_data.data(), slot->id + 1); + slot->cache_tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); if (nread == 0) { + slot->cache_tokens.resize(0); send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } - GGML_ASSERT(nread <= state_data.size()); - - // restore cached token values - size_t token_count = 0; - if (nread + sizeof(size_t) <= state_data.size()) { - token_count = *reinterpret_cast(state_data.data() + nread); - nread += sizeof(size_t); - } slot->cache_tokens.resize(token_count); - GGML_ASSERT(nread + (token_count * sizeof(llama_token)) <= state_data.size()); - - // tokens are of type llama_token (an integer) - for (size_t i = 0; i < token_count; i++) { - if (nread + sizeof(llama_token) <= state_data.size()) { - slot->cache_tokens[i] = *reinterpret_cast(state_data.data() + nread); - nread += sizeof(llama_token); - } - } - GGML_ASSERT(nread <= state_data.size()); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; diff --git a/llama.cpp b/llama.cpp index 145942078cf59..5e1747842b2ad 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15133,8 +15133,7 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) return s_total; } -size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) { - llama_data_buffer_context data_ctx(dst); +static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) { const auto & kv_self = ctx->kv_self; GGML_ASSERT(!kv_self.recurrent); // not implemented @@ -15238,6 +15237,11 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama return data_ctx.get_size_written(); } +size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_seq_id seq_id) { + llama_data_buffer_context data_ctx(dst); + return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); +} + size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { auto & kv_self = ctx->kv_self; GGML_ASSERT(!kv_self.recurrent); // not implemented @@ -15361,6 +15365,78 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, return nread; } +size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); + + file.write_u32(LLAMA_STATE_SEQ_MAGIC); + file.write_u32(LLAMA_STATE_SEQ_VERSION); + + // save the prompt + file.write_u32((uint32_t)n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_data_file_context data_ctx(&file); + llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); + + const size_t res = file.tell(); + GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written()); + return res; +} + +static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(filepath, "rb"); + + // version checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { + LLAMA_LOG_ERROR("%s : unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); + return 0; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s : token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return 0; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t state_size = file.size - file.tell(); + std::vector state_data(state_size); + file.read_raw(state_data.data(), state_size); + const size_t nread = llama_state_seq_set_data(ctx, state_data.data(), dest_seq_id); + if (!nread) { + LLAMA_LOG_ERROR("%s : failed to restore sequence state\n", __func__); + return 0; + } + GGML_ASSERT(nread <= state_size); + GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); + } + + return file.tell(); +} + +size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what()); + return false; + } +} + void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { ctx->cparams.n_threads = n_threads; ctx->cparams.n_threads_batch = n_threads_batch; diff --git a/llama.h b/llama.h index 0473f726abeef..d6e8b2ca66669 100644 --- a/llama.h +++ b/llama.h @@ -37,10 +37,14 @@ #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' +#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN #define LLAMA_SESSION_VERSION 5 +#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ +#define LLAMA_STATE_SEQ_VERSION 1 + #ifdef __cplusplus extern "C" { #endif @@ -667,6 +671,21 @@ extern "C" { const uint8_t * src, llama_seq_id dest_seq_id); + LLAMA_API size_t llama_state_seq_save_file( + struct llama_context * ctx, + const char * filepath, + llama_seq_id seq_id, + const llama_token * tokens, + size_t n_token_count); + + LLAMA_API size_t llama_state_seq_load_file( + struct llama_context * ctx, + const char * filepath, + llama_seq_id dest_seq_id, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + // // Decoding // From 3d6fa5bdd72ab49b9a7c5dafaa81d0f981ccb8f1 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Tue, 2 Apr 2024 04:06:23 +0800 Subject: [PATCH 25/33] catch exceptions on save as well --- llama.cpp | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 5e1747842b2ad..615804b56de0d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15064,7 +15064,7 @@ bool llama_state_load_file(struct llama_context * ctx, const char * path_session } } -bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { +static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { llama_file file(path_session, "wb"); file.write_u32(LLAMA_SESSION_MAGIC); @@ -15083,6 +15083,15 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session return true; } +bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + try { + return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error saving session file: %s\n", err.what()); + return false; + } +} + size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) { // save the size of size_t as a uint32_t for safety check const size_t size_t_size_size = sizeof(uint32_t); @@ -15365,7 +15374,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, return nread; } -size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { +static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { llama_file file(filepath, "wb"); file.write_u32(LLAMA_STATE_SEQ_MAGIC); @@ -15428,12 +15437,21 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con return file.tell(); } +size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + try { + return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error saving sequence state file: %s\n", err.what()); + return 0; + } +} + size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { try { return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); } catch (const std::exception & err) { LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what()); - return false; + return 0; } } From b3f6da3d60d1609a2a50b9c8472cb1530470265d Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Tue, 2 Apr 2024 04:10:17 +0800 Subject: [PATCH 26/33] error log messages --- llama.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 615804b56de0d..1dbb9a93dd414 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15265,6 +15265,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, memcpy(&size_t_size, inp, sizeof(size_t_size)); inp += sizeof(size_t_size); if (size_t_size != sizeof(size_t)) { + LLAMA_LOG_ERROR("%s: size_t size mismatch\n", __func__); return 0; } @@ -15289,9 +15290,11 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); if (n_layer != n_layer_ref) { + LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref); return 0; } if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref); return 0; } @@ -15310,6 +15313,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, } if (!llama_kv_cache_find_slot(kv_self, batch)) { llama_batch_free(batch); + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return 0; } @@ -15337,6 +15341,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, k_size_row_ref, il); return 0; } @@ -15357,6 +15362,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); if (v_size_el != v_size_el_ref) { llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); return 0; } @@ -15402,7 +15408,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con const uint32_t version = file.read_u32(); if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { - LLAMA_LOG_ERROR("%s : unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); + LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); return 0; } } @@ -15412,7 +15418,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con const uint32_t n_token_count = file.read_u32(); if (n_token_count > n_token_capacity) { - LLAMA_LOG_ERROR("%s : token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); return 0; } @@ -15427,7 +15433,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con file.read_raw(state_data.data(), state_size); const size_t nread = llama_state_seq_set_data(ctx, state_data.data(), dest_seq_id); if (!nread) { - LLAMA_LOG_ERROR("%s : failed to restore sequence state\n", __func__); + LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); return 0; } GGML_ASSERT(nread <= state_size); From be714a0fdaa78e35c74ff539474d1ebcf564ab5a Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Tue, 2 Apr 2024 04:17:15 +0800 Subject: [PATCH 27/33] check types for stricter restore --- llama.cpp | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 1dbb9a93dd414..726db8218398e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15119,6 +15119,8 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) } for (int il = 0; il < (int)n_layer; ++il) { + // types of keys and values + s_cell_data_size += sizeof(int32_t) * 2; // k_size_row and v_size_el values of layer s_cell_data_size += sizeof(size_t) * 2; @@ -15210,6 +15212,10 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam // Get whole range at a time std::vector tmp_buf; for (int il = 0; il < (int)n_layer; ++il) { + // Write key type + const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; + data_ctx.write(&k_type_i, sizeof(k_type_i)); + // Write row size of key const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); data_ctx.write(&k_size_row, sizeof(k_size_row)); @@ -15226,6 +15232,10 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam // For the values, they are transposed, so we also need the element size and get the element ranges from each row const uint32_t kv_size = kv_self.size; for (int il = 0; il < (int)n_layer; ++il) { + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); + // Write element size const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); data_ctx.write(&v_size_el, sizeof(v_size_el)); @@ -15334,6 +15344,17 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo for (int il = 0; il < (int)n_layer; ++il) { + // Read type of key + int32_t k_type_i_ref; + memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref)); + inp += sizeof(k_type_i_ref); + const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; + if (k_type_i != k_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return 0; + } + // Read row size of key size_t k_size_row_ref; memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref)); @@ -15354,11 +15375,21 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // For each layer, read the values for each cell (transposed) for (int il = 0; il < (int)n_layer; ++il) { + // Read type of value + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } + // Read element size of value size_t v_size_el_ref; memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); inp += sizeof(v_size_el_ref); - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); if (v_size_el != v_size_el_ref) { llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); From 0ccfbf2f6187bdd366f7a6d577b5f06999c6634c Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Tue, 2 Apr 2024 05:13:55 +0800 Subject: [PATCH 28/33] update server doc --- examples/server/README.md | 52 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/examples/server/README.md b/examples/server/README.md index aadc73b4ba81f..e240e624c6a64 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -59,6 +59,7 @@ see https://github.com/ggerganov/llama.cpp/issues/1437 - `-n N, --n-predict N`: Set the maximum tokens to predict (default: -1) - `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included. - `--metrics`: enable prometheus `/metrics` compatible endpoint (default: disabled) +- `--slot-save-path PATH`: Specifies the path where the state of slots (the prompt cache) can be stored. If not provided, the slot management endpoints will be disabled. - `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name (default: template taken from model's metadata). We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) - `--log-disable`: Output logs to stdout only, not to `llama.log`. default: enabled. - `--log-format FORMAT`: Define the log output to FORMAT: json or text (default: json) @@ -519,6 +520,57 @@ Available metrics: - `llamacpp:requests_processing`: Number of request processing. - `llamacpp:requests_deferred`: Number of request deferred. +- **POST** `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file. + + *Options:* + + `filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter. + +### Result JSON + +```json +{ + "id_slot": 0, + "filename": "slot_save_file.bin", + "n_saved": 1745, + "n_written": 14309796, + "timings": { + "save_ms": 49.865 + } +} +``` + +- **POST** `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file. + + *Options:* + + `filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter. + +### Result JSON + +```json +{ + "id_slot": 0, + "filename": "slot_save_file.bin", + "n_restored": 1745, + "n_read": 14309796, + "timings": { + "restore_ms": 42.937 + } +} +``` + +- **POST** `/slots/{id_slot}?action=erase`: Erase the prompt cache of the specified slot. + +### Result JSON + +```json +{ + "id_slot": 0, + "n_erased": 1745 +} +``` + ## More examples ### Change system prompt on runtime From 205c44c212919b7e8226a398dd2be03f3c767805 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 4 Apr 2024 11:44:26 +0300 Subject: [PATCH 29/33] readme : update API changes date --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6cd05be6a2660..e8d811509e5b6 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ### Recent API changes -- [2024 Mar 30] State and session file functions reorganized under `llama_state_*` https://github.com/ggerganov/llama.cpp/pull/6341 +- [2024 Apr 4] State and session file functions reorganized under `llama_state_*` https://github.com/ggerganov/llama.cpp/pull/6341 - [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122 - [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017 - [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328 From f2a4777d4afd1346d3d5a7a686b20ccdd9ff69a5 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 6 Apr 2024 00:44:37 +0800 Subject: [PATCH 30/33] strict filename validation --- common/common.cpp | 57 ++++++++++++++++++++++++++++++++++++++ common/common.h | 2 ++ examples/server/server.cpp | 4 +-- 3 files changed, 61 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 3e2df6e34adb4..5c4b6f1e9ce4b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1500,6 +1500,63 @@ std::string gpt_random_prompt(std::mt19937 & rng) { GGML_UNREACHABLE(); } +bool validate_file_name(const std::string & filename) { + if (filename.length() > 255) { + // Limit at common largest possible filename on Linux filesystems + // to avoid unnecessary further validation + // (On systems with smaller limits it will be caught by the OS) + return false; + } + + std::u32string filename_utf32; + try { + std::wstring_convert, char32_t> converter; + filename_utf32 = converter.from_bytes(filename); + + // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used, + // or invalid encodings were encountered. Reject such attempts + std::string filename_reencoded = converter.to_bytes(filename_utf32); + if (filename_reencoded != filename) { + return false; + } + } catch (const std::exception &) { + return false; + } + + // Check for forbidden codepoints: + // - Control characters + // - Unicode equivalents of illegal characters + // - UTF-16 surrogate pairs + // - UTF-8 replacement character + // - Illegal characters: / \ : * ? " < > | + for (char32_t c : filename_utf32) { + if (c <= 0x1F // Control characters (C0) + || c == 0x7F // Control characters (DEL) + || (c >= 0x80 && c <= 0x9F) // Control characters (C1) + || c == 0xFF0E // Fullwidth Full Stop (period equivalent) + || c == 0x2215 // Division Slash (forward slash equivalent) + || c == 0x2216 // Set Minus (backslash equivalent) + || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs + || c == 0xFFFD // Replacement Character (UTF-8) + || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters + || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') { + return false; + } + } + + // Reject any ".." (this is stricter than checking for ../ combinations) + if (filename.find("..") != std::string::npos) { + return false; + } + + // Reject "." + if (filename == ".") { + return false; + } + + return true; +} + // // String utils // diff --git a/common/common.h b/common/common.h index 99ee90bc3c728..4635e05d6381f 100644 --- a/common/common.h +++ b/common/common.h @@ -179,6 +179,8 @@ std::string gpt_random_prompt(std::mt19937 & rng); void process_escapes(std::string& input); +bool validate_file_name(const std::string & filename); + // // String utils // diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1a9b7f49d19e1..6c64fe3e17dec 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3278,7 +3278,7 @@ int main(int argc, char ** argv) { const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data["filename"]; - if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) { + if (!validate_file_name(filename)) { res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -3308,7 +3308,7 @@ int main(int argc, char ** argv) { const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data["filename"]; - if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) { + if (!validate_file_name(filename)) { res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } From 4a4f3993e73591702588eadcc735c5e807c57dc7 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 6 Apr 2024 00:59:34 +0800 Subject: [PATCH 31/33] move include, reject bom as well --- common/common.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index 5c4b6f1e9ce4b..8e9b57d706352 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #if defined(__APPLE__) && defined(__MACH__) #include @@ -27,7 +28,6 @@ #ifndef NOMINMAX # define NOMINMAX #endif -#include #include #include #include @@ -1528,6 +1528,7 @@ bool validate_file_name(const std::string & filename) { // - Unicode equivalents of illegal characters // - UTF-16 surrogate pairs // - UTF-8 replacement character + // - Byte order mark (BOM) // - Illegal characters: / \ : * ? " < > | for (char32_t c : filename_utf32) { if (c <= 0x1F // Control characters (C0) @@ -1538,6 +1539,7 @@ bool validate_file_name(const std::string & filename) { || c == 0x2216 // Set Minus (backslash equivalent) || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs || c == 0xFFFD // Replacement Character (UTF-8) + || c == 0xFEFF // Byte Order Mark (BOM) || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') { return false; From 2fbf0c34953deb0c4377bbc10a49a061e0a08e01 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 6 Apr 2024 02:43:13 +0800 Subject: [PATCH 32/33] also reject empty filename --- common/common.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index 8e9b57d706352..63fdb2399f305 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1500,7 +1500,13 @@ std::string gpt_random_prompt(std::mt19937 & rng) { GGML_UNREACHABLE(); } +// Validate if a filename is safe to use +// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function bool validate_file_name(const std::string & filename) { + if (!filename.length()) { + // Empty filename invalid + return false; + } if (filename.length() > 255) { // Limit at common largest possible filename on Linux filesystems // to avoid unnecessary further validation @@ -1546,7 +1552,7 @@ bool validate_file_name(const std::string & filename) { } } - // Reject any ".." (this is stricter than checking for ../ combinations) + // Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead) if (filename.find("..") != std::string::npos) { return false; } From bf94e9f788da5acd17c7744889f26ccc958ec914 Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Sat, 6 Apr 2024 03:14:39 +0800 Subject: [PATCH 33/33] reject whitespace and trailing dot --- common/common.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index 63fdb2399f305..7d983a453c68f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1552,6 +1552,12 @@ bool validate_file_name(const std::string & filename) { } } + // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename + // Unicode and other whitespace is not affected, only 0x20 space + if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') { + return false; + } + // Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead) if (filename.find("..") != std::string::npos) { return false;