Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantized KV Cache #30483

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

Quantized KV Cache #30483

wants to merge 14 commits into from

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Apr 25, 2024

What does this PR do?

An implementation of quantized cache with quanto library. Introduces a new CacheConfig to store cache related arguments and a new cache class QuantoQuantizedCache. The implementation is based partially on the KIVI paper, but in this case we do a per-token quantization for both: keys and values.

Example usage:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, attn_implementation="eager").to("cuda:0")

inputs = tokenizer("Hello, how are you?", truncation=True, return_tensors="pt").to(model.device)

out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized")
out_fp16 = model.generate(**inputs, do_sample=False, max_new_tokens=20)

print(f"text with quant cache: {tokenizer.batch_decode(out)}")
print(f"text with fp16 cache: {tokenizer.batch_decode(out_fp16)}")
Perplexity plots Here the results are different from what we got earlier because I was calculating perplexity in one forward pass, by quantizing and then dequantizing all keys and values. The new script uses cache object and calculates pplx per new token. Perplexity Latency
Eval on LongBench (scripts taken from LongBench repo) This is to compare with the KIVI method, since they did the same evals on all datasets from LongBench.
Dataset KIVI 16fp KIVI int2 Our fp16 Our int4 Our int2
TREC 63.0 67.5 63.0 63.0 55.0
SAMSum 41.12 42.18 41.12 41.3 14.04

I cannot find KIVI results on all of the LongBench, so here will be only transformers version.

Dataset fp16 int4 int2
TriviaQA 84.28 84.76 63.64
HotPotQA 30.08 30.04 17.3
Passage_retrieval_en 8.5 9.5 4.82
Memory vs Latency plots Same old plots showing memory consumption and latency for differeny cache types: Latency as a function of batch size Memory consumption as a function of batch size Memory consumption as a function of max new tokens

@zucchini-nlp
Copy link
Member Author

As we discussed quantized cache can be started to be integrated to the library, given the results we got so far. All the possible speed optimizations/pre-fill stage optimizations can be done further, as we will be getting feedback from the community.

So, I would like to get a review on the PR :)

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API wise looks really great ! I did not spotted anything critical here that needs to be addressed (and I will let joao give a deeper review on the cache file changes) - except for guarding quanto imports (also I would say safer to make local imports whenever possible - e.g. at QuantCache init)
You raised a concern about switching between cache implementations - I made an attempt while ago: #29030 that got stale (😅 ) maybe that PR might solve your concern?
Maybe we could also track models that support quant cache with a private attribute _supports_quant_cache in xxxPreTrainedModel - what do you think?

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

Thanks for the comments!

except for guarding quanto imports (also I would say safer to make local imports whenever possible - e.g. at QuantCache init)

Okey noted!

You raised a concern about switching between cache implementations - I made an attempt while ago: #29030 that got stale (😅 ) maybe that PR might solve your concern?

I love the generalized cache implementation idea. Not sure how this will work on overall API level, given that Joao and Arthur are working on changing cache thing. I'll let Joao to decide about that

Maybe we could also track models that support quant cache with a private attribute _supports_quant_cache in xxxPreTrainedModel - what do you think?

Hmm, Actually quant cache should be supported abywhere dynamicCache is, that means everything except for old models like bart/t5. Yeah I think we can add it for explicitness, until the cache API is refactored to be same everywhere

@younesbelkada
Copy link
Contributor

Thanks !

Hmm, Actually quant cache should be supported abywhere dynamicCache is, that means everything except for old models like bart/t5. Yeah I think we can add it for explicitness, until the cache API is refactored to be same everywhere

Ok that's great if that's the case then, i would say no need for that !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need a rebase due to #30476, but I love this POC -- in fact, I've reviewed it as if it was not a POC 😉

After removing the extra .py files and adding some docs, I believe it is ready to be launched! And I also think it deserves a blog post :D

src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
@zucchini-nlp zucchini-nlp marked this pull request as ready for review May 2, 2024 10:33
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments for you to work on, but let's gather the benchmarks first :)

docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @zucchini-nlp ! 🚀 I only left nits and one open question with respect to tests, otherwise it looks really great !

src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
tests/models/llama/test_modeling_llama.py Outdated Show resolved Hide resolved
tests/generation/test_utils.py Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

zucchini-nlp commented May 3, 2024

@gante added benchmark results on the PR description. Right now int4 has almost same performance as fp16, sometimes a bit better. Also added some comparison with the KIVI paper.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🙌 Thank you for iterating on this very cool project!

@gante
Copy link
Member

gante commented May 8, 2024

(CI needs fixing -- possibly a simple make fix-copies)

@gante gante requested a review from ArthurZucker May 8, 2024 14:20
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting work! Having both cache and quantizing on the fly when needed is very interesting!

docs/source/en/generation_strategies.md Show resolved Hide resolved
docs/source/en/generation_strategies.md Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
@zucchini-nlp zucchini-nlp changed the title [POC] Quantized KV Cache Quantized KV Cache May 9, 2024
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work ! Left one nit about tests !

tests/quantization/quanto_integration/test_quanto.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Last few nits and should be good to go!

src/transformers/cache_utils.py Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Show resolved Hide resolved
src/transformers/cache_utils.py Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few small nits

