Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: Static cache as a standalone object #30476

Merged
merged 12 commits into from Apr 30, 2024
Merged

Conversation

gante
Copy link
Member

@gante gante commented Apr 25, 2024

What does this PR do?

Replaces the current format of StaticCache[an object living inside a model, containing the cache for one layer] with a standalone object matching the other Cache objects. The new format preserves the existing torch.compile capabilities while being easier to manipulate, especially outside a model.

In the process, removes all traces of the previous format across all models, tests, and docs.

Fixes #30417 (In place of #30437)
Fixes #30351


Benchmarks

(RTX3090, tiny-llama model, torch==2.4.0.dev20240424+cu121)

Benchmark code
import torch

import os
os.environ["TOKENIZERS_PARALLELISM"] = "0"
import torch._dynamo.config
import torch._inductor.config
from torch.utils import benchmark
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, set_seed , StaticCache
from typing import Optional

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
# torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# TORCH_LOGS="perf_hints,recompiles,graph_breaks" python ../joao_scripts/benchmark_compile.py &> ~/logs.txt

torch.set_printoptions(linewidth=200)  # you can better see how the mask is shaped


NUM_ITER = 100
FRANCE_ARTICLE = (  # @noqa
  """<s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object.  Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren't revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot's license. Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot's license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren't going to keep doing their job and they're upset about that and so they're suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person's problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report."""
)
THROUGHPUT_LABEL = "Throughput (time/token)"


def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
  q = torch.empty_like(probs_sort).exponential_(1)
  return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)

def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
  logits = logits / max(temperature, 1e-5)

  if top_k is not None:
      v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
      pivot = v.select(-1, -1).unsqueeze(-1)
      logits = torch.where(logits < pivot, -float("Inf"), logits)
  probs = torch.nn.functional.softmax(logits, dim=-1)
  return probs

def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
  probs = logits_to_probs(logits[:, -1], temperature, top_k)
  idx_next = multinomial_sample_one_no_sync(probs)
  return idx_next, probs


device = "cuda"
attn_implementation = "sdpa"
all_dtype = torch.bfloat16

repo = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# repo = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(repo, padding_side="left", pad_token = "<s>")
model = AutoModelForCausalLM.from_pretrained(repo, torch_dtype=all_dtype, attn_implementation=attn_implementation).to(device,all_dtype)
model = model.eval()

is_legacy = hasattr(model, "_setup_cache") # up to v4.40

def print_results(throughput):
  print("\n")
  compare = benchmark.Compare(throughput)
  compare.trim_significant_figures()
  compare.colorize(rowwise = True)
  compare.print()


def record_fwd(function):
  with torch.no_grad():
      start = torch.cuda.Event(enable_timing=True)
      end = torch.cuda.Event(enable_timing=True)
      start.record()
      outputs = function()
      inputs_ids = sample(outputs[0],temperature=0.6, top_k=5)[0]
      end.record()
      torch.cuda.synchronize()
  return start.elapsed_time(end), inputs_ids


def record_generate(model, input_ids, generation_config):
  start = torch.cuda.Event(enable_timing=True)
  end = torch.cuda.Event(enable_timing=True)
  start.record()
  outputs = model.generate(input_ids, generation_config=generation_config)
  end.record()
  torch.cuda.synchronize()
  total_time = start.elapsed_time(end)/1000
  generate_time = (total_time - eager_prefill_time)/(NUM_ITER - 1)
  return generate_time



throughput = []
compilation_time = []
prefill_time = []

