Skip to content

Commit

Permalink
finalized mvp
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Apr 29, 2024
1 parent 90771df commit cd6aecd
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
9 changes: 3 additions & 6 deletions docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Under the hood, `generate` will attempt to reuse the same cache object, removing
</hfoption>
<hfoption id="Static Cache">

A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you'll need to write your own function to decode the next token given the current token and position and cache position of previously generated tokens.
A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. You can also pass the [`StaticCache`] object to [`~GenerationMixin.generate`] and use it across calls, like you would do with a dynamic cache.

```py
from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging
Expand Down Expand Up @@ -142,11 +142,8 @@ text
'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p']
```

Please note that the cache has to be manually reset if you want to repeat this process multiple times reusing the same cache object.

```py
past_key_values.reset() # Clears the cache's contents without destroying the objects
```
> [!TIP]
> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method
</hfoption>
</hfoptions>
Expand Down
29 changes: 21 additions & 8 deletions tests/quantization/aqlm_integration/test_aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,12 @@ def test_quantized_model_compile(self):
# Sample tokens greedily
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
cur_token,
position_ids=input_pos,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True,
)[0]
new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)

Expand All @@ -209,12 +214,12 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
seq_length = input_ids.shape[1]

# Setup static KV cache for generation
if hasattr(self.quantized_model.config, "_pre_quantization_dtype"):
cache_dtype = self.quantized_model.config._pre_quantization_dtype
else:
cache_dtype = self.quantized_model.dtype
past_key_values = StaticCache(
config=self.quantized_model.config, max_batch_size=2, max_cache_len=seq_length + self.max_new_tokens + 1, device=torch_device, dtype=cache_dtype
config=self.quantized_model.config,
max_batch_size=1,
max_cache_len=seq_length + self.max_new_tokens + 1,
device=torch_device,
dtype=self.quantized_model.config._pre_quantization_dtype,
)

# Allocate token ids to be generated and copy prefix ids
Expand All @@ -223,7 +228,13 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int)

# Do a forward pass to fill the prefix cache and compile the kernels if necessary
logits = self.quantized_model(input_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True)[0]
logits = self.quantized_model(
input_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True,
)[0]
next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
generated_ids[:, [seq_length]] = next_token

Expand All @@ -235,7 +246,9 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, self.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position, past_key_values)
next_token = decode_one_tokens(
self.quantized_model, next_token.clone(), None, cache_position, past_key_values
)
generated_ids.index_copy_(1, cache_position, next_token)
cache_position += 1

Expand Down

0 comments on commit cd6aecd

Please sign in to comment.