src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
Comment on lines +1602 to +1606
cache_config = (
generation_config.cache_config
if generation_config.cache_config is not None
else QuantizedCacheConfig()
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cache_config = (
generation_config.cache_config
if generation_config.cache_config is not None
else QuantizedCacheConfig()
)

can be removed if we pass generation_config.cache_config instead to the quantoQauntizedCache

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because generation_config.cache_config can be None if we are using the model's generation config and the user does not pass any cache config, In that case we have to init the Cache Config with default values ourselves.

The one line below code will fail otherwise, if we pass in None as input

else QuantizedCacheConfig()
)

model_kwargs["past_key_values"] = QuantoQuantizedCache(**cache_config.to_dict())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_kwargs["past_key_values"] = QuantoQuantizedCache(**cache_config.to_dict())
model_kwargs["past_key_values"] = QuantoQuantizedCache(cache_config)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, the class does not accept as input a dataclass object so we pass it as **mapping

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah, it's annoying

@ydshieh
Copy link
Collaborator

ydshieh commented May 13, 2024

Just curious, would this new cache work with torch.compile?

@ArthurZucker
Copy link
Collaborator

Nope, this one specifically no as it inherits from Dynamic cache, but another implementation based on static cache could. Compile is not super happy with if else and device placements espacially if it's input dependent (here depends on the length of the processed input)

@gante
Copy link
Member

gante commented May 13, 2024

@ArthurZucker @ydshieh: "torch.compile with quanto is only supported for 8 bits quantization for now" (from @SunMarc, on a related conversation on slack)

@zucchini-nlp
Copy link
Member Author

I made the KV cache work with HQQ as a backend. It can be simply plugged in if a user writes their own "CacheClass". I am not planning to add it now as it needs more evaluation and experiments, but wanted to show how anyone can add more backends. Do you think I should continue experimenting with HQQ or we can simply put the below code as example for users?

BTW, if we were to actually support more cache quant classes in the library, maybe we'll need to change the current QuantCache API a bit to be more versatile.

from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from hqq.core.quantize import Quantizer as HQQQuantizer


class HQQQuantizedCache(DynamicCache):
    def __init__(
        self,
        nbits: int = 4,
        axis: int = 0,
        q_group_size: int = 64,
        residual_length: int = 128,
        compute_dtype: torch.dtype = torch.float16,
        device: str = "cpu",
    ) -> None:
        if nbits not in [2, 4, 8]:
            raise ValueError(f"`nbits` has to be one of [`2`, `4`, `8`] but got {nbits}")

        if axis not in [0, 1]:
            raise ValueError(f"`axis` has to be one of [`1`, `2`] but got {axis}")

        self._quantized_key_cache: List[Tuple[torch.Tensor, Dict]] = []
        self._quantized_value_cache: List[Tuple[torch.Tensor, Dict]] = []
        self.nbits = nbits
        self.axis = axis

        self.residual_length = residual_length
        self.q_group_size = q_group_size
        self.compute_dtype = compute_dtype
        self.quantizer = HQQQuantizer
        self.device = device

        super().__init__()

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        if len(self.key_cache) <= layer_idx:
            q_key, meta_key = self._quantize(key_states.contiguous())
            self._quantized_key_cache.append((q_key, meta_key))

            q_value, meta_value = self._quantize(value_states.contiguous())
            self._quantized_value_cache.append((q_value, meta_value))

            self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
            self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
            keys_to_return, values_to_return = key_states, value_states
        else:
            quant_key, meta_key = self._quantized_key_cache[layer_idx]
            dequant_key = self.quantizer.dequantize(quant_key, meta_key)

            quant_value, meta_value = self._quantized_value_cache[layer_idx]
            dequant_value = self.quantizer.dequantize(quant_value, meta_value)

            keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
            values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]

            keys_to_return = torch.cat(keys_to_return, dim=-2)
            values_to_return = torch.cat(values_to_return, dim=-2)
            if (
                self.key_cache[layer_idx].dim() == 4
                and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
            ):
                q_key, meta_key = self._quantize(keys_to_return.contiguous())
                self._quantized_key_cache[layer_idx] = (q_key, meta_key)

                q_value, meta_value = self._quantize(values_to_return.contiguous())
                self._quantized_key_cache[layer_idx] = (q_value, meta_value)

                self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
                self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
            else:
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return keys_to_return, values_to_return

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        if len(self.key_cache) <= layer_idx:
            return 0
        # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
        # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
        # this part of code otherwise fails when used to verify attn_weight shape in some models
        return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1

    def _quantize(self, tensor):
        qtensor, meta = self.quantizer.quantize(
            tensor,
            axis=self.axis,
            device=self.device,
            compute_dtype=self.compute_dtype,
            nbits=self.nbits,
            group_size=self.q_group_size,
        )
        meta["compute_dtype"] = self.compute_dtype
        return qtensor, meta


tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, attn_implementation="eager", device_map = "auto")

inputs = tokenizer("I like rock music because" return_tensors="pt").to(model.device)

out = model.generate(
    **inputs,
    do_sample=False,
    max_new_tokens=50,
    past_key_values=HQQQuantizedCache(
        nbits=2,
        axis=1, # 2bit with axis=0 generates garbage
        compute_dtype=torch.float16,
        device=model.device
    ),
)


print(f"text with HQQ backend: {tokenizer.batch_decode(out)}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants