Skip to content

Commit

Permalink
[Bugfix] Fix 2 Node and Spec Decode tests (#13341)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Feb 16, 2025
1 parent a0231b7 commit 5d2965b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
10 changes: 5 additions & 5 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,11 @@ def _compare_tp(
if load_format == "dummy":
# Avoid OOM
text_overrides = {
"num_layers": 1,
"num_hidden_layers": 1,
"num_experts": 2,
"num_experts_per_tok": 2,
"num_local_experts": 2,
"num_hidden_layers": 4,
"hidden_size": 512,
"intermediate_size": 800,
"num_attention_heads": 4,
"num_key_value_heads": 1,
}

if is_multimodal:
Expand Down
16 changes: 12 additions & 4 deletions vllm/spec_decode/ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposals
Expand All @@ -25,11 +26,18 @@ class NGramWorker(NonLLMProposerWorkerBase):
which don't rely on LLM model to give proposals.
"""

def __init__(self, *args, **kwargs):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
device_type: str = "cuda",
**kwargs,
):
super().__init__(vllm_config)

# Get local_rank/vocab_size from kwargs attribute
self.local_rank = kwargs["local_rank"]
self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
self.device_type = kwargs.get("device_type", "cuda")
self.local_rank = local_rank
self.device_type = device_type

# Lazy initialization list.
self._proposer: Top1Proposer
Expand Down

0 comments on commit 5d2965b

Please sign in to comment.