for bs in reversed([1, 4, 16]):
  for max_cache_length in [2048, 512]:
      for seq_length in reversed([1, 1024]):
          if seq_length>=max_cache_length:
              continue
          description = f"batch, cache_len, seq_length: {bs, max_cache_length, seq_length}"
          model.generation_config.max_length = max_cache_length

          input_ids = tokenizer([FRANCE_ARTICLE]*bs, return_tensors="pt").to(device).input_ids[:bs, :seq_length]

          model.eval()

          # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
          task = "eager dynamic fwd"
          task_spec = benchmark.TaskSpec(stmt="", setup="", description=task, label=THROUGHPUT_LABEL, sub_label=description)
          generated_ids = torch.zeros(bs, NUM_ITER, dtype = torch.long)
          past_key_values = DynamicCache()
          res = []
          set_seed(123)
          for i in range(NUM_ITER):

              if i == 0:
                  time, new_token = record_fwd(lambda: model(
                      input_ids,
                      past_key_values=past_key_values,
                      position_ids=torch.arange(0, seq_length).to(device).unsqueeze(0),
                      cache_position=torch.arange(0, seq_length).to(device),
                      return_dict=False,
                      use_cache = True)
                  )
              else:
                  time, new_token = record_fwd(lambda: model(
                      new_token,
                      past_key_values=past_key_values,
                      position_ids=torch.arange(seq_length + i - 1, seq_length + i).to(device).unsqueeze(0),
                      cache_position=torch.arange(seq_length + i - 1, seq_length + i).to(device),
                      return_dict=False,
                      use_cache = True)
                  )

              res.append(time/1000)
              generated_ids[:, i] = new_token[:,0]

          torch.cuda.synchronize()
          throughput.append(benchmark.Measurement(1, res[3:], task_spec, metadata=None))
          print_results(throughput)

          # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
          # task = "eager dynamic generate"
          # task_spec = benchmark.TaskSpec(stmt="", setup="", description=task, label=THROUGHPUT_LABEL, sub_label=description)
          # res = []
          # generation_config = copy.deepcopy(model.generation_config)
          # generation_config.update(**{"do_sample": False, "max_new_tokens": NUM_ITER, "max_length": None})
          # for i in range(5):
          #     generate_time = record_generate(model, input_ids, generation_config)
          #     res.append(generate_time)

          # throughput.append(benchmark.Measurement(1, res[3:], task_spec, metadata=None))
          # print_results(throughput)

          # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
          task = "compiled static fwd"
          task_spec = benchmark.TaskSpec(stmt="", setup="", description=task, label=THROUGHPUT_LABEL, sub_label=description)
          if is_legacy:
              model._setup_cache(StaticCache, bs, max_cache_len=max_cache_length)
              past_key_values = None
          else:
              past_key_values = StaticCache(model.config, bs, max_cache_length, model.device, model.dtype)

          generated_ids = torch.zeros(bs, NUM_ITER)
          res = []
          torch.compiler.reset()
          set_seed(123)

          compiled_model = torch.compile(model, mode="reduce-overhead",fullgraph=True)
          for i in range(NUM_ITER):
              if i == 0:
                  time, new_token = record_fwd(lambda: compiled_model(
                      input_ids,
                      past_key_values=past_key_values,
                      position_ids=torch.arange(0, seq_length).to(device).unsqueeze(0),
                      cache_position=torch.arange(0, seq_length).to(device),
                      return_dict=False,
                      use_cache = True,
                      )
                  )
              else:
                  position_ids = torch.arange(seq_length + i - 1, seq_length + i).to(device).unsqueeze(0)
                  cache_position = torch.arange(seq_length + i - 1, seq_length + i).to(device)
                  time, new_token = record_fwd(lambda: compiled_model(
                      new_token,
                      past_key_values=past_key_values,
                      position_ids=position_ids,
                      cache_position=cache_position,
                      return_dict=False,
                      use_cache = True)
                  )

              res.append(time/1000)
              generated_ids[:, i] = new_token[:,0]

          throughput.append(benchmark.Measurement(1, res[3:], task_spec, metadata=None))
          print_results(throughput)

          label = "forward exec time (with compilation)"
          for i in range(1, 6):
              task = f"iter={i}"
              task_spec = benchmark.TaskSpec(stmt="", setup="", description=task, label=label, sub_label=description)
              compilation_time.append(benchmark.Measurement(1, [res[i]], task_spec, metadata=None))

          if is_legacy:
              model._reset_cache()

          # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
          # torch.compiler.reset()
          # task = "compiled static generate"
          # task_spec = benchmark.TaskSpec(stmt="", setup="", description=task, label=THROUGHPUT_LABEL, sub_label=description)
          # res = []
          # generation_config = copy.deepcopy(model.generation_config)
          # generation_config.update(**{"do_sample": False, "max_new_tokens": NUM_ITER, "cache_implementation": "static", "max_length": None})
          # model.generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True)
          # for i in range(5):
          #     generate_time = record_generate(model, input_ids, generation_config)
          #     res.append(generate_time)

          # throughput.append(benchmark.Measurement(1, res[3:], task_spec, metadata=None))
          # print_results(throughput)


# Finalize: print compilation and prefill times
print_results(compilation_time)
print_results(prefill_time)
  • commit == 14b19c4ef365f90797e07b2a20caaaaf3901b2d2
    Screenshot 2024-04-25 at 10 05 23

  • v4.39.0
    Screenshot 2024-04-25 at 10 05 48

