Skip to content

Commit

Permalink
clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp committed May 10, 2024
1 parent 73fcfb2 commit 16731be
Show file tree
Hide file tree
Showing 11 changed files with 459 additions and 18 deletions.
3 changes: 3 additions & 0 deletions docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ RUN python3 -m pip install --no-cache-dir decord av==9.2.0
# Some slow tests require bnb
RUN python3 -m pip install --no-cache-dir bitsandbytes

# Some tests require quanto
RUN python3 -m pip install --no-cache-dir quanto

# For `dinat` model
# The `XXX` part in `torchXXX` needs to match `PYTORCH` (to some extent)
RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https://shi-labs.com/natten/wheels
Expand Down
38 changes: 38 additions & 0 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,44 @@ your screen, one word at a time:
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```


## KV Cache Quantization

The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value
cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models.
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.

KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache]
(https://arxiv.org/abs/2402.02750). For more information on the inner workings see the paper.

To enable quantization of the key-value cache, one needs to indicate `cache_implementation="quantized"` in the `generation_config`.
Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`QuantizedCacheConfig`] class.

<Tip warning={true}>

Cache quantization can be detrimental if the context length is short and there is enough GPU VRAM available to run without cache quantization.

</Tip>


```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
>>> inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4})
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. It's a great way to express myself and rel

