Skip to content

Commit 75bbfd5

Browse files
authored
Cache: Static cache as a standalone object (#30476)
1 parent 0ae789e commit 75bbfd5

File tree

20 files changed

+381
-428
lines changed

20 files changed

+381
-428
lines changed

docs/source/en/internal/generation_utils.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens
362362
[[autodoc]] StaticCache
363363
- update
364364
- get_seq_length
365+
- reorder_cache

docs/source/en/llm_optims.md

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,12 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
6565
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
6666
```
6767

68-
</hfoption>
69-
<hfoption id="setup_cache">
68+
Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increase between calls, the cache will have to be reinitialized, triggering a new compilation.
7069

71-
> [!WARNING]
72-
> The `_setup_cache` method is an internal and private method that is still under development. This means it may not be backward compatible and the API design may change in the future.
70+
</hfoption>
71+
<hfoption id="Static Cache">
7372

74-
The `_setup_cache` method doesn't support [`~GenerationMixin.generate`] yet, so this method is a bit more involved. You'll need to write your own function to decode the next token given the current token and position and cache position of previously generated tokens.
73+
A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. You can also pass the [`StaticCache`] object to [`~GenerationMixin.generate`] and use it across calls, like you would do with a dynamic cache.
7574

7675
```py
7776
from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging
@@ -90,17 +89,22 @@ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token
9089
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
9190
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
9291

93-
def decode_one_tokens(model, cur_token, input_pos, cache_position):
92+
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
9493
logits = model(
95-
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
94+
cur_token,
95+
position_ids=input_pos,
96+
cache_position=cache_position,
97+
past_key_values=past_key_values,
98+
return_dict=False,
99+
use_cache=True
96100
)[0]
97101
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
98102
return new_token
99103
```
100104

101-
There are a few important things you must do to enable static kv-cache and torch.compile with the `_setup_cache` method:
105+
There are a few important things you must do to enable static kv-cache and torch.compile with the `StaticCache` method:
102106

103-
1. Access the model's `_setup_cache` method and pass it the [`StaticCache`] class. This is a more flexible method because it allows you to configure parameters like the maximum batch size and sequence length.
107+
1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length.
104108

105109
2. Call torch.compile on the model to compile the forward pass with the static kv-cache.
106110

@@ -109,31 +113,38 @@ There are a few important things you must do to enable static kv-cache and torch
109113
```py
110114
batch_size, seq_length = inputs["input_ids"].shape
111115
with torch.no_grad():
112-
model._setup_cache(StaticCache, 2, max_cache_len=4096)
113-
cache_position = torch.arange(seq_length, device=torch_device)
114-
generated_ids = torch.zeros(
115-
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
116-
)
117-
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
118-
119-
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
120-
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
121-
generated_ids[:, seq_length] = next_token[:, 0]
122-
123-
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
124-
cache_position = torch.tensor([seq_length + 1], device=torch_device)
125-
for _ in range(1, NUM_TOKENS_TO_GENERATE):
126-
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
127-
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
128-
generated_ids[:, cache_position] = next_token.int()
129-
cache_position += 1
116+
past_key_values = StaticCache(
117+
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
118+
)
119+
cache_position = torch.arange(seq_length, device=torch_device)
120+
generated_ids = torch.zeros(
121+
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
122+
)
123+
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
124+
125+
logits = model(
126+
**inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True
127+
)[0]
128+
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
129+
generated_ids[:, seq_length] = next_token[:, 0]
130+
131+
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
132+
cache_position = torch.tensor([seq_length + 1], device=torch_device)
133+
for _ in range(1, NUM_TOKENS_TO_GENERATE):
134+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
135+
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
136+
generated_ids[:, cache_position] = next_token.int()
137+
cache_position += 1
130138

131139
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
132140
text
133141
['Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.',
134142
'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p']
135143
```
136144

145+
> [!TIP]
146+
> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method
147+
137148
</hfoption>
138149
</hfoptions>
139150

src/transformers/cache_utils.py

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def update(
4444

4545
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
4646
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
47+
# TODO: deprecate this function in favor of `cache_position`
4748
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
4849

4950
def get_max_length(self) -> Optional[int]:
@@ -61,6 +62,14 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
6162
return max_length - new_seq_length
6263
return previous_seq_length
6364

65+
def reorder_cache(self, beam_idx: torch.LongTensor):
66+
"""Reorders the cache for beam search, given the selected beam indices."""
67+
for layer_idx in range(len(self.key_cache)):
68+
device = self.key_cache[layer_idx].device
69+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
70+
device = self.value_cache[layer_idx].device
71+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
72+
6473
@property
6574
def seen_tokens(self):
6675
logger.warning_once(
@@ -150,6 +159,7 @@ def update(
150159

151160
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
152161
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
162+
# TODO: deprecate this function in favor of `cache_position`
153163
if len(self.key_cache) <= layer_idx:
154164
return 0
155165
return self.key_cache[layer_idx].shape[-2]
@@ -158,14 +168,6 @@ def get_max_length(self) -> Optional[int]:
158168
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
159169
return None
160170

161-
def reorder_cache(self, beam_idx: torch.LongTensor):
162-
"""Reorders the cache for beam search, given the selected beam indices."""
163-
for layer_idx in range(len(self.key_cache)):
164-
device = self.key_cache[layer_idx].device
165-
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
166-
device = self.value_cache[layer_idx].device
167-
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
168-
169171
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
170172
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
171173
legacy_cache = ()
@@ -244,6 +246,7 @@ def _get_rerotation_cos_sin(
244246

245247
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
246248
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
249+
# TODO: deprecate this function in favor of `cache_position`
247250
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
248251
if len(self.key_cache) <= layer_idx:
249252
return 0
@@ -332,23 +335,14 @@ def update(
332335

333336
return self.key_cache[layer_idx], self.value_cache[layer_idx]
334337

335-
def reorder_cache(self, beam_idx: torch.LongTensor):
336-
"""Reorders the cache for beam search, given the selected beam indices."""
337-
for layer_idx in range(len(self.key_cache)):
338-
device = self.key_cache[layer_idx].device
339-
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
340-
device = self.value_cache[layer_idx].device
341-
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
342-
343338

344339
class StaticCache(Cache):
345340
"""
346341
Static Cache class to be used with `torch.compile(model)`.
347342
348343
Parameters:
349344
config (`PretrainedConfig):
350-
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
351-
required to initialize the static cache.
345+
The configuration file defining the shape-related attributes required to initialize the static cache.
352346
max_batch_size (`int`):
353347
The maximum batch size with which the model will be used.
354348
max_cache_len (`int`):
@@ -373,9 +367,18 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
373367
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
374368
)
375369

