Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
750955d
Fix contrastive_search for new cache structure, and improve performan…
Cyrilvallez May 2, 2024
c8a43a8
Fix _contrastive_search for non-standard cache using ellipsis slicing
Cyrilvallez May 4, 2024
9342fff
Fix all outputs.logits memory leaks for all decoding strategies!
Cyrilvallez May 7, 2024
15e8615
Fix small error in _contrastive_search()
Cyrilvallez May 7, 2024
c0e40d4
Make all necessary change and revert for the new class
Cyrilvallez May 13, 2024
1de3148
Apply coding style
Cyrilvallez May 13, 2024
e170e96
Remove pipes in type hints for compatibility
Cyrilvallez May 13, 2024
d7c4359
correct type hint
Cyrilvallez May 13, 2024
3986f1d
apply style
Cyrilvallez May 13, 2024
b9f7e04
Use DynamicCache by default and solve conflicts
Cyrilvallez May 14, 2024
614e052
Fix rebase issues
Cyrilvallez May 14, 2024
241b851
Add `_supports_dynamic_cache_class` in models for models that support…
Cyrilvallez May 14, 2024
d31eea4
Create generation config to return legacy format by default, or to ch…
Cyrilvallez May 14, 2024
17525ab
style
Cyrilvallez May 14, 2024
c47e6ce
Fix case when use_cache is False
Cyrilvallez May 14, 2024
b9bbfd9
Remove default DynamicCache in assiste_decoding if assistant_model do…
Cyrilvallez May 14, 2024
3b59cf6
Update prepare_inputs_for_generation() for case with empty DynamicCache
Cyrilvallez May 15, 2024
2a809ad
Correct return of args in _assisted_decoding
Cyrilvallez May 15, 2024
e96adcb
Remove EfficientDynamicCache as it is no longer needed
Cyrilvallez May 15, 2024
20174ec
Correct mistake in generation config
Cyrilvallez May 21, 2024
7e39b92
Move cache logic of assisted decoding to AssistedCandidateGenerator._…
Cyrilvallez May 21, 2024
f3e3161
change DynamicCache function names from "split" to "batch_split" for …
Cyrilvallez May 21, 2024
8abe055
Remove `_supports_dynamic_cache_class` attribute after rebase
Cyrilvallez May 21, 2024
e9d0b25
Correct missing line lost in conflict resolution during rebasing
Cyrilvallez May 21, 2024
c902dc1
Add special case for Jamba
Cyrilvallez May 21, 2024
2f83867
Fix jamba test
Cyrilvallez May 21, 2024
2c51e03
Coding style
Cyrilvallez May 21, 2024
3c0999b
coding style
Cyrilvallez May 23, 2024
b494dd5
Correct missing import in rebasing
Cyrilvallez May 23, 2024
70a0185
Simplify _validate_model_kwargs based on removal of _supports_dynamic…
Cyrilvallez May 23, 2024
d38a966
Simplify code paths in _contrastive_search
Cyrilvallez May 23, 2024
c8edaef
coding style
Cyrilvallez May 23, 2024
1e020d6
Update docstrings of cache methods
Cyrilvallez Jun 3, 2024
7bd2e3e
Update prepare_inputs_for_generation() -> past_key_values are always …
Cyrilvallez Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,22 +373,75 @@ def get_max_length(self) -> Optional[int]:
return None

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache

def crop(self, maximum_length: int):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update the docstring of the class to explain why we have these methods! 🤗

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class DynamicCache(Cache):
    """
    A cache that grows dynamically as more tokens are generated. This is the default for generative models.
    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.
    """

this needs to be updated!

"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""

# In case it is negative
if maximum_length < 0:
maximum_length = self.get_seq_length() - abs(maximum_length)

if self.get_seq_length() <= maximum_length:
return

self._seen_tokens = maximum_length
for idx in range(len(self.key_cache)):
self.key_cache[idx] = self.key_cache[idx][..., :maximum_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :maximum_length, :]

def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicCache()
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
return out

@classmethod
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
cache.update(layer_keys, layer_values, idx)
return cache

def batch_repeat_interleave(self, repeats: int):
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)

def batch_select_indices(self, indices: torch.Tensor):
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]


class QuantizedCache(DynamicCache):
"""
Expand Down
18 changes: 14 additions & 4 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ def __init__(
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
)

# Remove potential default DynamicCache if assistant does not support it
if "past_key_values" in assistant_kwargs.keys():
if (
isinstance(assistant_kwargs["past_key_values"], DynamicCache)
and not self.assistant_model._supports_cache_class
):
# Cache is empty -> remove it from kwargs
if len(assistant_kwargs["past_key_values"]) == 0:
del assistant_kwargs["past_key_values"]
# Cache is not empty -> convert to legacy
else:
assistant_kwargs["past_key_values"] = assistant_kwargs["past_key_values"].to_legacy_cache()

if "assistant_encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
elif assistant_model.config.is_encoder_decoder:
Expand Down Expand Up @@ -387,10 +400,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
elif isinstance(past_key_values, DynamicCache):
for idx in range(len(past_key_values.key_cache)):
if past_key_values.value_cache[idx].shape[-1] != 0:
past_key_values.key_cache[idx] = past_key_values.key_cache[idx][:, :, :maximum_length, :]
past_key_values.value_cache[idx] = past_key_values.value_cache[idx][:, :, :maximum_length, :]
past_key_values.crop(maximum_length)

elif past_key_values is not None:
for idx in range(len(past_key_values)):
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ class GenerationConfig(PushToHubMixin):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
return_legacy_cache (`bool`, *optional*, default to `True`):
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.

> Wild card

Expand Down Expand Up @@ -400,6 +402,7 @@ def __init__(self, **kwargs):
self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", True)

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
Expand Down
Loading