-
Notifications
You must be signed in to change notification settings - Fork 423
Memory efficient context handling #1183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
LLama/LLamaEmbedder.cs
Outdated
/// </summary> | ||
/// <param name="text"></param> | ||
/// <returns></returns> | ||
public int CountTokens(string text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CountTokens
and GetTokens
methods are duplicated on LLamaEmbedder
and LLamaStatelessExecutor
. Also I don't think either of them actually requires a context (which is a very expensive object to create and destroy)!
Can these methods be moved up to LLamaWeights
class instead? That's a more appropriate place for methods relating to tokens/vocabulary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CountTokens
andGetTokens
methods are duplicated onLLamaEmbedder
andLLamaStatelessExecutor
.
The reason for this is because the contexts are made with the parameters of each specific object (text generator or embedding generator).
Also I don't think either of them actually requires a context (which is a very expensive object to create and destroy)!
The code is now streamlined to not have the context around, but created when needed and then destroyed. The logic of doing this is the same as in each object itself, for example, GetEmbeddingsWithTokenCount() does the same in LLamaEmbedder, and InferAsync() does the same in StatelessExecutor. So, the code is logical in all sense now. The overhead of creating the context on the fly is very small, and when using KernelMemory with this update, compared to before, 30% less GPU memory is used.
Can these methods be moved up to
LLamaWeights
class instead? That's a more appropriate place for methods relating to tokens/vocabulary.
If we would move them to LLamaWeights
, then we would need to change the code to keep the parameters in each object (required to make the context - different for LLamaEmbedder and for LLamaStatelessExecutor), and pass these parameters on several places in the code to these functions, etc. A lot of modifications on several places where these functions are used. I think that the simplest and cleanest would be to leave them where they are now.
As a conclusion, I would keep it how it is now. Please let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Martin, on a second thought, I think that you may be right (only the params need to be kept). I will look at it!
LLama/LLamaStatelessExecutor.cs
Outdated
@@ -169,5 +169,44 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams | |||
throw new LLamaDecodeError(returnCode); | |||
} | |||
} | |||
|
|||
/// <inheritdoc/> | |||
public int CountTokens(string text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See other comment about these methods
Moved the code to |
var embeddings = await generator.GenerateAsync( | ||
[ | ||
"The cat is cute", | ||
if (false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need resolving before merge?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generator.GetService<EmbeddingGeneratorMetadata>()
uses the context and thus will fail because for a context efficient handling we do not keep the context. This was the main aim of this PR.
The code in the test assumes that there is a context. I think that for the test code to work we would need some extra work to create an embedding service that keeps the context (this could be done in a next PR, if anybody is interested in to do it). The aim of the embedder in our code is different. My opinion is that the test code is wrong because it assume that the embedder is a live service, and it should not be for efficiently handling of GPU memory. There are two options, delete the test code or leave it in switched off with the TODO comment I have added.
LLama/LLamaWeights.cs
Outdated
using var context = CreateContext(parameters); | ||
var count = context.Tokenize(text, special: true).Length; | ||
return count; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a Tokenize
method on the lower level model handle, no need to use a context: https://github.com/SciSharp/LLamaSharp/blob/master/LLama/Native/SafeLlamaModelHandle.cs#L480
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(i.e. just writing a wrapper over that method should suffice)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not realize that there is a tokenize method on parent level. I will use that then and since it does not need a context I can move all code back to KM. The main aim of this PR was to remove the context that was create on several places because it unnecessarily fills GPU memory (saves about 30% of memory!).
LLama/LLamaWeights.cs
Outdated
/// <remarks> | ||
/// It throws if text is null and Includes empty stop token because addBos is left true to be consistent with the CountTokens implementation.</remarks> | ||
/// <see cref="CountTokens(string, IContextParams)"/> | ||
public IReadOnlyList<string> GetTokens(string text, IContextParams parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation doesn't seem correct to me (I realise you're just moving it from LlamaSharpTextGenerator
to LLamaWeights
, but I don't really work on the KM stuff, so I haven't closely reviewed these methods before).
In general LLamaSharp is quite careful about never treating tokens as text, it's not safe for a number of reasons - for example a token could be half of a character, in which case it can't be decoded into text. That's what the StreamingTokenDecoder
is for, you could add 10 tokens and get back just one character of text. At the very least, that means that GetTokens and CountTokens would have a mismatch.
Obviously KM needs something back to satisfy the contract of ITextTokenizer
etc, so I'm not really sure what the right answer is here. Maybe we should move back into KM, an an extension method on LLamaWeights? That way you can still use it as if it's here, but it's not part of the main lib. I'm open to other ideas though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main aim of this PR was to remove the context that was created on several places in the KM code because it unnecessarily fills GPU memory (this saves about 30% of memory!). With using the native tokenize method and moving back the code to KM, I think that we have the right solution. Furthermore, I think that the code is only used in unit tests and probably will never be used by anybody. This is also the reason for why we should not keep a context there!
…LamaSharp into UpdateContextHandling
I have updated the code. |
Martin, in the LLamaEmbedderTests in CompareEmbeddings() I had to disable a Microsoft.Extensions.AI.IEmbeddingGenerator related code segment that does not work with the new efficient context handling. Please look at that code to decide what can be done to use it or we can also remove it. I guess that if we want that kind of functionality, then we will need to create a LLamaEmbedderService that is compatible with Microsoft.Extensions.AI.IEmbeddingGenerator.
I had to disable in SafeLlamaModelHandleTests -> MetadataValByKey_ReturnsCorrectly for Mac and Linux. On Windows it is OK. Please look at this also to decide if this is important. This has nothing to do with this PR, but I guess with llama.cpp.