370+
self.key_cache: List[torch.Tensor] = []
371+
self.value_cache: List[torch.Tensor] = []
376372
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
377-
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
378-
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
373+
for _ in range(config.num_hidden_layers):
374+
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
375+
# breaks when updating the cache.
376+
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
377+
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
378+
torch._dynamo.mark_static_address(new_layer_key_cache)
379+
torch._dynamo.mark_static_address(new_layer_value_cache)
380+
self.key_cache.append(new_layer_key_cache)
381+
self.value_cache.append(new_layer_value_cache)
379382

380383
def update(
381384
self,
@@ -394,42 +397,37 @@ def update(
394397
value_states (`torch.Tensor`):
395398
The new value states to cache.
396399
layer_idx (`int`):
397-
The index of the layer to cache the states for. Kept for backward compatibility
400+
The index of the layer to cache the states for.
398401
cache_kwargs (`Dict[str, Any]`, `optional`):
399-
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
400-
to know how much of the cache it should overwrite.
402+
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
403+
to know how where to write in the cache.
401404
402405
Return:
403406
A tuple containing the updated key and value states.
404407
"""
405-
new_cache_positions = cache_kwargs.get("cache_position")
406-
k_out = self.key_cache
407-
v_out = self.value_cache
408+
cache_position = cache_kwargs.get("cache_position")
409+
k_out = self.key_cache[layer_idx]
410+
v_out = self.value_cache[layer_idx]
408411

409-
k_out[:, :, new_cache_positions] = key_states
410-
v_out[:, :, new_cache_positions] = value_states
412+
k_out[:, :, cache_position] = key_states
413+
v_out[:, :, cache_position] = value_states
411414

412415
return k_out, v_out
413416

414417
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
415-
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
418+
"""Returns the sequence length of the cached states that were seen by the model."""
416419
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
417420
# limit the check to the first batch member and head dimension.
418-
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
419-
# https://github.com/pytorch/pytorch/issues/120248 is fixed
420-
return (self.key_cache[0, 0].any(dim=-1)).sum()
421+
# TODO: deprecate this function in favor of `cache_position`
422+
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
421423

422424
def get_max_length(self) -> Optional[int]:
423-
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
425+
"""Returns the maximum sequence length of the cached states."""
424426
return self.max_cache_len
425427

426-
def reorder_cache(self, beam_idx: torch.LongTensor):
427-
"""Reorders the cache for beam search, given the selected beam indices."""
428-
device = self.key_cache.device
429-
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
430-
device = self.value_cache.device
431-
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
432-
433-
def to_legacy_cache(self):
434-
"""Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
435-
return None
428+
def reset(self):
429+
"""Resets the cache values while preserving the objects"""
430+
for layer_idx in range(len(self.key_cache)):
431+
# In-place ops prevent breaking the static address
432+
self.key_cache[layer_idx].zero_()
433+
self.value_cache[layer_idx].zero_()

src/transformers/generation/utils.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,34 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
13101310
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
13111311
return model_kwargs
13121312

1313+
def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache:
1314+
"""
1315+
Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a
1316+
new `generate` call requires a larger cache.
1317+
1318+
Returns the resulting static cache object.
1319+
"""
1320+
needs_new_cache = (
1321+
not hasattr(self, "_static_cache")
1322+
or self._static_cache.max_batch_size < max_batch_size
1323+
or self._static_cache.max_cache_len < max_cache_len
1324+
)
1325+
if needs_new_cache:
1326+
if hasattr(self.config, "_pre_quantization_dtype"):
1327+
cache_dtype = self.config._pre_quantization_dtype
1328+
else:
1329+
cache_dtype = self.dtype
1330+
self._static_cache = StaticCache(
1331+
config=self.config,
1332+
max_batch_size=max_batch_size,
1333+
max_cache_len=max_cache_len,
1334+
device=self.device,
1335+
dtype=cache_dtype,
1336+
)
1337+
else:
1338+
self._static_cache.reset() # reset the cache for a new generation
1339+
return self._static_cache
1340+
13131341
@torch.no_grad()
13141342
def generate(
13151343
self,
@@ -1514,19 +1542,19 @@ def generate(
15141542
input_ids_length=input_ids_length,
15151543
)
15161544

1517-
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
1545+
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
1546+
raise ValueError(
1547+
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
1548+
"Cache object) is unsupported. Please use only one of the two."
1549+
)
1550+
elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
1551+
if not self._supports_cache_class:
1552+
raise ValueError(
1553+
"This model does not support the `cache_implementation` argument. Please check the following "
1554+
"issue: https://github.com/huggingface/transformers/issues/28981."
1555+
)
15181556
if generation_config.cache_implementation == "static":
1519-
if model_kwargs.get("past_key_values", False) is not False:
1520-
raise ValueError(
1521-
"Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository."
1522-
)
1523-
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"]
1524-
if not callable(getattr(self, "_setup_cache", None)):
1525-
raise ValueError(
1526-
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
1527-
" Make sure it has a `_setup_cache` function."
1528-
)
1529-
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
1557+
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)
15301558

15311559
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
15321560

@@ -1844,14 +1872,6 @@ def typeerror():
18441872
**model_kwargs,
18451873
)
18461874

1847-
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
1848-
if not callable(getattr(self, "_reset_cache", None)):
1849-
raise ValueError(
1850-
"A `static_cache` was used to generate but there was a failure when trying to release the cache. "
1851-
" Make sure this model implements a `_reset_cache` function."
1852-
)
1853-
self._reset_cache()
1854-
18551875
return result
18561876

18571877
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:

0 commit comments

Comments
 (0)