diff --git a/tests/spec_decode/e2e/test_mtp_correctness.py b/tests/spec_decode/e2e/test_mtp_correctness.py index eec37a9858865..f0fca64fcba49 100644 --- a/tests/spec_decode/e2e/test_mtp_correctness.py +++ b/tests/spec_decode/e2e/test_mtp_correctness.py @@ -27,12 +27,10 @@ # main model MAIN_MODEL = "luccafong/deepseek_mtp_main_random" -# speculative model -SPEC_MODEL = "luccafong/deepseek_mtp_draft_random" # max. number of speculative tokens: this corresponds to -# num_heads in the config.json of the speculator model. -MAX_SPEC_TOKENS = 3 +# num_nextn_predict_layers in the config.json of the speculator model. +MAX_SPEC_TOKENS = 1 # precision PRECISION = "bfloat16" @@ -57,7 +55,6 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, }, ]) @@ -97,12 +94,10 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, "disable_logprobs_during_spec_decoding": False, }, { - "speculative_model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, "disable_logprobs_during_spec_decoding": True, }, @@ -152,7 +147,6 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, }, ]) @@ -196,7 +190,6 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, }, ]) @@ -239,7 +232,6 @@ def test_mtp_e2e_greedy_correctness_with_preemption( "test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, "num_speculative_tokens": k, } # Try a range of num. speculative tokens @@ -282,7 +274,6 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, "speculative_disable_by_batch_size": 4 }]) diff --git a/vllm/config.py b/vllm/config.py index d12607d21f439..e18d5daf85b45 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -850,8 +850,12 @@ def get_num_attention_heads(self, def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> Tuple[int, int]: from vllm.distributed.utils import get_pp_indices - total_num_hidden_layers = getattr(self.hf_text_config, - "num_hidden_layers", 0) + if self.hf_text_config.model_type == "deepseek_mtp": + total_num_hidden_layers = getattr(self.hf_text_config, + "num_nextn_predict_layers", 0) + else: + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size pp_size = parallel_config.pipeline_parallel_size start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) @@ -1664,6 +1668,21 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str + @staticmethod + def hf_config_override( + hf_config: PretrainedConfig + ) -> PretrainedConfig: + if hf_config.model_type == "deepseek_v3": + hf_config.model_type = "deepseek_mtp" + if hf_config.model_type == "deepseek_mtp": + n_predict = getattr( + hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["DeepSeekMTPModel"] + }) + return hf_config + @staticmethod def maybe_create_spec_config( target_model_config: ModelConfig, @@ -1746,12 +1765,16 @@ def maybe_create_spec_config( Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if the necessary conditions are met, else None. """ - if speculative_model is None: if num_speculative_tokens is not None: - raise ValueError("num_speculative_tokens was provided without " + if target_model_config.hf_text_config.model_type == "deepseek_v3": + # use the draft model from the same model: + speculative_model = target_model_config.model + else: + raise ValueError("num_speculative_tokens was provided without " "speculative_model.") - return None + else: + return None if (speculative_disable_by_batch_size is not None and speculative_disable_by_batch_size < 2): @@ -1805,6 +1828,7 @@ def maybe_create_spec_config( max_seq_len_to_capture=target_model_config. max_seq_len_to_capture, max_logprobs=target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, ) draft_hf_config = draft_model_config.hf_config @@ -1812,7 +1836,6 @@ def maybe_create_spec_config( if (num_speculative_tokens is not None and hasattr(draft_hf_config, "num_lookahead_tokens")): draft_hf_config.num_lookahead_tokens = num_speculative_tokens - n_predict = getattr(draft_hf_config, "n_predict", None) if n_predict is not None: if num_speculative_tokens is None: @@ -1922,11 +1945,12 @@ def _verify_and_get_draft_model_tensor_parallel_size( # If speculative_draft_tensor_parallel_size is unset then set it # appropriately else verify that it is set correctly. if speculative_draft_tensor_parallel_size is None: - if draft_hf_config.model_type == "mlp_speculator": + if draft_hf_config.model_type in ("mlp_speculator", "deepseek_mtp"): speculative_draft_tensor_parallel_size = 1 if target_parallel_config.tensor_parallel_size > 1: logger.warning( - "MLPSpeculator cannot currently be run with tp>1; " + f"{draft_hf_config.model_type} cannot currently " + "be run with tp>1; " "setting speculative_draft_tensor_parallel_size=1") else: speculative_draft_tensor_parallel_size = \ diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 318032a774239..c862eb9bfb157 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -58,8 +58,8 @@ def __init__( self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.eh_proj = nn.Linear(config.model.hidden_size * 2, - config.model.hidden_size, + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, bias=False) self.shared_head = SharedHead(config=config, quant_config=quant_config) self.block = DeepseekV3DecoderLayer(config, prefix, model_config, @@ -73,11 +73,13 @@ def forward( attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, ) -> torch.Tensor: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) assert inputs_embeds is not None - inputs_embeds[positions == 0] = 0 # masking inputs at position=0 + inputs_embeds[positions <= spec_step_index] = 0 + # masking inputs at position<=k, token from k+1 inputs_embeds = self.enorm(inputs_embeds) previous_hidden_states = self.hnorm(previous_hidden_states) @@ -123,24 +125,25 @@ def forward( attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, - step_idx: int = 0, + spec_step_idx: int = 0, ) -> torch.Tensor: - return self.layers[str(self.mtp_start_layer_idx + step_idx)]( + return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( input_ids, positions, - kv_caches[step_idx], + kv_caches[spec_step_idx], attn_metadata, previous_hidden_states, inputs_embeds, + spec_step_idx, ) def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - step_idx: int = 0, + spec_step_idx: int = 0, ) -> torch.Tensor: - mtp_layer = self.layers[str(self.mtp_start_layer_idx + step_idx)] + mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] logits = self.logits_processor(mtp_layer.shared_head.head, hidden_states, sampling_metadata) return logits @@ -150,9 +153,7 @@ class DeepSeekMTP(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config - self.config = config - self.model_config = config.model + self.config = vllm_config.model_config.hf_config self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "model")) @@ -168,21 +169,21 @@ def forward( previous_hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - step_idx: int = 0, + spec_step_idx: int = 0, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, previous_hidden_states, - inputs_embeds, step_idx) + inputs_embeds, spec_step_idx) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - step_idx: int = 0, + spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: return self.model.compute_logits(hidden_states, sampling_metadata, - step_idx) + spec_step_idx) def sample( self, diff --git a/vllm/sequence.py b/vllm/sequence.py index 534b9e60610a2..b60ee3304f842 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1304,6 +1304,8 @@ class ExecuteModelRequest( previous_hidden_states: Optional[HiddenStates] = None # The number of forward steps to run. num_steps: int = 1 + # The step index for spec model input. + spec_step_idx: int = 0 # Finished request ids since last step. finished_requests_ids: List[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index b57ca0cde01ac..09a5f77a3872d 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -175,6 +175,7 @@ def execute_model( previous_hidden_states: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs, ) -> Optional[List[SamplerOutput]]: """Executes num_steps forward passes with advacement of input tensors on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. @@ -271,7 +272,7 @@ def execute_model( for step in range(num_steps): multi_modal_kwargs = model_input.multi_modal_kwargs or {} - kwargs = {"previous_hidden_states": hidden_states} \ + model_execute_kwargs = {"previous_hidden_states": hidden_states} \ if previous_hidden_states is not None else {} compute_logits_kwargs = {} @@ -279,8 +280,9 @@ def execute_model( if hasattr(self.model.config, "num_nextn_predict_layers"): # for DeepSeek MTP only to use the corresponding layer for # each step - kwargs["step_idx"] = step - compute_logits_kwargs["step_idx"] = step + spec_step_idx = kwargs.get("spec_step_idx", 0) + model_execute_kwargs["spec_step_idx"] = spec_step_idx + compute_logits_kwargs["spec_step_idx"] = spec_step_idx with set_forward_context(model_input.attn_metadata, self.vllm_config): hidden_states = model_executable( @@ -291,7 +293,7 @@ def execute_model( intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), - **kwargs, + **model_execute_kwargs, ) # Compute the logits. diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 5474917a6fab7..55d3b20a284fb 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -2,7 +2,7 @@ import copy import weakref -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Set, Tuple, Optional import torch @@ -95,9 +95,10 @@ def sampler_output( # TODO: Remove this branch once DraftModelRunner supports TP>1 # and other restrictions that are part of DraftModelRunner's # supports_gpu_multi_step(..) - for _ in range(sample_len): + for i in range(sample_len): model_output: List[SamplerOutput] = self.worker.execute_model( execute_model_req=expanded_request) + expanded_request.spec_step_idx += 1 assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] @@ -106,6 +107,7 @@ def sampler_output( model_output, expanded_request.seq_group_metadata_list, indices_of_seq_with_bonus_tokens) model_outputs.append(model_output) + expanded_request.spec_step_idx = 0 # move indices to device to avoid stream sync indices_of_seq_with_bonus_tokens = torch.tensor( diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index bf3aa8e40b0db..f7c9b33707f00 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -107,6 +107,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": typical_acceptance_sampler_posterior_alpha, disable_logprobs=speculative_config.disable_logprobs, disable_log_stats=speculative_config.disable_log_stats, + num_speculative_tokens=speculative_config.num_speculative_tokens, ) return spec_decode_worker @@ -152,9 +153,11 @@ def create_worker( typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, disable_log_stats: bool, + num_speculative_tokens: int, ) -> "SpecDecodeWorker": allow_zero_draft_token_step = True + num_spec_prefill_steps = 1 ngram_prompt_lookup_max = ( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( @@ -190,6 +193,8 @@ def create_worker( allow_zero_draft_token_step = False proposer_worker = MultiStepWorker(**draft_worker_kwargs) + if draft_model_config.hf_config.model_type == "deepseek_mtp": + num_spec_prefill_steps = num_speculative_tokens proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp) @@ -241,7 +246,9 @@ def create_worker( disable_log_stats=disable_log_stats, disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, - allow_zero_draft_token_step=allow_zero_draft_token_step) + allow_zero_draft_token_step=allow_zero_draft_token_step, + num_spec_prefill_steps=num_spec_prefill_steps, + ) def __init__( self, @@ -254,6 +261,7 @@ def __init__( metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, + num_spec_prefill_steps: int = 1, ): """ Create a SpecDecodeWorker. @@ -284,6 +292,10 @@ def __init__( allow_zero_draft_token_step: whether to allow a step where the draft model generates no draft token; should disallow when the tp of draft model is larger than 1 (TODO: #5814) + num_spec_prefill_steps: number of speculative prefill steps to run + before the speculative decoding starts. This is only used when + the draft model is a deepseek_mtp model that requires prefill + kv cache separately for each step layer. """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker @@ -316,6 +328,7 @@ def __init__( self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs self._disable_log_stats = disable_log_stats + self._num_spec_prefill_steps = num_spec_prefill_steps def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -664,8 +677,10 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, execute_model_req.previous_hidden_states = \ prepare_prefill_hidden_states( sampler_output.prefill_hidden_states) - - self.proposer_worker.execute_model(execute_model_req) + execute_model_req.spec_step_idx = 0 + for _ in range(self._num_spec_prefill_steps): + self.proposer_worker.execute_model(execute_model_req) + execute_model_req.spec_step_idx += 1 sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 12baecde6e42c..8f0d1bde8645f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1650,6 +1650,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index bd07608f788f0..ff38e3bfc207b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -68,8 +68,8 @@ def __init__( speculative_config = self.speculative_config model_config = self.model_config speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.model == - model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type == + model_config.hf_config.model_type) \ or (speculative_config.draft_model_config.hf_config.model_type not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \ else {"return_hidden_states": True} diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 819b81fbfdbb2..89b4652f201f0 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -392,6 +392,8 @@ def execute_model( model_input, worker_input, kwargs = inputs num_steps = worker_input.num_steps + if execute_model_req is not None: + kwargs["spec_step_idx"] = execute_model_req.spec_step_idx self.execute_worker(worker_input)