Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assisted generation slower than with base model alone #36337

Open
2 of 4 tasks
sahilsuneja1 opened this issue Feb 21, 2025 · 1 comment
Open
2 of 4 tasks

Assisted generation slower than with base model alone #36337

sahilsuneja1 opened this issue Feb 21, 2025 · 1 comment
Labels

Comments

@sahilsuneja1
Copy link

sahilsuneja1 commented Feb 21, 2025

System Info

  • transformers version: 4.49.0
  • Platform: Linux-5.14.0-284.73.1.el9_2.x86_64-x86_64-with-glibc2.35
  • Python version: 3.11.11
  • Huggingface_hub version: 0.27.1
  • Safetensors version: 0.5.2
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.2.2+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Ran the script https://github.com/gante/huggingface-demos/blob/main/experiments/faster_generation/benchmark_decoder_open.py as:
python benchmark_decoder_open.py /path/to/Llama-3.1-8B --aux-model /path/to/Llama-3.2-1B but assisted generation turned out to be slower than using the base model alone:

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.36s/it]
Resolving data files: 100%|██████████████████████████████████████████████████████████| 1024/1024 [00:02<00:00, 357.24it/s]
Resolving data files: 100%|████████████████████████████████████████████████████████| 1024/1024 [00:00<00:00, 38278.20it/s]

ASSISTED model: 100%|█████████████████████████████████████████████████████████████████████| 20/20 [00:58<00:00,  2.92s/it]
Average time per input (ms): 2729.76
Average time per token (ms): 31.12

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████| 4/4 [00:06<00:00,  1.64s/it]
Resolving data files: 100%|██████████████████████████████████████████████████████████| 1024/1024 [00:03<00:00, 335.68it/s]
Resolving data files: 100%|████████████████████████████████████████████████████████| 1024/1024 [00:00<00:00, 37536.86it/s]

ORIGINAL model: 100%|█████████████████████████████████████████████████████████████████████| 20/20 [00:54<00:00,  2.73s/it]
Average time per input (ms): 2137.03
Average time per token (ms): 24.36
Mismatches: 0

Not sure if am missing something, or if this is a bug.

Expected behavior

Some speedup as shown at: https://huggingface.co/blog/dynamic_speculation_lookahead

Target model Draft (Assistant) model Task Speedup - heuristic Speedup - dynamic
meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-1B open-ended generation 1.00x 1.18x
-- -- -- -- --
@sahilsuneja1
Copy link
Author

sahilsuneja1 commented Feb 21, 2025

Seems like some kind of regression.

Results with transformers==4.46.0:

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.07s/it]
Resolving data files: 100%|██████████████████████████████████████████████████████████| 1024/1024 [00:04<00:00, 225.76it/s]
Resolving data files: 100%|████████████████████████████████████████████████████████| 1024/1024 [00:00<00:00, 60903.38it/s]
ASSISTED model:   0%|                                                                              | 0/20 [00:00<?, ?it/s]From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.
ASSISTED model: 100%|█████████████████████████████████████████████████████████████████████| 20/20 [00:47<00:00,  2.39s/it]
Average time per input (ms): 2109.56
Average time per token (ms): 24.05
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.04s/it]
Resolving data files: 100%|██████████████████████████████████████████████████████████| 1024/1024 [00:04<00:00, 224.12it/s]
Resolving data files: 100%|████████████████████████████████████████████████████████| 1024/1024 [00:00<00:00, 36315.24it/s]
ORIGINAL model: 100%|█████████████████████████████████████████████████████████████████████| 20/20 [00:43<00:00,  2.18s/it]
Average time per input (ms): 2137.57
Average time per token (ms): 24.37
Mismatches: 0

Results with transformers==4.45.0 (NOTE: NOT 1.18x AS MENTIONED IN THE BLOG- https://huggingface.co/blog/dynamic_speculation_lookahead):

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.38s/it]
Resolving data files: 100%|██████████████████████████████████████████████████████████| 1024/1024 [00:03<00:00, 325.33it/s]
Resolving data files: 100%|████████████████████████████████████████████████████████| 1024/1024 [00:00<00:00, 37243.58it/s]
ASSISTED model: 100%|█████████████████████████████████████████████████████████████████████| 20/20 [00:47<00:00,  2.36s/it]
Average time per input (ms): 2162.39
Average time per token (ms): 24.65
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████| 4/4 [00:06<00:00,  1.57s/it]
Resolving data files: 100%|██████████████████████████████████████████████████████████| 1024/1024 [00:04<00:00, 235.23it/s]
Resolving data files: 100%|██████████████████████████████████████████████████████████| 1024/1024 [00:07<00:00, 136.37it/s]
ORIGINAL model: 100%|█████████████████████████████████████████████████████████████████████| 20/20 [00:43<00:00,  2.20s/it]
Average time per input (ms): 2154.06
Average time per token (ms): 24.56
Mismatches: 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant