Skip to content

Commit

Permalink
llama : save and restore kv cache for single seq id (#6341)
Browse files Browse the repository at this point in the history
* llama : save and restore kv cache for single seq id

* remove trailing whitespace

* respond error in case there's no space in the kv cache

* add kv seq save restore to test case

* add --slot-save-path arg to enable save restore and restrict save location

* Returning 0 for some cases, instead of asserting.

* cleanup error cases

* rename sequence state functions

* rename state get set functions

* add previous function names back in with DEPRECATED notice

* update doc

* adjust endpoints to preferred style

* fix restoring zero cell count

* handle seq rm return value

* unused param

* keep in the size check

* fix return types

* add server test case for slot save restore

* cleanup

* add cake

* cleanup style

* add special

* removing a whole sequence never fails

* move sequence state file functionality from server to llama to match session api and add version tags

* catch exceptions on save as well

* error log messages

* check types for stricter restore

* update server doc

* readme : update API changes date

* strict filename validation

* move include, reject bom as well

* also reject empty filename

* reject whitespace and trailing dot

---------

Co-authored-by: Martin Evans <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
3 people committed Apr 8, 2024
1 parent 87fb5b4 commit beea6e1
Show file tree
Hide file tree
Showing 11 changed files with 1,086 additions and 31 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)

### Recent API changes

- [2024 Apr 4] State and session file functions reorganized under `llama_state_*` https://github.com/ggerganov/llama.cpp/pull/6341
- [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122
- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
Expand Down
73 changes: 72 additions & 1 deletion common/common.cpp
Expand Up @@ -16,6 +16,7 @@
#include <unordered_set>
#include <vector>
#include <cinttypes>
#include <codecvt>

#if defined(__APPLE__) && defined(__MACH__)
#include <sys/types.h>
Expand All @@ -27,7 +28,6 @@
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <codecvt>
#include <locale>
#include <windows.h>
#include <fcntl.h>
Expand Down Expand Up @@ -1500,6 +1500,77 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
GGML_UNREACHABLE();
}

// Validate if a filename is safe to use
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
bool validate_file_name(const std::string & filename) {
if (!filename.length()) {
// Empty filename invalid
return false;
}
if (filename.length() > 255) {
// Limit at common largest possible filename on Linux filesystems
// to avoid unnecessary further validation
// (On systems with smaller limits it will be caught by the OS)
return false;
}

std::u32string filename_utf32;
try {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
filename_utf32 = converter.from_bytes(filename);

// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
// or invalid encodings were encountered. Reject such attempts
std::string filename_reencoded = converter.to_bytes(filename_utf32);
if (filename_reencoded != filename) {
return false;
}
} catch (const std::exception &) {
return false;
}

// Check for forbidden codepoints:
// - Control characters
// - Unicode equivalents of illegal characters
// - UTF-16 surrogate pairs
// - UTF-8 replacement character
// - Byte order mark (BOM)
// - Illegal characters: / \ : * ? " < > |
for (char32_t c : filename_utf32) {
if (c <= 0x1F // Control characters (C0)
|| c == 0x7F // Control characters (DEL)
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
|| c == 0xFF0E // Fullwidth Full Stop (period equivalent)
|| c == 0x2215 // Division Slash (forward slash equivalent)
|| c == 0x2216 // Set Minus (backslash equivalent)
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|| c == 0xFFFD // Replacement Character (UTF-8)
|| c == 0xFEFF // Byte Order Mark (BOM)
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
return false;
}
}

// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
// Unicode and other whitespace is not affected, only 0x20 space
if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') {
return false;
}

// Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead)
if (filename.find("..") != std::string::npos) {
return false;
}

// Reject "."
if (filename == ".") {
return false;
}

return true;
}

//
// String utils
//
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Expand Up @@ -179,6 +179,8 @@ std::string gpt_random_prompt(std::mt19937 & rng);

void process_escapes(std::string& input);

bool validate_file_name(const std::string & filename);

//
// String utils
//
Expand Down
6 changes: 3 additions & 3 deletions examples/main/main.cpp
Expand Up @@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
// The file exists and is not empty
session_tokens.resize(n_ctx);
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
return 1;
}
Expand Down Expand Up @@ -693,7 +693,7 @@ int main(int argc, char ** argv) {
// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
need_to_save_session = false;
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());

