New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Quantized KV Cache #30483
base: main
Are you sure you want to change the base?
Quantized KV Cache #30483
Conversation
As we discussed quantized cache can be started to be integrated to the library, given the results we got so far. All the possible speed optimizations/pre-fill stage optimizations can be done further, as we will be getting feedback from the community. So, I would like to get a review on the PR :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API wise looks really great ! I did not spotted anything critical here that needs to be addressed (and I will let joao give a deeper review on the cache file changes) - except for guarding quanto imports (also I would say safer to make local imports whenever possible - e.g. at QuantCache
init)
You raised a concern about switching between cache implementations - I made an attempt while ago: #29030 that got stale (😅 ) maybe that PR might solve your concern?
Maybe we could also track models that support quant cache with a private attribute _supports_quant_cache
in xxxPreTrainedModel
- what do you think?
Thanks for the comments!
Okey noted!
I love the generalized cache implementation idea. Not sure how this will work on overall API level, given that Joao and Arthur are working on changing cache thing. I'll let Joao to decide about that
Hmm, Actually quant cache should be supported abywhere dynamicCache is, that means everything except for old models like bart/t5. Yeah I think we can add it for explicitness, until the cache API is refactored to be same everywhere |
Thanks !
Ok that's great if that's the case then, i would say no need for that ! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need a rebase due to #30476, but I love this POC -- in fact, I've reviewed it as if it was not a POC 😉
After removing the extra .py
files and adding some docs, I believe it is ready to be launched! And I also think it deserves a blog post :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments for you to work on, but let's gather the benchmarks first :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @zucchini-nlp ! 🚀 I only left nits and one open question with respect to tests, otherwise it looks really great !
@gante added benchmark results on the PR description. Right now int4 has almost same performance as fp16, sometimes a bit better. Also added some comparison with the KIVI paper. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🙌 Thank you for iterating on this very cool project!
(CI needs fixing -- possibly a simple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting work! Having both cache and quantizing on the fly when needed is very interesting!
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work ! Left one nit about tests !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Last few nits and should be good to go!
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Arthur <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few small nits
cache_config = ( | ||
generation_config.cache_config | ||
if generation_config.cache_config is not None | ||
else QuantizedCacheConfig() | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cache_config = ( | |
generation_config.cache_config | |
if generation_config.cache_config is not None | |
else QuantizedCacheConfig() | |
) |
can be removed if we pass generation_config.cache_config instead to the quantoQauntizedCache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, because generation_config.cache_config
can be None if we are using the model's generation config and the user does not pass any cache config, In that case we have to init the Cache Config with default values ourselves.
The one line below code will fail otherwise, if we pass in None
as input
else QuantizedCacheConfig() | ||
) | ||
|
||
model_kwargs["past_key_values"] = QuantoQuantizedCache(**cache_config.to_dict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_kwargs["past_key_values"] = QuantoQuantizedCache(**cache_config.to_dict()) | |
model_kwargs["past_key_values"] = QuantoQuantizedCache(cache_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really, the class does not accept as input a dataclass
object so we pass it as **mapping
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yeah, it's annoying
Just curious, would this new cache work with |
Nope, this one specifically no as it inherits from Dynamic cache, but another implementation based on static cache could. Compile is not super happy with if else and device placements espacially if it's input dependent (here depends on the length of the processed input) |
@ArthurZucker @ydshieh: "torch.compile with quanto is only supported for 8 bits quantization for now" (from @SunMarc, on a related conversation on slack) |
I made the KV cache work with HQQ as a backend. It can be simply plugged in if a user writes their own "CacheClass". I am not planning to add it now as it needs more evaluation and experiments, but wanted to show how anyone can add more backends. Do you think I should continue experimenting with HQQ or we can simply put the below code as example for users? BTW, if we were to actually support more cache quant classes in the library, maybe we'll need to change the current QuantCache API a bit to be more versatile. from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from hqq.core.quantize import Quantizer as HQQQuantizer
class HQQQuantizedCache(DynamicCache):
def __init__(
self,
nbits: int = 4,
axis: int = 0,
q_group_size: int = 64,
residual_length: int = 128,
compute_dtype: torch.dtype = torch.float16,
device: str = "cpu",
) -> None:
if nbits not in [2, 4, 8]:
raise ValueError(f"`nbits` has to be one of [`2`, `4`, `8`] but got {nbits}")
if axis not in [0, 1]:
raise ValueError(f"`axis` has to be one of [`1`, `2`] but got {axis}")
self._quantized_key_cache: List[Tuple[torch.Tensor, Dict]] = []
self._quantized_value_cache: List[Tuple[torch.Tensor, Dict]] = []
self.nbits = nbits
self.axis = axis
self.residual_length = residual_length
self.q_group_size = q_group_size
self.compute_dtype = compute_dtype
self.quantizer = HQQQuantizer
self.device = device
super().__init__()
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
if len(self.key_cache) <= layer_idx:
q_key, meta_key = self._quantize(key_states.contiguous())
self._quantized_key_cache.append((q_key, meta_key))
q_value, meta_value = self._quantize(value_states.contiguous())
self._quantized_value_cache.append((q_value, meta_value))
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
keys_to_return, values_to_return = key_states, value_states
else:
quant_key, meta_key = self._quantized_key_cache[layer_idx]
dequant_key = self.quantizer.dequantize(quant_key, meta_key)
quant_value, meta_value = self._quantized_value_cache[layer_idx]
dequant_value = self.quantizer.dequantize(quant_value, meta_value)
keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
keys_to_return = torch.cat(keys_to_return, dim=-2)
values_to_return = torch.cat(values_to_return, dim=-2)
if (
self.key_cache[layer_idx].dim() == 4
and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
):
q_key, meta_key = self._quantize(keys_to_return.contiguous())
self._quantized_key_cache[layer_idx] = (q_key, meta_key)
q_value, meta_value = self._quantize(values_to_return.contiguous())
self._quantized_key_cache[layer_idx] = (q_value, meta_value)
self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return keys_to_return, values_to_return
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
# this part of code otherwise fails when used to verify attn_weight shape in some models
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
def _quantize(self, tensor):
qtensor, meta = self.quantizer.quantize(
tensor,
axis=self.axis,
device=self.device,
compute_dtype=self.compute_dtype,
nbits=self.nbits,
group_size=self.q_group_size,
)
meta["compute_dtype"] = self.compute_dtype
return qtensor, meta
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, attn_implementation="eager", device_map = "auto")
inputs = tokenizer("I like rock music because" return_tensors="pt").to(model.device)
out = model.generate(
**inputs,
do_sample=False,
max_new_tokens=50,
past_key_values=HQQQuantizedCache(
nbits=2,
axis=1, # 2bit with axis=0 generates garbage
compute_dtype=torch.float16,
device=model.device
),
)
print(f"text with HQQ backend: {tokenizer.batch_decode(out)}") |
What does this PR do?
An implementation of quantized cache with
quanto
library. Introduces a newCacheConfig
to store cache related arguments and a new cache classQuantoQuantizedCache
. The implementation is based partially on the KIVI paper, but in this case we do a per-token quantization for both: keys and values.Example usage:
Perplexity plots
Here the results are different from what we got earlier because I was calculating perplexity in one forward pass, by quantizing and then dequantizing all keys and values. The new script uses cache object and calculates pplx per new token.Eval on LongBench (scripts taken from LongBench repo)
This is to compare with the KIVI method, since they did the same evals on all datasets from LongBench.I cannot find KIVI results on all of the LongBench, so here will be only
transformers
version.Memory vs Latency plots
Same old plots showing memory consumption and latency for differeny cache types: