Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: Static cache as a standalone object #30476

Merged
merged 12 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache
- update
- get_seq_length
- reorder_cache
65 changes: 38 additions & 27 deletions docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,12 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
```

</hfoption>
<hfoption id="setup_cache">
Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increase between calls, the cache will have to be reinitialized, triggering a new compilation.

> [!WARNING]
> The `_setup_cache` method is an internal and private method that is still under development. This means it may not be backward compatible and the API design may change in the future.
</hfoption>
<hfoption id="Static Cache">

The `_setup_cache` method doesn't support [`~GenerationMixin.generate`] yet, so this method is a bit more involved. You'll need to write your own function to decode the next token given the current token and position and cache position of previously generated tokens.
A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. You can also pass the [`StaticCache`] object to [`~GenerationMixin.generate`] and use it across calls, like you would do with a dynamic cache.

```py
from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging
Expand All @@ -90,17 +89,22 @@ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

def decode_one_tokens(model, cur_token, input_pos, cache_position):
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
cur_token,
position_ids=input_pos,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True
)[0]
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
return new_token
```

There are a few important things you must do to enable static kv-cache and torch.compile with the `_setup_cache` method:
There are a few important things you must do to enable static kv-cache and torch.compile with the `StaticCache` method:

1. Access the model's `_setup_cache` method and pass it the [`StaticCache`] class. This is a more flexible method because it allows you to configure parameters like the maximum batch size and sequence length.
1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length.

2. Call torch.compile on the model to compile the forward pass with the static kv-cache.

Expand All @@ -109,31 +113,38 @@ There are a few important things you must do to enable static kv-cache and torch
```py
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
model._setup_cache(StaticCache, 2, max_cache_len=4096)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)

logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
generated_ids[:, seq_length] = next_token[:, 0]

decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, NUM_TOKENS_TO_GENERATE):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
generated_ids[:, cache_position] = next_token.int()
cache_position += 1
past_key_values = StaticCache(
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)

logits = model(
**inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True
)[0]
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
generated_ids[:, seq_length] = next_token[:, 0]

decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, NUM_TOKENS_TO_GENERATE):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
generated_ids[:, cache_position] = next_token.int()
cache_position += 1

text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
text
['Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.',
'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p']
```

> [!TIP]
> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method

</hfoption>
</hfoptions>

Expand Down
84 changes: 41 additions & 43 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def update(

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")

def get_max_length(self) -> Optional[int]:
Expand All @@ -61,6 +62,14 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
return max_length - new_seq_length
return previous_seq_length

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

@property
def seen_tokens(self):
logger.warning_once(
Expand Down Expand Up @@ -150,6 +159,7 @@ def update(

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
Expand All @@ -158,14 +168,6 @@ def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return None

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
Expand Down Expand Up @@ -244,6 +246,7 @@ def _get_rerotation_cos_sin(

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if len(self.key_cache) <= layer_idx:
return 0
Expand Down Expand Up @@ -332,23 +335,14 @@ def update(

return self.key_cache[layer_idx], self.value_cache[layer_idx]

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))


class StaticCache(Cache):
"""
Static Cache class to be used with `torch.compile(model)`.

Parameters:
config (`PretrainedConfig):
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
required to initialize the static cache.
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
Expand All @@ -373,9 +367,18 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
for _ in range(config.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def update(
self,
Expand All @@ -394,42 +397,37 @@ def update(
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for. Kept for backward compatibility
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
to know how much of the cache it should overwrite.
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
to know how where to write in the cache.

Return:
A tuple containing the updated key and value states.
"""
new_cache_positions = cache_kwargs.get("cache_position")
k_out = self.key_cache
v_out = self.value_cache
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return (self.key_cache[0, 0].any(dim=-1)).sum()
# TODO: deprecate this function in favor of `cache_position`
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
"""Returns the maximum sequence length of the cached states."""
return self.max_cache_len

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
device = self.key_cache.device
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
device = self.value_cache.device
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))

def to_legacy_cache(self):
"""Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
return None
def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
60 changes: 40 additions & 20 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,34 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs

def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache:
"""
Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.

Returns the resulting static cache object.
"""
needs_new_cache = (
not hasattr(self, "_static_cache")
or self._static_cache.max_batch_size < max_batch_size
or self._static_cache.max_cache_len < max_cache_len
)
if needs_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
self._static_cache = StaticCache(
config=self.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.device,
dtype=cache_dtype,
)
else:
self._static_cache.reset() # reset the cache for a new generation
return self._static_cache

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -1514,19 +1542,19 @@ def generate(
input_ids_length=input_ids_length,
)

if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
raise ValueError(
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
"Cache object) is unsupported. Please use only one of the two."
)
elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if not self._supports_cache_class:
raise ValueError(
"This model does not support the `cache_implementation` argument. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981."
)
if generation_config.cache_implementation == "static":
if model_kwargs.get("past_key_values", False) is not False:
raise ValueError(
"Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository."
)
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"]
if not callable(getattr(self, "_setup_cache", None)):
raise ValueError(
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
" Make sure it has a `_setup_cache` function."
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

Expand Down Expand Up @@ -1844,14 +1872,6 @@ def typeerror():
**model_kwargs,
)

if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if not callable(getattr(self, "_reset_cache", None)):
raise ValueError(
"A `static_cache` was used to generate but there was a failure when trying to release the cache. "
" Make sure this model implements a `_reset_cache` function."
)
self._reset_cache()

return result

def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
Expand Down