LOG("saved session to %s\n", path_session.c_str());
}
Expand Down Expand Up @@ -935,7 +935,7 @@ int main(int argc, char ** argv) {

if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
}

llama_print_timings(ctx);
Expand Down
101 changes: 95 additions & 6 deletions examples/save-load-state/save-load-state.cpp
Expand Up @@ -24,6 +24,7 @@ int main(int argc, char ** argv) {

std::string result0;
std::string result1;
std::string result2;

// init
llama_model * model;
Expand All @@ -44,8 +45,8 @@ int main(int argc, char ** argv) {

// save state (rng, logits, embedding and kv_cache) to file
{
std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
const size_t written = llama_copy_state_data(ctx, state_mem.data());
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
const size_t written = llama_state_get_data(ctx, state_mem.data());

FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
Expand Down Expand Up @@ -97,13 +98,13 @@ int main(int argc, char ** argv) {

// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));

FILE * fp_read = fopen("dump_state.bin", "rb");
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

if (read != llama_set_state_data(ctx2, state_mem.data())) {
if (read != llama_state_set_data(ctx2, state_mem.data())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
Expand Down Expand Up @@ -141,16 +142,104 @@ int main(int argc, char ** argv) {
n_past += 1;
}

printf("\n");
printf("\n\n");

llama_free(ctx2);
llama_free_model(model);

if (result0 != result1) {
fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
return 1;
}

// make new context
auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));

printf("\nsingle seq run: %s", params.prompt.c_str());

// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));

FILE * fp_read = fopen("dump_state.bin", "rb");
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

if (read != llama_state_set_data(ctx3, state_mem.data())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx3);
llama_free_model(model);
return 1;
}

fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
}

// restore state (last tokens)
n_past = n_past_saved;

// save seq 0 and load into seq 1
{
// save kv of seq 0
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
if (ncopy != seq_store.size()) {
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
llama_free(ctx3);
llama_free_model(model);
return 1;
}
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);

// erase whole kv
llama_kv_cache_clear(ctx3);
fprintf(stderr, "%s : kv cache cleared\n", __func__);

// restore kv into seq 1
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
if (nset != seq_store.size()) {
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
llama_free(ctx3);
llama_free_model(model);
return 1;
}
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
}

// third run with seq 1 instead of 0
for (auto i = 0; i < params.n_predict; i++) {
auto * logits = llama_get_logits(ctx3);
auto n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
auto next_token = llama_sample_token(ctx3, &candidates_p);
auto next_token_str = llama_token_to_piece(ctx3, next_token);

printf("%s", next_token_str.c_str());
result2 += next_token_str;

if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx3);
llama_free_model(model);
return 1;
}
n_past += 1;
}

printf("\n");

llama_free(ctx3);
llama_free_model(model);

if (result0 != result2) {
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
return 1;
}

fprintf(stderr, "\n%s : success\n", __func__);

return 0;
Expand Down
52 changes: 52 additions & 0 deletions examples/server/README.md
Expand Up @@ -57,6 +57,7 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/
- `-n N, --n-predict N`: Set the maximum tokens to predict. Default: `-1`
- `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included.
- `--metrics`: enable prometheus `/metrics` compatible endpoint. Default: disabled
- `--slot-save-path PATH`: Specifies the path where the state of slots (the prompt cache) can be stored. If not provided, the slot management endpoints will be disabled.
- `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
- `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled
- `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json`
Expand Down Expand Up @@ -517,6 +518,57 @@ Available metrics:
- `llamacpp:requests_processing`: Number of requests processing.
- `llamacpp:requests_deferred`: Number of requests deferred.

- **POST** `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.

*Options:*

`filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter.

### Result JSON

```json
{
"id_slot": 0,
"filename": "slot_save_file.bin",
"n_saved": 1745,
"n_written": 14309796,
"timings": {
"save_ms": 49.865
}
}
```

- **POST** `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file.

*Options:*

`filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter.

### Result JSON

```json
{
"id_slot": 0,
"filename": "slot_save_file.bin",
"n_restored": 1745,
"n_read": 14309796,
"timings": {
"restore_ms": 42.937
}
}
```

- **POST** `/slots/{id_slot}?action=erase`: Erase the prompt cache of the specified slot.

### Result JSON

```json
{
"id_slot": 0,
"n_erased": 1745
}
```

## More examples

### Change system prompt on runtime
Expand Down

0 comments on commit beea6e1

Please sign in to comment.