-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support tracable dynamicKVcache #36311
base: main
Are you sure you want to change the base?
Support tracable dynamicKVcache #36311
Conversation
cc @gante @zucchini-nlp as well! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tugsbayasgalan This is great! I think it can apply to StaticCache as well, right? By making it pytree friendly, we will have the option to trace it as an input arg instead of forcing the cache to be wrapped up. Of course, to make HF model consumable with llama_runner out-of-the-box, we would still wrap it during exporting, but I think if we apply this to StaticCache, we will have the new option to allow users compose the cache through the IO during inference.
past_key_values = DynamicCache() | ||
ep = torch.export.export( | ||
model, | ||
(), | ||
{ | ||
"input_ids": input_ids, | ||
"attention_mask": attention_mask, | ||
"past_key_values": past_key_values, | ||
"use_cache": True, | ||
}, | ||
strict=False, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tugsbayasgalan Curious how DynamicCache would work differently than StaticCache in the sense of export. When we export with DynamicCache, isn't the size of it specialized to a fixed number (its capacity)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep it is specialized to number of layers. As a result, the input and output spec for DynamicKVCache is bit different as the former holds 0 tensors and latter holds 2*num_layers tensors. To properly enable dynamic shapes for the concat dimension, i think we need to initialize 0 size tensors at the start of DynamicCache.
What does this PR do?
Recently, #35873 was landed to support static KV cache with export. As a side effect, it now also made it easy to support dynamic KV cache. Since all HF models accept KVCache as extra input, we can represent dynamic KV cache as a pytree container. Credit goes to @xadupre who attempted this approach in pytorch/pytorch#147326.
cc @guangy10 @IlyasMoutawwakil