>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20)
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
```


## Decoding strategies

Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
Expand Down
12 changes: 11 additions & 1 deletion docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,23 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] Cache
- update

[[autodoc]] CacheConfig
- update

[[autodoc]] QuantizedCacheConfig
- validate

[[autodoc]] DynamicCache
- update
- get_seq_length
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] QuantoQuantizedCache
- update
- get_seq_length

[[autodoc]] SinkCache
- update
- get_seq_length
Expand All @@ -371,4 +381,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache
- update
- get_seq_length
- reorder_cache
- reset
20 changes: 18 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,15 @@
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"]
_import_structure["cache_utils"] = [
"Cache",
"CacheConfig",
"DynamicCache",
"QuantizedCacheConfig",
"QuantoQuantizedCache",
"SinkCache",
"StaticCache",
]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down Expand Up @@ -5730,7 +5738,15 @@
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from .cache_utils import (
Cache,
CacheConfig,
DynamicCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SinkCache,
StaticCache,
)
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
Expand Down
243 changes: 241 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import copy
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import torch

from .configuration_utils import PretrainedConfig
from .utils import logging
from .utils import is_quanto_available, logging


if is_quanto_available():
from quanto import QBitsTensor, qint2, qint4

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -82,6 +88,140 @@ def seen_tokens(self):
return None


@dataclass
class CacheConfig:
"""
Base class for cache configs
"""

cache_implementation: None

@classmethod
def from_dict(cls, config_dict, **kwargs):
"""
Constructs a CacheConfig instance from a dictionary of parameters.
Args:
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
**kwargs: Additional keyword arguments to override dictionary values.
Returns:
CacheConfig: Instance of CacheConfig constructed from the dictionary.
"""
config = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
return config

def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path (Union[str, os.PathLike]): Path to the JSON file in which this configuration instance's parameters will be saved.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
config_dict = self.to_dict()
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

writer.write(json_string)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Returns:
Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
return output

def __iter__(self):
for attr, value in copy.deepcopy(self.__dict__).items():
yield attr, value

def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"

def to_json_string(self):
"""
Serializes this instance to a JSON formatted string.
Returns:
str: JSON formatted string representing the configuration instance.
"""
return json.dumps(self.__dict__, indent=2) + "\n"

def update(self, **kwargs):
"""
Update the configuration attributes with new values.
Args:
**kwargs: Keyword arguments representing configuration attributes and their new values.
"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)


@dataclass
class QuantizedCacheConfig(CacheConfig):
"""
Configuration class for quantized cache settings.
Attributes:
nbits (`Optional[int]`, *optional*, defaults to 4):
Number of bits, can be 2 or 4. Defaults to 2.
q_group_size (`Optional[int]`, *optional*, defaults to 64):
Size of the quantization group, should be a divisor of the model's hidden dimension.
Defaults to 64.
residual_length (`Optional[int]`, *optional*, defaults to 128):
Length of the residual cache which will always be stored in original presicion.
Defaults to 128.
"""

def __init__(
self,
nbits: Optional[int] = 4,
q_group_size: Optional[int] = 64,
residual_length: Optional[int] = 128,
):
self.nbits = nbits
self.q_group_size = q_group_size
self.residual_length = residual_length

def validate(self):
"""Validates if the arguments passed are correct"""

incorrect_arg_msg = (
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
)
if self.nbits not in [2, 4]:
raise ValueError(
incorrect_arg_msg.format(
key="nbits",
correct_value="2 or 4",
found_value=self.nbits,
),
)
if self.q_group_size <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="q_group_size",
correct_value="a positive integer",
found_value=self.q_group_size,
),
)
if self.residual_length < 0:
raise ValueError(
incorrect_arg_msg.format(
key="residual_length",
correct_value="a positive integer",
found_value=self.residual_length,
),
)


class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
Expand Down Expand Up @@ -186,6 +326,105 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens
return cache


class QuantoQuantizedCache(DynamicCache):
"""
A cache similar to what described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. Current implementation
supports `int2` and `int4` dtypes from `quanto` cache.
Cache stores the original precision Key and Value states as a list of tensors, one for each layer. The maximum expected shape for each tensor is
`[batch_size, num_heads, residual_length, head_dim]`. Quantized Key and Value are stored separately as a list of quantized tensors, one for each layer.
The size of each tensor is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
Parameters:
nbits (`Optional[int]`, *optional*, defaults to 4):
Number of bits, can be 2 or 4. Defaults to 2.
q_group_size (`Optional[int]`, *optional*, defaults to 64):
Size of the quantization group, should be a divisor of the model's hidden dimension.
Defaults to 64.
residual_length (`Optional[int]`, *optional*, defaults to 128):
Length of the residual cache which will always be stored in original presicion.
Defaults to 128.
"""

def __init__(self, nbits: int = 4, q_group_size: int = 64, residual_length: int = 128) -> None:
if nbits not in [2, 4]:
raise ValueError(f"`nbits` has to be one of [`2`, `4`] but got {nbits}")

self._key_cache_quant: List[torch.Tensor] = []
self._value_cache_quant: List[torch.Tensor] = []

self.residual_length = residual_length
self.qtype = qint4 if nbits == 4 else qint2
self.q_group_size = q_group_size

super.__init__()

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the number of seen tokens
if layer_idx == 0:
self.seen_token += key_states.shape[-2]

if len(self.key_cache) <= layer_idx:
self._quantized_value_cache.append(self._quantize(key_states.contiguous()))
self._value_cache_quant.append(self._quantize(value_states.contiguous()))
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
keys_to_return, values_to_return = key_states, value_states
else:
dequant_key = self._key_cache_quant[layer_idx].dequantize()
dequant_value = self._value_cache_quant[layer_idx].dequantize()
keys_to_return = torch.cat(
[
dequant_key,
self.key_cache[layer_idx],
key_states,
],
dim=-2,
)
values_to_return = torch.cat(
[
dequant_value,
self.value_cache[layer_idx],
value_states,
],
dim=-2,
)
if (
self.key_cache[layer_idx].dim() == 4
and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
):
self._key_cache_quant[layer_idx] = self._quantize(keys_to_return.contiguous())
self._value_cache_quant[layer_idx] = self._quantize(values_to_return.contiguous())
self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

return keys_to_return, values_to_return

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
return self.seen_token

def _quantize(self, tensor):
qtensor = QBitsTensor.quantize(tensor, axis=0, qtype=self.qtype, group_size=self.q_group_size)
return qtensor


class SinkCache(Cache):
"""
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
Expand Down

0 comments on commit 16731be

Please sign in to comment.