@gante gante requested a review from ArthurZucker April 25, 2024 09:08
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LFGTM

Comment on lines 413 to 420
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return (self.key_cache[0, 0].any(dim=-1)).sum()
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will remove this one

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's slow and not reliable, generate should never use it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(needs deprecation cycle and it's easer to do after we isolate the prefill stage, I'm going to leave it off this PR)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine by me to deprecate

Comment on lines +433 to +435
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be compatible if we slice the q k v efficiently, but that's too much trouble

@ArthurZucker
Copy link
Collaborator

Taking this on to finish!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker
Copy link
Collaborator

image
LGT

@ArthurZucker
Copy link
Collaborator

If you use the memory efficient kernel it's 20% slower. That's what we use by default

@ArthurZucker
Copy link
Collaborator

@gante gante changed the title Cache: Static cache as a stand-alone object Cache: Static cache as a standalone object Apr 26, 2024
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand it, once the StaticCache is initialized, there is no need to pass it in past_key_values argument. That's why additional condition is necessary. Suggestion:
using_static_cache = isinstance(past_key_values, StaticCache) or isinstance( getattr(self.layers[0].self_attn, "past_key_value", None), StaticCache )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@poedator This PR changes precisely the assumption you wrote: we will always need to pass the cache, after this PR it is an object that does NOT live inside the model.

This change will make the transformers' team work easier 🤗

@gante gante marked this pull request as ready for review April 26, 2024 13:07
}
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
# was changed to have a cache of 53 tokens (as opposed to 4096).
EXPECTED_TEXT_COMPLETION = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as here: #30437 (comment) please make sure to validate these tests on the T4 and A10 runners 🙏

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was indeed a mismatch on T4 🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolute great work

Comment on lines 413 to 420
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return (self.key_cache[0, 0].any(dim=-1)).sum()
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine by me to deprecate

Comment on lines 430 to 431
self.key_cache[layer_idx] *= 0.0
self.value_cache[layer_idx] *= 0.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.key_cache[layer_idx] *= 0.0
self.value_cache[layer_idx] *= 0.0
self.key_cache[layer_idx] = 0.0
self.value_cache[layer_idx] = 0.0

might be faster?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting to a new tensor produces a graph break 💔 (I'm assuming you meant self.key_cache[layer_idx] = torch.zeros(...))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No no, I think just filling them with zeros should work

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would result in TypeError: 'float' object is not subscriptable when indexing the cache :D

But filling with zeros with tensor.zero_() works 👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok 👍🏻 let's go with that then!


if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arf alright, let's add maybe a TODO? as we won't be initializing with get_seq_length later on!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a todo on get_seq_length 👍

Comment on lines +976 to +981
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is new, and since we pass cahce position, let's use cache_position[0]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed in theory, can't do in practice: breaks torch.fx tests 💔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah thought so

Comment on lines +990 to +991
if using_static_cache:
target_length = past_key_values.get_max_length()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we always use get_max_length()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_max_length() is None in the dynamic caches

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be seq_length

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but alright

@@ -684,15 +683,25 @@ def test_model_13b_greedy_generation(self):
@require_torch_gpu
@require_read_token
def test_compile_static_cache(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should require torch > 2.2

Comment on lines +730 to +736
# Static Cache + compile
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good thanks

@gante gante merged commit 75bbfd5 into huggingface:main Apr 30, 2024
24 checks passed
@gante gante deleted the static_cache_v2 branch April 30, 2024 15:37
poedator added a commit to poedator/transformers that referenced this pull request May 5, 2024
ArthurZucker pushed a commit that referenced this pull request May 13, 2024
* 4d mask fixes

* Update custom 4D mask logic

* test moved to mixin

* extra tests 4d mask

* upd 4d mask and StaticCache handling

* added Mask4DTestHard to mistral tests

* post-rebase fixes

* test fixes for StaticCache

* make fix-copies

* upd 1 after #30476

* fix common tests

* rm elif attention_mask.dim() == 4:

* tests combined, fixed, mixtral supported

* bigbird style chg reverted

* rm if attention_mask.dim() == 2

* modeling_llama formatting chg

---------

Co-authored-by: Joao Gante <[email protected]>
Comment on lines +687 to +688
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and 2.2.1 works as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants