Skip to content

Commit

Permalink
fix multi step and support same model checkpoint for main and spec model
Browse files Browse the repository at this point in the history
  • Loading branch information
luccafong committed Feb 5, 2025
1 parent 12a153e commit 9afc75e
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 45 deletions.
13 changes: 2 additions & 11 deletions tests/spec_decode/e2e/test_mtp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@
# main model
MAIN_MODEL = "luccafong/deepseek_mtp_main_random"

# speculative model
SPEC_MODEL = "luccafong/deepseek_mtp_draft_random"

# max. number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 3
# num_nextn_predict_layers in the config.json of the speculator model.
MAX_SPEC_TOKENS = 1

# precision
PRECISION = "bfloat16"
Expand All @@ -57,7 +55,6 @@
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
Expand Down Expand Up @@ -97,12 +94,10 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
},
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
},
Expand Down Expand Up @@ -152,7 +147,6 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
Expand Down Expand Up @@ -196,7 +190,6 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
Expand Down Expand Up @@ -239,7 +232,6 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
}
# Try a range of num. speculative tokens
Expand Down Expand Up @@ -282,7 +274,6 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
Expand Down
40 changes: 32 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,8 +850,12 @@ def get_num_attention_heads(self,
def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
if self.hf_text_config.model_type == "deepseek_mtp":
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
Expand Down Expand Up @@ -1664,6 +1668,21 @@ def compute_hash(self) -> str:
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str

@staticmethod
def hf_config_override(
hf_config: PretrainedConfig
) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3":
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(
hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["DeepSeekMTPModel"]
})
return hf_config

@staticmethod
def maybe_create_spec_config(
target_model_config: ModelConfig,
Expand Down Expand Up @@ -1746,12 +1765,16 @@ def maybe_create_spec_config(
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
"""

Check failure on line 1767 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:1767:81: E501 Line too long (82 > 80)

if speculative_model is None:
if num_speculative_tokens is not None:
raise ValueError("num_speculative_tokens was provided without "
if target_model_config.hf_text_config.model_type == "deepseek_v3":
# use the draft model from the same model:
speculative_model = target_model_config.model
else:
raise ValueError("num_speculative_tokens was provided without "
"speculative_model.")
return None
else:
return None

if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2):
Expand Down Expand Up @@ -1805,14 +1828,14 @@ def maybe_create_spec_config(
max_seq_len_to_capture=target_model_config.
max_seq_len_to_capture,
max_logprobs=target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)

draft_hf_config = draft_model_config.hf_config

if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens

n_predict = getattr(draft_hf_config, "n_predict", None)
if n_predict is not None:
if num_speculative_tokens is None:
Expand Down Expand Up @@ -1922,11 +1945,12 @@ def _verify_and_get_draft_model_tensor_parallel_size(
# If speculative_draft_tensor_parallel_size is unset then set it
# appropriately else verify that it is set correctly.
if speculative_draft_tensor_parallel_size is None:
if draft_hf_config.model_type == "mlp_speculator":
if draft_hf_config.model_type in ("mlp_speculator", "deepseek_mtp"):
speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"MLPSpeculator cannot currently be run with tp>1; "
f"{draft_hf_config.model_type} cannot currently "
"be run with tp>1; "

Check failure on line 1953 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/config.py:1951:25: G004 Logging statement uses f-string
"setting speculative_draft_tensor_parallel_size=1")
else:
speculative_draft_tensor_parallel_size = \
Expand Down
31 changes: 16 additions & 15 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __init__(

self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.block = DeepseekV3DecoderLayer(config, prefix, model_config,
Expand All @@ -73,11 +73,13 @@ def forward(
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
inputs_embeds[positions <= spec_step_index] = 0
# masking inputs at position<=k, token from k+1
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)

Expand Down Expand Up @@ -123,24 +125,25 @@ def forward(
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
step_idx: int = 0,
spec_step_idx: int = 0,
) -> torch.Tensor:
return self.layers[str(self.mtp_start_layer_idx + step_idx)](
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
input_ids,
positions,
kv_caches[step_idx],
kv_caches[spec_step_idx],
attn_metadata,
previous_hidden_states,
inputs_embeds,
spec_step_idx,
)

def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
step_idx: int = 0,
spec_step_idx: int = 0,
) -> torch.Tensor:
mtp_layer = self.layers[str(self.mtp_start_layer_idx + step_idx)]
mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
logits = self.logits_processor(mtp_layer.shared_head.head,
hidden_states, sampling_metadata)
return logits
Expand All @@ -150,9 +153,7 @@ class DeepSeekMTP(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.model_config = config.model
self.config = vllm_config.model_config.hf_config
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
Expand All @@ -168,21 +169,21 @@ def forward(
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
step_idx: int = 0,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states,
inputs_embeds, step_idx)
inputs_embeds, spec_step_idx)
return hidden_states

def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
step_idx: int = 0,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
return self.model.compute_logits(hidden_states, sampling_metadata,
step_idx)
spec_step_idx)

def sample(
self,
Expand Down
2 changes: 2 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,8 @@ class ExecuteModelRequest(
previous_hidden_states: Optional[HiddenStates] = None
# The number of forward steps to run.
num_steps: int = 1
# The step index for spec model input.
spec_step_idx: int = 0
# Finished request ids since last step.
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding.
Expand Down
10 changes: 6 additions & 4 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def execute_model(
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]:
"""Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
Expand Down Expand Up @@ -271,16 +272,17 @@ def execute_model(
for step in range(num_steps):
multi_modal_kwargs = model_input.multi_modal_kwargs or {}

kwargs = {"previous_hidden_states": hidden_states} \
model_execute_kwargs = {"previous_hidden_states": hidden_states} \
if previous_hidden_states is not None else {}

compute_logits_kwargs = {}
# Run model
if hasattr(self.model.config, "num_nextn_predict_layers"):
# for DeepSeek MTP only to use the corresponding layer for
# each step
kwargs["step_idx"] = step
compute_logits_kwargs["step_idx"] = step
spec_step_idx = kwargs.get("spec_step_idx", 0)
model_execute_kwargs["spec_step_idx"] = spec_step_idx
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(model_input.attn_metadata,
self.vllm_config):
hidden_states = model_executable(
Expand All @@ -291,7 +293,7 @@ def execute_model(
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**kwargs,
**model_execute_kwargs,
)

# Compute the logits.
Expand Down
6 changes: 4 additions & 2 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import copy
import weakref
from typing import Dict, List, Set, Tuple
from typing import Dict, List, Set, Tuple, Optional

Check failure on line 5 in vllm/spec_decode/multi_step_worker.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F401)

vllm/spec_decode/multi_step_worker.py:5:44: F401 `typing.Optional` imported but unused

import torch

Expand Down Expand Up @@ -95,9 +95,10 @@ def sampler_output(
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
for i in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
expanded_request.spec_step_idx += 1
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
Expand All @@ -106,6 +107,7 @@ def sampler_output(
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
expanded_request.spec_step_idx = 0

# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
Expand Down
Loading

0 comments on commit 9afc75e

Please sign in to comment.