Skip to content

Commit

Permalink
[V1][Spec Decode] Optimize N-gram matching with Numba (#13365)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored Feb 18, 2025
1 parent c8d70e2 commit 4c82229
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 60 deletions.
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
psutil
sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding.
requests >= 2.26.0
tqdm
blake3
Expand Down
113 changes: 55 additions & 58 deletions vllm/v1/spec_decode/ngram_proposer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
from typing import Optional

import numpy as np
from numba import jit


class NgramProposer:

def __init__(self):
pass

def propose(
self,
context_token_ids: np.ndarray,
Expand All @@ -21,7 +19,7 @@ def propose(
that match.
Args:
context_token_ids: List of token IDs representing the
context_token_ids: Numpy array of token IDs representing the
context sequence.
n: Length of the n-gram to match.
k: Number of tokens follow the match. If there are less
Expand All @@ -41,66 +39,65 @@ def propose(
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# TODO: Use c++ to implement the _find_subarray_kmp to
# improve the efficiency
return self._find_subarray_kmp(context_token_ids, n, k)
return _find_subarray_kmp(context_token_ids, n, k)

@staticmethod
def _kmp_lps_array(pattern: List[int]) -> List[int]:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
"""
lps = [0] * len(pattern)
prev_lps = 0 # length of the previous longest prefix suffix
i = 1

while i < len(pattern):
if pattern[i] == pattern[prev_lps]:
prev_lps += 1
lps[i] = prev_lps
i += 1
@jit(nopython=True)
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
"""
lps = np.zeros(len(pattern), dtype=np.int32)
prev_lps = 0 # length of the previous longest prefix suffix
i = 1

while i < len(pattern):
if pattern[i] == pattern[prev_lps]:
prev_lps += 1
lps[i] = prev_lps
i += 1
else:
if prev_lps != 0:
prev_lps = lps[prev_lps - 1]
else:
if prev_lps != 0:
prev_lps = lps[prev_lps - 1]
else:
lps[i] = 0
i += 1
lps[i] = 0
i += 1
return lps

return lps

@staticmethod
def _find_subarray_kmp(
context_token_ids: np.ndarray,
n: int,
k: int,
) -> Optional[np.ndarray]:
context_len = context_token_ids.shape[0]
assert n > 0
@jit(nopython=True)
def _find_subarray_kmp(
context_token_ids: np.ndarray,
n: int,
k: int,
) -> Optional[np.ndarray]:
context_len = context_token_ids.shape[0]
assert n > 0

pattern = context_token_ids[-n:]
# Precompute lps array for Y
lps = NgramProposer._kmp_lps_array(pattern)
pattern = context_token_ids[-n:]
# Precompute lps array for Y
lps = _kmp_lps_array(pattern)

i = 0
j = 0
# -n because the last n tokens are used as pattern
while i < context_len - n:
if context_token_ids[i] == pattern[j]:
i += 1
j += 1
i = 0
j = 0
# -n because the last n tokens are used as pattern
while i < context_len - n:
if context_token_ids[i] == pattern[j]:
i += 1
j += 1

# If we have matched the entire Y
if j == n:
# Found pattern in context, gather the next K elements
return context_token_ids[i:i + k]
# If we have matched the entire Y
if j == n:
# Found pattern in context, gather the next K elements
return context_token_ids[i:i + k]
else:
# Mismatch
if j != 0:
# Use the lps array to avoid re-checking elements
j = lps[j - 1]
else:
# Mismatch
if j != 0:
# Use the lps array to avoid re-checking elements
j = lps[j - 1]
else:
i += 1
i += 1

# Y not found
return None
# Y not found
return None
13 changes: 11 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,20 @@ def __init__(
# Set up speculative decoding.
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True

# TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1."
self.drafter = NgramProposer()
self.use_spec_decode = True
if get_pp_group().is_last_rank:
self.drafter = NgramProposer()
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.drafter.propose(
np.zeros(1024, dtype=np.int32),
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)

# Request states.
self.requests: Dict[str, CachedRequestState] = {}
Expand Down

0 comments on commit 4c82229

Please sign in to comment.