Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static cache v2 #30495

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
80 changes: 29 additions & 51 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
return max_length - new_seq_length
return previous_seq_length

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

@property
def seen_tokens(self):
logger.warning_once(
Expand Down Expand Up @@ -150,6 +158,7 @@ def update(

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
logger.warning_once("This function is deprecated. Use the `cache_positions` instead.")
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
Expand All @@ -158,14 +167,6 @@ def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return None

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
Expand Down Expand Up @@ -332,23 +333,14 @@ def update(

return self.key_cache[layer_idx], self.value_cache[layer_idx]

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))


class StaticCache(Cache):
"""
Static Cache class to be used with `torch.compile(model)`.

Parameters:
config (`PretrainedConfig):
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
required to initialize the static cache.
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
Expand All @@ -373,9 +365,18 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
for _ in range(config.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def update(
self,
Expand All @@ -394,42 +395,19 @@ def update(
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for. Kept for backward compatibility
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
to know how much of the cache it should overwrite.
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
to know how where to write in the cache.

Return:
A tuple containing the updated key and value states.
"""
new_cache_positions = cache_kwargs.get("cache_position")
k_out = self.key_cache
v_out = self.value_cache
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return (self.key_cache[0, 0].any(dim=-1)).sum()

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return self.max_cache_len

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
device = self.key_cache.device
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
device = self.value_cache.device
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))

def to_legacy_cache(self):
"""Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
return None
41 changes: 22 additions & 19 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,19 +1514,30 @@ def generate(
input_ids_length=input_ids_length,
)

if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
raise ValueError(
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
"Cache object) is unsupported. Please use only one of the two."
)
elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if not self._supports_cache_class:
raise ValueError(
"This model does not support the `cache_implementation` argument. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981."
)
if generation_config.cache_implementation == "static":
if model_kwargs.get("past_key_values", False) is not False:
raise ValueError(
"Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository."
)
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"]
if not callable(getattr(self, "_setup_cache", None)):
raise ValueError(
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
" Make sure it has a `_setup_cache` function."
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
model_kwargs["past_key_values"] = cache_cls(
config=self.config,
max_batch_size=batch_size,
max_cache_len=generation_config.max_length,
device=self.device,
dtype=cache_dtype,
)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

Expand Down Expand Up @@ -1844,14 +1855,6 @@ def typeerror():
**model_kwargs,
)

if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if not callable(getattr(self, "_reset_cache", None)):
raise ValueError(
"A `static_cache` was used to generate but there was a failure when trying to release the cache. "
" Make sure this model implements a `_reset_cache` function."
)
self._reset_cache()

return result

def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
Expand Down