Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : greatly reduce output buffer memory usage #6122

Merged
merged 26 commits into from Mar 26, 2024

Conversation

compilade
Copy link
Collaborator

@compilade compilade commented Mar 17, 2024

Supersedes #2700.

As I've noted in #6017 (comment), the logits buffer is way too big and mostly unused when the batch size is very large. This PR fixes this waste of memory by introducing another, smaller buffer of ids (ctx->output_ids) which point into the logits and embeddings buffers, allowing to keep most of the behavior of llama_get_logits_ith() and llama_get_embeddings_ith() while making the output buffer's content contiguous. While I was at it, I've also noticed it was relatively easy to skip computing unused logits with the new output buffer layout introduced in this change.

From @slaren's suggestion in #6017 (comment), this PR allocates space for n_seq_max outputs, then dynamically reallocates the output buffer when more logits and/or embeddings are necessary.
(Thanks for making me think to look for #2700 when alluding to it in #6017 (comment))

API changes

  • llama_get_logits and llama_get_embeddings now return contiguous arrays (still float *).
    • The logits and/or the embeddings for which llama_batch.logits[i] == true are stored contiguously in the order they have in the batch.
    • The layout is the same as before if logits_all == true and when using llama_batch_get_one() (which doesn't set llama_batch.logits).
    • Only the necessary logits and embeddings are computed, so reducing the use of logits_all now has a slight performance incentive
  • llama_get_logits_ith and llama_get_embeddings_ith now always verify if the passed id is valid.
    • The result was undefined in some cases before.
    • Note that the perplexity example previously relied on the non-verification of this (assertion error when running with a Debug build on master), which has been fixed. (in other words, this fixes perplexity assert fails #6246)
  • The session file format changed slightly, and is now much smaller (a lot of space for unused logits was wasted before). 🎉
    • For example, storing the state of Tinyllama with a batch size of 512 previously used at least 63 MiB, but now the equivalent session file takes only 1 MiB (mostly composed of the used KV cache cells from my test prompt, it now only includes the needed logits from the last batch), regardless of the batch size, whereas a batch size of 1024 previously would use around 126 MiB with Tinyllama.
    • It's now possible to load a session file with a different context size than what it was saved with, as long as the new context size is at least enough to contain the state for the saved tokens.

Notes

The perplexity example used the previous layout of the logits very extensively.
I've adapted most of it, but the parts of it which still use logits_all (even though this still works) will need to be changed in the future to benefit from skipping the computation of unused logits.

The perplexity example should have the exact same output as master when compiled with the same flags. Well, except with Winogrande, because I've modified the implementation. (EDIT: Winogrande output should be the same as master as of 8f70dcb.)
(note to @ikawrakow: the logic for skipping the choice words in winogrande_score() shouldn't be required, but I still kept it. so I've simplified it by also fixing the logits used when evaluating the second choice (previously, the logits of the end of the first choice were used there instead of the logits of the end of the common prefix, which caused a big skew in the log-likelyhood of some second choices). This will need to be fixed in a separate PR.)

TODO

Since a (small) model-specific change was required for each of the 23+ architectures supported by llama.cpp, I'd like to at least ensure I didn't break any of them. I'd really like to know if it works on GPU and/or with MoE models.

Feel free to edit the list below to add more tested models and examples.

Compare the output with --temp 0 when using this PR vs master.

  • Mamba (on CPU: main, parallel (with and without -cb), perplexity (v1, v2, hellaswag, multiple-choices), imatrix, save-load-state, embedding)
  • LLama (with Tinyllama (on CPU: speculative (with Llama-160M as a draft model), save-load-state))
    • EDIT (2024-03-19): seems to differ now, investigating... False alarm, it's because repetition penalty is now disabled by default in master since common : disable repeat penalties by default #6127, so for the outputs to be the same, --repeat-penalty 1 now has to be passed.
  • Phi-2 (on CPU: main, save-load-state)
  • GritLM
  • BERT (from the CI test-suite: server)
  • Any Moe Model (Mixtral (with HipBLAS: llama-bench, main))
  • (17 other model types (TODO: expand this list))

Known issues:

  • server with embeddings
  • Some backends don't support GGML_OP_GET_ROWS and don't fallback properly, (e.g. the Vulkan backend)

The first logits used to evaluate the second choice were not from
the end of the common prefix; instead, they were the logits from the end
of the first choice. This has been corrected.

The previous implementation sometimes had outliers in the scores of
choices for some tasks, and the logic to skip choices words
in the log-likelihood evaluation probably was an attempt to reduce those,
but it was complex and didn't quite seem to be the right thing.

This is simpler now, and the outlier scores aren't there anymore.
A mismatch happened when using a smaller n_ubatch than n_batch and then using
llama_batch_get_one(). The decision of what n_outputs should be now almost
fully depends on how lctx.n_outputs is set in llama_decode_internal.
The conditions are simpler this way.

* llama : when saving the state, recalculate n_outputs

This ensures the correct number of outputs for the entire previous batch
is stored in the session file, even when n_ubatch is smaller than n_batch.
llama.cpp Outdated Show resolved Hide resolved
@slaren
Copy link
Collaborator

slaren commented Mar 18, 2024

As it is, this breaks pipeline parallelism because changes in the graph topology force a synchronization. I think it should be possible to fix this if the final get_rows is done unconditionally.

You can test this without multiple GPUs by building in debug with LLAMA_DEBUG, and using llama-bench with n_batch > n_ubatch, or perplexity with multple sequences. If you get ggml_backend_sched_alloc_splits: failed to allocate graph, reserving messages on every eval, then it is breaking pipeline parallelism.

It previously worked because lctx.inp_out_ids was not initialized,
so it pointed to some garbage address which was somehow still valid when I
ran my tests.
@compilade
Copy link
Collaborator Author

As it is, this breaks pipeline parallelism because changes in the graph topology force a synchronization. I think it should be possible to fix this if the final get_rows is done unconditionally.

Yes, this should be possible. I initially put the skip of the rest of the graph there because tensors with 0 rows caused division by zero problems in ggml_can_repeat. But working around this could allow keeping the same graphs (assuming no other problems, like more divisions by zero elsewhere).

You can test this without multiple GPUs by building in debug with LLAMA_DEBUG, and using llama-bench with n_batch > n_ubatch, or perplexity with multple sequences. If you get ggml_backend_sched_alloc_splits: failed to allocate graph, reserving messages on every eval, then it is breaking pipeline parallelism.

I'll see if I can try this on my Intel UHD Graphics 615 (since splits don't seem to happen on CPU-only).
But on CPU, when using the parallel example (and also perplexity with multiple sequences), I do see a bunch of

ggml_gallocr_needs_realloc: graph has different number of nodes
ggml_gallocr_alloc_graph: reallocating buffers automatically

which may be problematic when using a GPU.

I'll try to avoid changing the graph topology.

@slaren
Copy link
Collaborator

slaren commented Mar 18, 2024

I'll see if I can try this on my Intel UHD Graphics 615 (since splits don't seem to happen on CPU-only).

Yes, I forgot that the reallocation is automatic if there is only one backend, but that message also shows the problem. If that message disappears, then it should also work with pipeline parallelism.

@fgdfgfthgr-fox
Copy link

fgdfgfthgr-fox commented Mar 18, 2024

Test result using Radeon VII HipBLAS build under Linux environment.
I don't see any significant decrease in VRAM use though. In what case should this pr 's change to be significant?
(please ignore the koboldcpp file path. It's running using llama.cpp, I just store my model weight under koboldcpp's file path)
master:

$ ./llama-bench -m /mnt/2878EBCCAED823C6/koboldcpp-rocm/mixtral/mixtral-8x7b-iq3_xxs.gguf -ngl 22
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 ROCm devices:
  Device 0: AMD Radeon VII, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------------: |
| llama 7B IQ3_XXS - 3.0625 bpw  |  16.99 GiB |    46.70 B | ROCm       |  22 | pp 512     |     87.74 ± 0.11 |
| llama 7B IQ3_XXS - 3.0625 bpw  |  16.99 GiB |    46.70 B | ROCm       |  22 | tg 128     |      7.55 ± 0.21 |

build: d01b3c4c (2450)
Peak VRAM Used: 13376MB

$ ./main -ngl 22 -m /mnt/2878EBCCAED823C6/koboldcpp-rocm/mixtral/mixtral-8x7b-iq3_xxs.gguf -n 64 --temp 0  -p "In a land before time"
...
generate: n_ctx = 512, n_batch = 2048, n_predict = 64, n_keep = 1


 In a land before time, when the internet was still in its infancy and social media didn’t exist, there were only two ways to get your message out: print or broadcast.

Print is great for reaching people who are already interested in what you have to say – but it can be expensive and hard to measure results. Broad
llama_print_timings:        load time =    2615.59 ms
llama_print_timings:      sample time =       7.09 ms /    64 runs   (    0.11 ms per token,  9026.80 tokens per second)
llama_print_timings: prompt eval time =     715.49 ms /     6 tokens (  119.25 ms per token,     8.39 tokens per second)
llama_print_timings:        eval time =    8572.94 ms /    63 runs   (  136.08 ms per token,     7.35 tokens per second)
llama_print_timings:       total time =    9315.96 ms /    69 tokens

pr:

$ ./llama-bench -m /mnt/2878EBCCAED823C6/koboldcpp-rocm/mixtral/mixtral-8x7b-iq3_xxs.gguf -ngl 22
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 ROCm devices:
  Device 0: AMD Radeon VII, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------------: |
| llama 7B IQ3_XXS - 3.0625 bpw  |  16.99 GiB |    46.70 B | ROCm       |  22 | pp 512     |     90.56 ± 0.13 |
| llama 7B IQ3_XXS - 3.0625 bpw  |  16.99 GiB |    46.70 B | ROCm       |  22 | tg 128     |      7.41 ± 0.13 |

build: 6bf7f3f4 (2464)
Peak VRAM Used: 13325MB

$ ./main -ngl 22 -m /mnt/2878EBCCAED823C6/koboldcpp-rocm/mixtral/mixtral-8x7b-iq3_xxs.gguf -n 64 --temp 0  -p "In a land before time"
...
generate: n_ctx = 512, n_batch = 2048, n_predict = 64, n_keep = 1


 In a land before time, when the internet was still in its infancy and social media didn’t exist, there were only two ways to get your message out: print or broadcast.

Print is great for reaching people who are already interested in what you have to say – but it can be expensive and hard to measure results. Broad
llama_print_timings:        load time =    2659.17 ms
llama_print_timings:      sample time =       7.00 ms /    64 runs   (    0.11 ms per token,  9148.08 tokens per second)
llama_print_timings: prompt eval time =     551.66 ms /     6 tokens (   91.94 ms per token,    10.88 tokens per second)
llama_print_timings:        eval time =    8741.20 ms /    63 runs   (  138.75 ms per token,     7.21 tokens per second)
llama_print_timings:       total time =    9320.75 ms /    69 tokens

@Dampfinchen
Copy link

That sounds really promising! However, when trying to run this PR, I'm getting this error:

fixed or limited cutCUDA error: invalid configuration argument
  current device: 0, in function ggml_cuda_op_flatten at /userdir/llama.cpp memory/llama.cpp/ggml-cuda.cu:9149
  cudaGetLastError()
GGML_ASSERT: /userdir/llama.cpp/ggml-cuda.cu:267: !"CUDA error"
Abgebrochen (Speicherabzug geschrieben)

I was compiling it with CMake, Cuda Toolkit 12.3. RTX 2060, Core i7 9750H, 32 GB.

Does this PR only support llama or its derivatives like Mistral as well?

My model was Fimbulvetr-v2 at IQ4_XS, which is a Solar 10.7B merge.

ggml.c Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
* llama : rework reallocation logic for llama_output_reserve

Now comparing the actual size with the new total size of the output buffer
to allow more efficient enabling and disabling of the embeddings
and/or logits output in the future.
@compilade
Copy link
Collaborator Author

compilade commented Mar 19, 2024

Answering 2 messages at once:

I don't see any significant decrease in VRAM use though.

@fgdfgfthgr-fox The expected decrease is at most n_vocab*(n_batch - 1)*sizeof(float) bytes, so for Mixtral, if its n_vocab is around 50000 and if using n_batch == 512, the maximum decrease will be 97 MiB, while with llama-bench, you're seeing a decrease of 50 MiB, which seems reasonable. BTW, thanks for testing this with Mixtral :)

In what case should this pr 's change to be significant?

It should be significant when using a very large batch size. For example, with -b 8192, this PR could save up to 1.5 GiB of memory (or 3.05 GiB of memory with -b 16384, etc.) compared to master. (EDIT: this should only affect RAM, not VRAM, because the output buffer is using a CPU buffer (ref: #6122 (comment)). This still means bigger logical batch sizes will be usable without wasting lots of RAM, hopefully making pipeline parallelism more scalable.)
With llama-bench, I think this can be tested by adding the flags -b 8192 -p 8192, while with main, -b 8192 -c 8192 could be used (to ensure the batch size won't be clamped by the context size). Of course, if the KV cache (unaffected by this PR) is too big for your system, you can use smaller numbers.


That sounds really promising! However, when trying to run this PR, I'm getting this error:

@Dampfinchen Thanks for testing this with CUDA. A more complete backtrace would be helpful1, but in this case this is probably happening because I didn't yet make the other backends than CPU skip empty tensors, so what you're seeing is likely the symptoms of a GPU-accelerated division by zero. Hopefully fixed by 8b826c5. This didn't happen to @fgdfgfthgr-fox because they didn't offload enough layers for the last one to be offloaded (which can now have tensors with no elements when no logits are used, causing problems when dividing dimensions with each other if not skipped).

Does this PR only support llama or its derivatives like Mistral as well?

The goal is for this to support every of the 23+ model architectures supported by llama.cpp. In theory, they should all work, but maybe there's something I overlooked somewhere2, which is why I encourage testing models of different types and different backends. Mistral should work, and from #6122 (comment), Mixtral too.

Footnotes

  1. Apparently, GGML_ASSERT prints full backtraces when gdb is installed on Linux and Unix-like systems. I didn't know for a long time because the only debugger I had already installed was lldb. Knowing this (or coredumpctl debug --debugger=lldb, followed by bt) two months ago would have saved me a lot of printf-based debugging.

  2. like the Vulkan backend not properly supporting GGML_GET_ROWS... Probably wasn't too important before because ggml_get_rows was mostly used on input tensors which are always using the CPU backend. (but ggml_get_rows was also used for MoE inference, so backends on which Mixtral worked likely do not have this problem)

@slaren
Copy link
Collaborator

slaren commented Mar 19, 2024

Works well with CUDA, improves pp performance by 2-3% with a single GPU.

GPU Model Model Size [GiB] Test t/s master t/s PR Speedup
RTX 3090 Ti llama 7B F16 12.55 pp512 5288.37 5378.25 1.02
RTX 3090 Ti llama 7B F16 12.55 pp1024 5031.00 5144.94 1.02
RTX 3090 Ti llama 7B F16 12.55 pp2048 4600.25 4710.85 1.02
RTX 3090 Ti llama 7B F16 12.55 pp4096 3956.06 4032.93 1.02
RTX 3090 Ti llama 7B Q4_0 3.56 pp512 4395.17 4532.78 1.03
RTX 3090 Ti llama 7B Q4_0 3.56 pp1024 4190.07 4316.03 1.03
RTX 3090 Ti llama 7B Q4_0 3.56 pp2048 3904.58 4011.76 1.03
RTX 3090 Ti llama 7B Q4_0 3.56 pp4096 3429.24 3512.89 1.02
RTX 3090 Ti llama 7B F16 12.55 tg128 54.76 54.81 1.00
RTX 3090 Ti llama 7B Q4_0 3.56 tg128 126.81 126.19 1.00
GPU Model Model Size [GiB] Test t/s master t/s PR Speedup
RTX 3090 Ti/RTX 3080 llama 7B F16 12.55 pp512 4726.59 4861.25 1.03
RTX 3090 Ti/RTX 3080 llama 7B F16 12.55 pp1024 5724.30 5759.07 1.01
RTX 3090 Ti/RTX 3080 llama 7B F16 12.55 pp2048 6146.06 6129.15 1.00
RTX 3090 Ti/RTX 3080 llama 7B F16 12.55 pp4096 5783.30 5803.46 1.00
RTX 3090 Ti/RTX 3080 llama 7B Q4_0 3.56 pp512 3949.96 3869.63 0.98
RTX 3090 Ti/RTX 3080 llama 7B Q4_0 3.56 pp1024 4825.07 4909.32 1.02
RTX 3090 Ti/RTX 3080 llama 7B Q4_0 3.56 pp2048 5275.33 5319.94 1.01
RTX 3090 Ti/RTX 3080 llama 7B Q4_0 3.56 pp4096 5026.15 5091.23 1.01
RTX 3090 Ti/RTX 3080 llama 7B F16 12.55 tg128 48.34 47.83 0.99
RTX 3090 Ti/RTX 3080 llama 7B Q4_0 3.56 tg128 108.05 108.19 1.00

@slaren
Copy link
Collaborator

slaren commented Mar 19, 2024

Note that the output buffer is always allocated in a CPU buffer, so this shouldn't affect VRAM usage.

@Dampfinchen
Copy link

@compilade I can confirm the issue I've had has been fixed. Good work!

Notably includes the new repetition penalty default, support for grok-1,
and support for split GGUF.
Copy link
Collaborator

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only blocker for this PR is the lack of support of get_rows in the Vulkan backend. It is not clear when this will be implemented and we cannot delay this infinitely, so I think this should be merged before more conflicts arise.

When loading a session file, the context size is now only required to be
at least enough to load the KV cells contained in that session file,
instead of requiring to use exactly the same context size as when saving.

Doing this enables the use-case of extending or shrinking the context size
of a saved session.

This breaks existing session files because the meaning of kv_buf_size
is slightly changed (previously it was the size of the whole KV cache,
now it's only the size of the saved part of it). This allows for
finer-grained sanity checks when loading in an effort to keep kv_buf_size
useful even when the kv_size is changed.
@0cc4m
Copy link
Collaborator

0cc4m commented Mar 26, 2024

The only blocker for this PR is the lack of support of get_rows in the Vulkan backend. It is not clear when this will be implemented and we cannot delay this infinitely, so I think this should be merged before more conflicts arise.

I've got that mostly done this weekend, but didn't have the time to figure out the last bugs in the implementation.

ggml-ci
Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge and finish the Vulkan kernels from master. @0cc4m if it takes longer to resolve the issues, we can put a notice in the README with the last stable version (i.e. before this PR) until it is ready

@compilade compilade force-pushed the compilade/smaller-output-buffer branch from 457d900 to 20248e8 Compare March 26, 2024 14:33
@ggerganov ggerganov merged commit 557410b into master Mar 26, 2024
51 of 57 checks passed
@ggerganov ggerganov deleted the compilade/smaller-output-buffer branch March 26, 2024 14:46
hxer7963 pushed a commit to hxer7963/llama.cpp that referenced this pull request Mar 27, 2024
@ikawrakow
Copy link
Contributor

This PR breaks MoE models. Well, I don't know if all MoE models, but for sure Mixtral8x7B.

Everything is fine until 55c1b2a. Then

git co 557410b8f06380560155ac7fcb8316d71ddc9837 && make -j

./perplexity -m some_model -f some_test -t 1 -ngl 100
main: build = 2540 (557410b8)
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: seed  = 1711546930
llama_model_loader: loaded meta data with 25 key-value pairs and 995 tensors from jun2.bin (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = hf
llama_model_loader: - kv   2:                       llama.context_length u32              = 32768
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv   9:                         llama.expert_count u32              = 8
llama_model_loader: - kv  10:                    llama.expert_used_count u32              = 2
llama_model_loader: - kv  11:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  12:                       llama.rope.freq_base f32              = 1000000.000000
llama_model_loader: - kv  13:                          general.file_type u32              = 24
llama_model_loader: - kv  14:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  15:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  16:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  17:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  18:                      tokenizer.ggml.merges arr[str,58980]   = ["▁ t", "i n", "e r", "▁ a", "h e...
llama_model_loader: - kv  19:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  20:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  21:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  22:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  23:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  24:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type  f16:   32 tensors
llama_model_loader: - type q2_K:   33 tensors
llama_model_loader: - type q4_K:   64 tensors
llama_model_loader: - type q5_K:   33 tensors
llama_model_loader: - type iq1_s:  768 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 32768
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 8
llm_load_print_meta: n_expert_used    = 2
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 1000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 32768
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 7B
llm_load_print_meta: model ftype      = IQ1_S - 1.5625 bpw
llm_load_print_meta: model params     = 46.70 B
llm_load_print_meta: model size       = 9.14 GiB (1.68 BPW) 
llm_load_print_meta: general.name     = hf
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4080, compute capability 8.9, VMM: yes
llm_load_tensors: ggml ctx size =    0.76 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CPU buffer size =    41.02 MiB
llm_load_tensors:      CUDA0 buffer size =  9322.95 MiB
....................................................................................................
llama_new_context_with_model: n_ctx      = 4096
llama_new_context_with_model: n_batch    = 4096
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: freq_base  = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =   512.00 MiB
llama_new_context_with_model: KV self size  =  512.00 MiB, K (f16):  256.00 MiB, V (f16):  256.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.49 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   316.04 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    16.01 MiB
llama_new_context_with_model: graph nodes  = 1670
llama_new_context_with_model: graph splits = 2

system_info: n_threads = 1 / 64 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | 
perplexity: tokenizing the input ..
perplexity: tokenization took 658.483 ms
perplexity: calculating perplexity over 80 chunks, n_ctx=4096, batch_size=4096, n_seq=1
GGML_ASSERT: /home/iwan/other/llama.cpp/ggml.c:2941: view_src == NULL || data_size + view_offs <= ggml_nbytes(view_src)
Could not attach to process.  If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user.  For more details, see /etc/sysctl.d/10-ptrace.conf
ptrace: Operation not permitted.
No stack.
The program is not being run.
Aborted (core dumped)

@Wuzzooy
Copy link

Wuzzooy commented Mar 28, 2024

command-r seems broken, it works until build b2536.
I'm getting "GGML_ASSERT: D:\a\llama.cpp\llama.cpp\ggml.c:3648: ggml_can_repeat(b, a)"

INFO [ server_params_parse] logging to file is disabled. | tid="196984" timestamp=1711615937
INFO [ main] build info | tid="196984" timestamp=1711615937 build=2554 commit="25f4a613"
INFO [ main] system info | tid="196984" timestamp=1711615937 n_threads=7 n_threads_batch=-1 total_threads=8 system_info="AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 | "
llama_model_loader: loaded meta data with 23 key-value pairs and 322 tensors from G:\AI\models\c4ai-command-r-v01-Q4_K_S.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = command-r
llama_model_loader: - kv 1: general.name str = c4ai-command-r-v01
llama_model_loader: - kv 2: command-r.block_count u32 = 40
llama_model_loader: - kv 3: command-r.context_length u32 = 131072
llama_model_loader: - kv 4: command-r.embedding_length u32 = 8192
llama_model_loader: - kv 5: command-r.feed_forward_length u32 = 22528
llama_model_loader: - kv 6: command-r.attention.head_count u32 = 64
llama_model_loader: - kv 7: command-r.attention.head_count_kv u32 = 64
llama_model_loader: - kv 8: command-r.rope.freq_base f32 = 8000000.000000
llama_model_loader: - kv 9: command-r.attention.layer_norm_epsilon f32 = 0.000010
llama_model_loader: - kv 10: general.file_type u32 = 14
llama_model_loader: - kv 11: command-r.logit_scale f32 = 0.062500
llama_model_loader: - kv 12: command-r.rope.scaling.type str = none
llama_model_loader: - kv 13: tokenizer.ggml.model str = gpt2
llama_model_loader: - kv 14: tokenizer.ggml.tokens arr[str,256000] = ["", "", "", "", ...
llama_model_loader: - kv 15: tokenizer.ggml.token_type arr[i32,256000] = [3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 1, 1, ...
llama_model_loader: - kv 16: tokenizer.ggml.merges arr[str,253333] = ["─á ─á", "─á t", "e r", "i n", "─á a...
llama_model_loader: - kv 17: tokenizer.ggml.bos_token_id u32 = 5
llama_model_loader: - kv 18: tokenizer.ggml.eos_token_id u32 = 255001
llama_model_loader: - kv 19: tokenizer.ggml.padding_token_id u32 = 0
llama_model_loader: - kv 20: tokenizer.ggml.add_bos_token bool = true
llama_model_loader: - kv 21: tokenizer.ggml.add_eos_token bool = false
llama_model_loader: - kv 22: general.quantization_version u32 = 2
llama_model_loader: - type f32: 41 tensors
llama_model_loader: - type q4_K: 271 tensors
llama_model_loader: - type q5_K: 9 tensors
llama_model_loader: - type q6_K: 1 tensors
llm_load_vocab: special tokens definition check successful ( 1008/256000 ).
llm_load_print_meta: format = GGUF V3 (latest)
llm_load_print_meta: arch = command-r
llm_load_print_meta: vocab type = BPE
llm_load_print_meta: n_vocab = 256000
llm_load_print_meta: n_merges = 253333
llm_load_print_meta: n_ctx_train = 131072
llm_load_print_meta: n_embd = 8192
llm_load_print_meta: n_head = 64
llm_load_print_meta: n_head_kv = 64
llm_load_print_meta: n_layer = 40
llm_load_print_meta: n_rot = 128
llm_load_print_meta: n_embd_head_k = 128
llm_load_print_meta: n_embd_head_v = 128
llm_load_print_meta: n_gqa = 1
llm_load_print_meta: n_embd_k_gqa = 8192
llm_load_print_meta: n_embd_v_gqa = 8192
llm_load_print_meta: f_norm_eps = 1.0e-05
llm_load_print_meta: f_norm_rms_eps = 0.0e+00
llm_load_print_meta: f_clamp_kqv = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale = 6.2e-02
llm_load_print_meta: n_ff = 22528
llm_load_print_meta: n_expert = 0
llm_load_print_meta: n_expert_used = 0
llm_load_print_meta: causal attn = 1
llm_load_print_meta: pooling type = 0
llm_load_print_meta: rope type = 0
llm_load_print_meta: rope scaling = none
llm_load_print_meta: freq_base_train = 8000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx = 131072
llm_load_print_meta: rope_finetuned = unknown
llm_load_print_meta: ssm_d_conv = 0
llm_load_print_meta: ssm_d_inner = 0
llm_load_print_meta: ssm_d_state = 0
llm_load_print_meta: ssm_dt_rank = 0
llm_load_print_meta: model type = 35B
llm_load_print_meta: model ftype = Q4_K - Small
llm_load_print_meta: model params = 34.98 B
llm_load_print_meta: model size = 18.97 GiB (4.66 BPW)
llm_load_print_meta: general.name = c4ai-command-r-v01
llm_load_print_meta: BOS token = 5 '<BOS_TOKEN>'
llm_load_print_meta: EOS token = 255001 '<|END_OF_TURN_TOKEN|>'
llm_load_print_meta: PAD token = 0 ''
llm_load_print_meta: LF token = 136 'Ä'
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 2 CUDA devices:
Device 0: NVIDIA GeForce RTX 4070 Ti, compute capability 8.9, VMM: yes
Device 1: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6, VMM: yes
llm_load_tensors: ggml ctx size = 0.37 MiB
llm_load_tensors: offloading 40 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 41/41 layers to GPU
llm_load_tensors: CPU buffer size = 1640.62 MiB
llm_load_tensors: CUDA0 buffer size = 5434.38 MiB
llm_load_tensors: CUDA1 buffer size = 13989.53 MiB
.......................................................................................
llama_new_context_with_model: n_ctx = 4096
llama_new_context_with_model: n_batch = 2048
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: freq_base = 8000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: CUDA0 KV buffer size = 1536.00 MiB
llama_kv_cache_init: CUDA1 KV buffer size = 3584.00 MiB
llama_new_context_with_model: KV self size = 5120.00 MiB, K (f16): 2560.00 MiB, V (f16): 2560.00 MiB
llama_new_context_with_model: CUDA_Host output buffer size = 1.95 MiB
llama_new_context_with_model: pipeline parallelism enabled (n_copies=4)
llama_new_context_with_model: CUDA0 compute buffer size = 672.01 MiB
llama_new_context_with_model: CUDA1 compute buffer size = 672.02 MiB
llama_new_context_with_model: CUDA_Host compute buffer size = 48.02 MiB
llama_new_context_with_model: graph nodes = 1247
llama_new_context_with_model: graph splits = 3
[1711615946] warming up the model with an empty run
INFO [ init] initializing slots | tid="196984" timestamp=1711615946 n_slots=1
INFO [ init] new slot | tid="196984" timestamp=1711615946 id_slot=0 n_ctx_slot=4096
INFO [ main] model loaded | tid="196984" timestamp=1711615946
INFO [ main] chat template | tid="196984" timestamp=1711615946 chat_example="<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nHow are you?<|im_end|>\n<|im_start|>assistant\n" built_in=true
INFO [ main] HTTP server listening | tid="196984" timestamp=1711615946 hostname="127.0.0.1" port="8080" n_threads_http="7"
INFO [ update_slots] all slots are idle | tid="196984" timestamp=1711615946
INFO [ launch_slot_with_task] slot is processing task | tid="196984" timestamp=1711615976 id_slot=0 id_task=0
INFO [ update_slots] kv cache rm [p0, end) | tid="196984" timestamp=1711615976 id_slot=0 id_task=0 p0=0
GGML_ASSERT: D:\a\llama.cpp\llama.cpp\ggml.c:3648: ggml_can_repeat(b, a)

@compilade
Copy link
Collaborator Author

command-r seems broken, it works until build b2536.
I'm getting "GGML_ASSERT: D:\a\llama.cpp\llama.cpp\ggml.c:3648: ggml_can_repeat(b, a)"

@Wuzzooy I'm truly sorry about that, it seems the model's graph was structured a bit differently than the other models (inpSA was named ffn_inp instead), and I overlooked a tensor which should have gone through the new inp_out_ids dance.

Should be fixed by #6367 (hopefully).

slaren pushed a commit that referenced this pull request Mar 29, 2024
* Support xverse model convert to gguf format.

* 1. Convert xverse models to gguf;
2. Add LLM_ARCH_XVERSE inference in llama.cpp;
3. Add xverse item in Supported models in README.md;

* * gguf-py: remove redundant logs
* llama: remove the init_mapping_prefetch custom parameter

* llama.cpp: Include the changes from #6122 to exclude the unused outputs of the last layers.

* - Fix format issues
- Remove duplicate set kqv_out to llm_build_kv

* Update llama.cpp

---------

Co-authored-by: willhe <[email protected]>
Co-authored-by: willhe <[email protected]>
woachk added a commit to woachk/llama.cpp that referenced this pull request Mar 31, 2024
op_getrows_f32 is required since ggerganov#6122
for the Vulkan w/ Kompute backend to be functional.

As such, implement this op to make this backend functional again.
woachk added a commit to woachk/llama.cpp that referenced this pull request Mar 31, 2024
op_getrows_f32 is required since ggerganov#6122
for the Vulkan w/ Kompute backend to be functional.

As such, implement this op to make this backend functional again.
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* llama : greatly reduce logits memory usage

* llama : more compact state saving and reloading

* llama : fix lctx.n_outputs not being set before building graph

* perplexity : adapt to the logits API changes

* perplexity : fix Winogrande, use correct logits for second choice start

The first logits used to evaluate the second choice were not from
the end of the common prefix; instead, they were the logits from the end
of the first choice. This has been corrected.

The previous implementation sometimes had outliers in the scores of
choices for some tasks, and the logic to skip choices words
in the log-likelihood evaluation probably was an attempt to reduce those,
but it was complex and didn't quite seem to be the right thing.

This is simpler now, and the outlier scores aren't there anymore.

* perplexity : normalize spaces and punctuation in Winogrande sentences

* llama : fix embedding conditions

* llama : fix llama_get_embeddings_ith when the resulting id is 0

* llama : fix wrong n_outputs in llama_set_inputs

A mismatch happened when using a smaller n_ubatch than n_batch and then using
llama_batch_get_one(). The decision of what n_outputs should be now almost
fully depends on how lctx.n_outputs is set in llama_decode_internal.
The conditions are simpler this way.

* llama : when saving the state, recalculate n_outputs

This ensures the correct number of outputs for the entire previous batch
is stored in the session file, even when n_ubatch is smaller than n_batch.

* llama : fix not-skipping outputs of non-causal models

* llama : fix running a batch with n_outputs == 0

It previously worked because lctx.inp_out_ids was not initialized,
so it pointed to some garbage address which was somehow still valid when I
ran my tests.

* llama : keep same graph topology even when n_outputs == 0

* ggml : saner ggml_can_repeat with empty tensors

*  ggml : future-proof ggml_is_empty by using GGML_MAX_DIMS - 1

* ggml : do not multi-thread ops returning empty tensors

* ggml : make ggml_is_empty public and work with views

* llama : use a vector for ctx->output_ids

* llama : rework reallocation logic for llama_output_reserve

Now comparing the actual size with the new total size of the output buffer
to allow more efficient enabling and disabling of the embeddings
and/or logits output in the future.

* ggml : skip empty tensors in all backends

* llama : fix llama_output_reserve nullptr deref when new_size is 0

* perplexity : make Winogrande work as it does on master

The problems with the Winogrande implementation will
need to be fixed in a separate PR to ease review.

* llama : clearer error messages for invalid logits or embeddings ids

* llama : assert all models that can have inp_out_ids

Since the graph topology is now constant, this presence check
can be done even when there are no outputs.

* llama : assert logits and embd buffers exist before writing to them

* llama : handle errors from llama_output_reserve at call sites

* perplexity : make hellaswag and multiple-choice outputs identical to master

Due to how the KV cache is updated, the logprobs for tokens in a batch
are very slightly affected by the other tokens present in the batch,
so to make hellaswag and multiple-choice return exactly the same results
as on master, the last token of each sequence needs to be evaluated
even though its output is not used at all.

This will probably be changed back in the future to make these benchmarks
a tiny bit faster.

* perplexity : fix division by zero when using less than 100 multiple-choice tasks

* llama : allow loading state saved with a different ctx size

When loading a session file, the context size is now only required to be
at least enough to load the KV cells contained in that session file,
instead of requiring to use exactly the same context size as when saving.

Doing this enables the use-case of extending or shrinking the context size
of a saved session.

This breaks existing session files because the meaning of kv_buf_size
is slightly changed (previously it was the size of the whole KV cache,
now it's only the size of the saved part of it). This allows for
finer-grained sanity checks when loading in an effort to keep kv_buf_size
useful even when the kv_size is changed.

* llama : minor

ggml-ci

* readme : update recent API changes, and warn about Vulkan

---------

Co-authored-by: Georgi Gerganov <[email protected]>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* Support xverse model convert to gguf format.

* 1. Convert xverse models to gguf;
2. Add LLM_ARCH_XVERSE inference in llama.cpp;
3. Add xverse item in Supported models in README.md;

* * gguf-py: remove redundant logs
* llama: remove the init_mapping_prefetch custom parameter

* llama.cpp: Include the changes from ggerganov#6122 to exclude the unused outputs of the last layers.

* - Fix format issues
- Remove duplicate set kqv_out to llm_build_kv

* Update llama.cpp

---------

Co-authored-by: willhe <[email protected]>
Co-authored-by: willhe <[email protected]>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 3, 2024
* llama : greatly reduce logits memory usage

* llama : more compact state saving and reloading

* llama : fix lctx.n_outputs not being set before building graph

* perplexity : adapt to the logits API changes

* perplexity : fix Winogrande, use correct logits for second choice start

The first logits used to evaluate the second choice were not from
the end of the common prefix; instead, they were the logits from the end
of the first choice. This has been corrected.

The previous implementation sometimes had outliers in the scores of
choices for some tasks, and the logic to skip choices words
in the log-likelihood evaluation probably was an attempt to reduce those,
but it was complex and didn't quite seem to be the right thing.

This is simpler now, and the outlier scores aren't there anymore.

* perplexity : normalize spaces and punctuation in Winogrande sentences

* llama : fix embedding conditions

* llama : fix llama_get_embeddings_ith when the resulting id is 0

* llama : fix wrong n_outputs in llama_set_inputs

A mismatch happened when using a smaller n_ubatch than n_batch and then using
llama_batch_get_one(). The decision of what n_outputs should be now almost
fully depends on how lctx.n_outputs is set in llama_decode_internal.
The conditions are simpler this way.

* llama : when saving the state, recalculate n_outputs

This ensures the correct number of outputs for the entire previous batch
is stored in the session file, even when n_ubatch is smaller than n_batch.

* llama : fix not-skipping outputs of non-causal models

* llama : fix running a batch with n_outputs == 0

It previously worked because lctx.inp_out_ids was not initialized,
so it pointed to some garbage address which was somehow still valid when I
ran my tests.

* llama : keep same graph topology even when n_outputs == 0

* ggml : saner ggml_can_repeat with empty tensors

*  ggml : future-proof ggml_is_empty by using GGML_MAX_DIMS - 1

* ggml : do not multi-thread ops returning empty tensors

* ggml : make ggml_is_empty public and work with views

* llama : use a vector for ctx->output_ids

* llama : rework reallocation logic for llama_output_reserve

Now comparing the actual size with the new total size of the output buffer
to allow more efficient enabling and disabling of the embeddings
and/or logits output in the future.

* ggml : skip empty tensors in all backends

* llama : fix llama_output_reserve nullptr deref when new_size is 0

* perplexity : make Winogrande work as it does on master

The problems with the Winogrande implementation will
need to be fixed in a separate PR to ease review.

* llama : clearer error messages for invalid logits or embeddings ids

* llama : assert all models that can have inp_out_ids

Since the graph topology is now constant, this presence check
can be done even when there are no outputs.

* llama : assert logits and embd buffers exist before writing to them

* llama : handle errors from llama_output_reserve at call sites

* perplexity : make hellaswag and multiple-choice outputs identical to master

Due to how the KV cache is updated, the logprobs for tokens in a batch
are very slightly affected by the other tokens present in the batch,
so to make hellaswag and multiple-choice return exactly the same results
as on master, the last token of each sequence needs to be evaluated
even though its output is not used at all.

This will probably be changed back in the future to make these benchmarks
a tiny bit faster.

* perplexity : fix division by zero when using less than 100 multiple-choice tasks

* llama : allow loading state saved with a different ctx size

When loading a session file, the context size is now only required to be
at least enough to load the KV cells contained in that session file,
instead of requiring to use exactly the same context size as when saving.

Doing this enables the use-case of extending or shrinking the context size
of a saved session.

This breaks existing session files because the meaning of kv_buf_size
is slightly changed (previously it was the size of the whole KV cache,
now it's only the size of the saved part of it). This allows for
finer-grained sanity checks when loading in an effort to keep kv_buf_size
useful even when the kv_size is changed.

* llama : minor

ggml-ci

* readme : update recent API changes, and warn about Vulkan

---------

Co-authored-by: Georgi Gerganov <[email protected]>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 3, 2024
* Support xverse model convert to gguf format.

* 1. Convert xverse models to gguf;
2. Add LLM_ARCH_XVERSE inference in llama.cpp;
3. Add xverse item in Supported models in README.md;

* * gguf-py: remove redundant logs
* llama: remove the init_mapping_prefetch custom parameter

* llama.cpp: Include the changes from ggerganov#6122 to exclude the unused outputs of the last layers.

* - Fix format issues
- Remove duplicate set kqv_out to llm_build_kv

* Update llama.cpp

---------

Co-authored-by: willhe <[email protected]>
Co-authored-by: willhe <[email protected]>
tybalex pushed a commit to tybalex/function.cpp that referenced this pull request Apr 17, 2024
* llama : greatly reduce logits memory usage

* llama : more compact state saving and reloading

* llama : fix lctx.n_outputs not being set before building graph

* perplexity : adapt to the logits API changes

* perplexity : fix Winogrande, use correct logits for second choice start

The first logits used to evaluate the second choice were not from
the end of the common prefix; instead, they were the logits from the end
of the first choice. This has been corrected.

The previous implementation sometimes had outliers in the scores of
choices for some tasks, and the logic to skip choices words
in the log-likelihood evaluation probably was an attempt to reduce those,
but it was complex and didn't quite seem to be the right thing.

This is simpler now, and the outlier scores aren't there anymore.

* perplexity : normalize spaces and punctuation in Winogrande sentences

* llama : fix embedding conditions

* llama : fix llama_get_embeddings_ith when the resulting id is 0

* llama : fix wrong n_outputs in llama_set_inputs

A mismatch happened when using a smaller n_ubatch than n_batch and then using
llama_batch_get_one(). The decision of what n_outputs should be now almost
fully depends on how lctx.n_outputs is set in llama_decode_internal.
The conditions are simpler this way.

* llama : when saving the state, recalculate n_outputs

This ensures the correct number of outputs for the entire previous batch
is stored in the session file, even when n_ubatch is smaller than n_batch.

* llama : fix not-skipping outputs of non-causal models

* llama : fix running a batch with n_outputs == 0

It previously worked because lctx.inp_out_ids was not initialized,
so it pointed to some garbage address which was somehow still valid when I
ran my tests.

* llama : keep same graph topology even when n_outputs == 0

* ggml : saner ggml_can_repeat with empty tensors

*  ggml : future-proof ggml_is_empty by using GGML_MAX_DIMS - 1

* ggml : do not multi-thread ops returning empty tensors

* ggml : make ggml_is_empty public and work with views

* llama : use a vector for ctx->output_ids

* llama : rework reallocation logic for llama_output_reserve

Now comparing the actual size with the new total size of the output buffer
to allow more efficient enabling and disabling of the embeddings
and/or logits output in the future.

* ggml : skip empty tensors in all backends

* llama : fix llama_output_reserve nullptr deref when new_size is 0

* perplexity : make Winogrande work as it does on master

The problems with the Winogrande implementation will
need to be fixed in a separate PR to ease review.

* llama : clearer error messages for invalid logits or embeddings ids

* llama : assert all models that can have inp_out_ids

Since the graph topology is now constant, this presence check
can be done even when there are no outputs.

* llama : assert logits and embd buffers exist before writing to them

* llama : handle errors from llama_output_reserve at call sites

* perplexity : make hellaswag and multiple-choice outputs identical to master

Due to how the KV cache is updated, the logprobs for tokens in a batch
are very slightly affected by the other tokens present in the batch,
so to make hellaswag and multiple-choice return exactly the same results
as on master, the last token of each sequence needs to be evaluated
even though its output is not used at all.

This will probably be changed back in the future to make these benchmarks
a tiny bit faster.

* perplexity : fix division by zero when using less than 100 multiple-choice tasks

* llama : allow loading state saved with a different ctx size

When loading a session file, the context size is now only required to be
at least enough to load the KV cells contained in that session file,
instead of requiring to use exactly the same context size as when saving.

Doing this enables the use-case of extending or shrinking the context size
of a saved session.

This breaks existing session files because the meaning of kv_buf_size
is slightly changed (previously it was the size of the whole KV cache,
now it's only the size of the saved part of it). This allows for
finer-grained sanity checks when loading in an effort to keep kv_buf_size
useful even when the kv_size is changed.

* llama : minor

ggml-ci

* readme : update recent API changes, and warn about Vulkan

---------

Co-authored-by: Georgi Gerganov <[email protected]>
tybalex pushed a commit to tybalex/function.cpp that referenced this pull request Apr 17, 2024
* Support xverse model convert to gguf format.

* 1. Convert xverse models to gguf;
2. Add LLM_ARCH_XVERSE inference in llama.cpp;
3. Add xverse item in Supported models in README.md;

* * gguf-py: remove redundant logs
* llama: remove the init_mapping_prefetch custom parameter

* llama.cpp: Include the changes from ggerganov#6122 to exclude the unused outputs of the last layers.

* - Fix format issues
- Remove duplicate set kqv_out to llm_build_kv

* Update llama.cpp

---------

Co-authored-by: willhe <[email protected]>
Co-authored-by: willhe <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

perplexity assert fails
8 participants