diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9991060a31621..3918e3e867695 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -2,7 +2,7 @@ # adding a new command to an existing step. See different options here for examples. # This script will be feed into Jinja template in `test-template-aws.j2` at -# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 +# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 # to generate the final pipeline yaml file. # Documentation @@ -15,7 +15,7 @@ # mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] # gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100 # num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4. -# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host, +# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host, # in this case, commands must be specified. the first command runs on first host, the second # command runs on the second host. # working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests @@ -24,8 +24,8 @@ # When adding a test # - If the test belong to an existing group, add it there # - If the test is short, add to any existing step -# - If the test takes more than 10min, then it is okay to create a new step. -# Note that all steps execute in parallel. +# - If the test takes more than 10min, then it is okay to create a new step. +# Note that all steps execute in parallel. steps: ##### fast check tests ##### @@ -145,14 +145,14 @@ steps: - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py - label: Metrics, Tracing Test # 10min - num_gpus: 2 + num_gpus: 2 fast_check: true source_file_dependencies: - vllm/ - tests/metrics - tests/tracing commands: - - pytest -v -s metrics + - pytest -v -s metrics - "pip install \ 'opentelemetry-sdk>=1.26.0,<1.27.0' \ 'opentelemetry-api>=1.26.0,<1.27.0' \ @@ -254,7 +254,7 @@ steps: - vllm/model_executor/guided_decoding - tests/test_logits_processor - tests/model_executor/test_guided_processors - commands: + commands: - pytest -v -s test_logits_processor.py - pytest -v -s model_executor/test_guided_processors.py @@ -265,7 +265,7 @@ steps: - vllm/model_executor/models/eagle.py commands: - pytest -v -s spec_decode/e2e/test_multistep_correctness.py - - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_mtp_correctness.py - pytest -v -s spec_decode/e2e/test_eagle_correctness.py - label: LoRA Test %N # 15min each @@ -580,7 +580,7 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn # This test runs llama 13B, so it is required to run on 4 GPUs. - pytest -v -s -x lora/test_long_context.py - # There is some Tensor Parallelism related processing logic in LoRA that + # There is some Tensor Parallelism related processing logic in LoRA that # requires multi-GPU testing for validation. - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py @@ -605,7 +605,7 @@ steps: - vllm/ - tests/weight_loading commands: - - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt ##### multi gpus test ##### @@ -617,7 +617,7 @@ steps: num_gpus: 4 source_file_dependencies: - vllm/ - commands: + commands: # NOTE: don't test llama model here, it seems hf implementation is buggy # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py diff --git a/tests/models/registry.py b/tests/models/registry.py index 17bfe1d21e4ad..04148ac64508e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -295,6 +295,9 @@ def check_available_online( speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501 "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 + "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", + speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 + trust_remote_code=True), } _FALLBACK_MODEL = { diff --git a/tests/spec_decode/e2e/test_mtp_correctness.py b/tests/spec_decode/e2e/test_mtp_correctness.py new file mode 100644 index 0000000000000..0bad19f61d305 --- /dev/null +++ b/tests/spec_decode/e2e/test_mtp_correctness.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: Apache-2.0 +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various number of speculative tokens. + +With those tests, we can say at least, mtp would not break the +correctess for the target model outputs. +""" + +import pytest + +from .conftest import run_equality_correctness_test + +# main model +MAIN_MODEL = "luccafong/deepseek_mtp_main_random" + +# max. number of speculative tokens: this corresponds to +# num_nextn_predict_layers in the config.json of the speculator model. +MAX_SPEC_TOKENS = 1 + +# precision +PRECISION = "bfloat16" + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int): + + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "enforce_eager": False, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size: int, + output_len: int, seed: int): + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness_with_preemption( + vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "num_speculative_tokens": k, + } + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_different_k(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that mtp speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_disable_queue(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that mtp speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) diff --git a/vllm/config.py b/vllm/config.py index 5c220ed136301..16c39baa2b841 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -762,7 +762,7 @@ def get_hidden_size(self) -> int: def is_deepseek_mla(self) -> bool: return (hasattr(self.hf_text_config, "model_type")) \ and (self.hf_text_config.model_type in \ - ('deepseek_v2', 'deepseek_v3'))\ + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\ and (self.hf_text_config.kv_lora_rank is not None) def get_head_size(self) -> int: @@ -855,8 +855,12 @@ def get_num_attention_heads(self, def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> Tuple[int, int]: from vllm.distributed.utils import get_pp_indices - total_num_hidden_layers = getattr(self.hf_text_config, - "num_hidden_layers", 0) + if self.hf_text_config.model_type == "deepseek_mtp": + total_num_hidden_layers = getattr(self.hf_text_config, + "num_nextn_predict_layers", 0) + else: + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size pp_size = parallel_config.pipeline_parallel_size start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) @@ -1688,6 +1692,18 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str + @staticmethod + def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: + if hf_config.model_type == "deepseek_v3": + hf_config.model_type = "deepseek_mtp" + if hf_config.model_type == "deepseek_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["DeepSeekMTPModel"] + }) + return hf_config + @staticmethod def maybe_create_spec_config( target_model_config: ModelConfig, @@ -1770,12 +1786,18 @@ def maybe_create_spec_config( Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if the necessary conditions are met, else None. """ - if speculative_model is None: if num_speculative_tokens is not None: - raise ValueError("num_speculative_tokens was provided without " - "speculative_model.") - return None + if target_model_config.hf_text_config.model_type \ + == "deepseek_v3": + # use the draft model from the same model: + speculative_model = target_model_config.model + else: + raise ValueError( + "num_speculative_tokens was provided without " + "speculative_model.") + else: + return None if (speculative_disable_by_batch_size is not None and speculative_disable_by_batch_size < 2): @@ -1829,6 +1851,7 @@ def maybe_create_spec_config( max_seq_len_to_capture=target_model_config. max_seq_len_to_capture, max_logprobs=target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, ) draft_hf_config = draft_model_config.hf_config @@ -1845,7 +1868,6 @@ def maybe_create_spec_config( if (num_speculative_tokens is not None and hasattr(draft_hf_config, "num_lookahead_tokens")): draft_hf_config.num_lookahead_tokens = num_speculative_tokens - n_predict = getattr(draft_hf_config, "n_predict", None) if n_predict is not None: if num_speculative_tokens is None: @@ -1959,8 +1981,9 @@ def _verify_and_get_draft_model_tensor_parallel_size( speculative_draft_tensor_parallel_size = 1 if target_parallel_config.tensor_parallel_size > 1: logger.warning( - "MLPSpeculator cannot currently be run with tp>1; " - "setting speculative_draft_tensor_parallel_size=1") + "%s cannot currently be run with tp>1; " + "setting speculative_draft_tensor_parallel_size=1", + draft_hf_config.model_type) else: speculative_draft_tensor_parallel_size = \ target_parallel_config.tensor_parallel_size diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py new file mode 100644 index 0000000000000..1a051992a3065 --- /dev/null +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Iterable, List, Optional, Set, Tuple + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .deepseek_v2 import (DeepseekV2DecoderLayer, + get_spec_layer_idx_from_weight_name) +from .utils import maybe_prefix + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class DeepSeekMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config, + cache_config, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=None) + hidden_states = residual + hidden_states + return self.shared_head(hidden_states) + + +class DeepSeekMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + DeepSeekMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( + input_ids, + positions, + kv_caches[spec_step_idx], + attn_metadata, + previous_hidden_states, + inputs_embeds, + spec_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + hidden_states, sampling_metadata) + return logits + + +class DeepSeekMTP(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, previous_hidden_states, + inputs_embeds, spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, sampling_metadata, + spec_step_idx) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + spec_layer_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") + return name diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index fd0e58fa1458d..a4d52c613b3e1 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -732,13 +732,9 @@ def load_weights(self, weights: Iterable[Tuple[str, if "rotary_emb.inv_freq" in name: continue - # TODO(simon): support nextn predict layers - if hasattr(self.config, "num_nextn_predict_layers" - ) and self.config.num_nextn_predict_layers > 0: - assert self.config.num_nextn_predict_layers == 1 - layer_idx = self.config.num_hidden_layers - if name.startswith(f"model.layers.{layer_idx}"): - continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -805,3 +801,15 @@ def load_weights(self, weights: Iterable[Tuple[str, class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass + + +def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 775398e003cd2..81623defd3379 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -187,6 +187,7 @@ _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), + "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } diff --git a/vllm/sequence.py b/vllm/sequence.py index 45d0e5bc76804..c0425ba33c9af 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1307,6 +1307,8 @@ class ExecuteModelRequest( previous_hidden_states: Optional[HiddenStates] = None # The number of forward steps to run. num_steps: int = 1 + # The step index for spec model input. + spec_step_idx: Optional[int] = None # Finished request ids since last step. finished_requests_ids: List[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 3948298db40c2..7353d3c53ae97 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -153,7 +153,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): return False # TODO: Add support for other attn backends - if self.attn_backend.get_name() != "FLASH_ATTN": + if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"): return False # TODO: Add support for LORA @@ -175,6 +175,7 @@ def execute_model( previous_hidden_states: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs, ) -> Optional[List[SamplerOutput]]: """Executes num_steps forward passes with advacement of input tensors on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. @@ -271,10 +272,17 @@ def execute_model( for step in range(num_steps): multi_modal_kwargs = model_input.multi_modal_kwargs or {} - kwargs = {"previous_hidden_states": hidden_states} \ + model_execute_kwargs = {"previous_hidden_states": hidden_states} \ if previous_hidden_states is not None else {} + compute_logits_kwargs = {} # Run model + if hasattr(self.model.config, "num_nextn_predict_layers"): + # for DeepSeek MTP only to use the corresponding layer for + # each step + spec_step_idx = kwargs.get("spec_step_idx", step) + model_execute_kwargs["spec_step_idx"] = spec_step_idx + compute_logits_kwargs["spec_step_idx"] = spec_step_idx with set_forward_context(model_input.attn_metadata, self.vllm_config): hidden_states = model_executable( @@ -285,13 +293,15 @@ def execute_model( intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), - **kwargs, + **model_execute_kwargs, ) # Compute the logits. logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata) - + model_input.sampling_metadata, + **compute_logits_kwargs) + if not self.is_driver_worker: + return [] # Sample the next token. output = self.model.sample( logits=logits, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 33b1be54c8b3c..fce06a81ff04a 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -108,6 +108,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": typical_acceptance_sampler_posterior_alpha, disable_logprobs=speculative_config.disable_logprobs, disable_log_stats=speculative_config.disable_log_stats, + num_speculative_tokens=speculative_config.num_speculative_tokens, ) return spec_decode_worker @@ -153,10 +154,12 @@ def create_worker( typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, disable_log_stats: bool, + num_speculative_tokens: int, ) -> "SpecDecodeWorker": allow_zero_draft_token_step = True enable_lm_head_weight_load = False + num_spec_prefill_steps = 1 ngram_prompt_lookup_max = ( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( @@ -179,14 +182,16 @@ def create_worker( elif draft_model_config.hf_config.model_type == "medusa": proposer_worker = MedusaWorker(**draft_worker_kwargs) else: - if draft_tp == 1: + if draft_tp == 1 or draft_model_config.hf_config.model_type ==\ + "deepseek_mtp": if current_platform.is_cuda_alike(): draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner else: if draft_model_config.hf_config.model_type == "eagle": raise NotImplementedError( - "EAGLE does not support TP > 1 yet") + f"{draft_model_config.hf_config.model_type} " + "does not support TP > 1 yet") allow_zero_draft_token_step = False @@ -195,6 +200,8 @@ def create_worker( enable_lm_head_weight_load = True proposer_worker = MultiStepWorker(**draft_worker_kwargs) + if draft_model_config.hf_config.model_type == "deepseek_mtp": + num_spec_prefill_steps = num_speculative_tokens proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp) @@ -247,7 +254,8 @@ def create_worker( disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, allow_zero_draft_token_step=allow_zero_draft_token_step, - enable_lm_head_weight_load=enable_lm_head_weight_load) + enable_lm_head_weight_load=enable_lm_head_weight_load, + num_spec_prefill_steps=num_spec_prefill_steps) def __init__( self, @@ -261,6 +269,7 @@ def __init__( disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, enable_lm_head_weight_load: Optional[bool] = False, + num_spec_prefill_steps: int = 1, ): """ Create a SpecDecodeWorker. @@ -293,6 +302,10 @@ def __init__( draft model is larger than 1 (TODO: #5814) enable_lm_head_weight_load: whether to load lm_head weight for draft models like eagle. + num_spec_prefill_steps: number of speculative prefill steps to run + before the speculative decoding starts. This is only used when + the draft model is a deepseek_mtp model that requires prefill + kv cache separately for each MTP layer. """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker @@ -326,6 +339,7 @@ def __init__( self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs self._disable_log_stats = disable_log_stats + self._num_spec_prefill_steps = num_spec_prefill_steps def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -685,8 +699,9 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, execute_model_req.previous_hidden_states = \ prepare_prefill_hidden_states( sampler_output.prefill_hidden_states) - - self.proposer_worker.execute_model(execute_model_req) + for i in range(self._num_spec_prefill_steps): + execute_model_req.spec_step_idx = i + self.proposer_worker.execute_model(execute_model_req) sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c7814f17375b2..31dab15dd7dd1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -99,6 +99,7 @@ class ModelInputForGPU(ModelRunnerInputBase): virtual_engine: int = 0 async_callback: Optional[Callable] = None scheduler_outputs: Optional[SchedulerOutputs] = None + previous_hidden_states: Optional[torch.Tensor] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -1649,6 +1650,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") @@ -1706,6 +1708,10 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_inner_state else {} + previous_hidden_states = kwargs.get("previous_hidden_states") + model_kwargs = {} + if previous_hidden_states is not None: + model_kwargs["previous_hidden_states"] = previous_hidden_states if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_start = torch.cuda.Event(enable_timing=True) @@ -1723,7 +1729,9 @@ def execute_model( intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), - **seqlen_agnostic_kwargs) + **seqlen_agnostic_kwargs, + **model_kwargs, + ) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): @@ -1815,7 +1823,7 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: 1. current vLLM instance is KV cache consumer/decode vLLM instance 2. this batch is not a profiling run 3. this batch is a prefill run - + Args: model_input: input to the model executable kv_caches: vLLM's paged memory @@ -1840,7 +1848,7 @@ def need_send_kv(self, model_input, kv_caches) -> bool: 1. current vLLM instance is KV cache producer/prefill vLLM instance 2. this batch is not a profiling run 3. this batch is a prefill run - + Args: model_input: input to the model executable kv_caches: vLLM's paged memory @@ -1976,7 +1984,11 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) if positions is not None: - self.input_buffers["positions"].copy_(positions, non_blocking=True) + # in some case like MLA, it will reuse positions in metadata + # but truncate them to the original size + # so the shape is not padded, we need to copy partial only + self.input_buffers["positions"][:positions.shape[0]].copy_( + positions, non_blocking=True) if self.backend_name != "NO_ATTENTION": self.input_buffers["slot_mapping"].copy_( diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 38d2b712eff57..bae37cb7155f0 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -46,7 +46,10 @@ def _init_attn_metadata_from_tensor_dict( valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): if field.name in tensor_dict: - valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) + if field.name == "input_positions": + valid_attn_kwargs[field.name] = tensor_dict[field.name] + else: + valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) tensor_dict["attn_metadata"] = attn_metadata diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 582aa460eb4fa..ff38e3bfc207b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -68,10 +68,10 @@ def __init__( speculative_config = self.speculative_config model_config = self.model_config speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.model == - model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type == + model_config.hf_config.model_type) \ or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator", "eagle"]) \ + not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 83fcf0865ae1c..190429074d56c 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -397,6 +397,8 @@ def execute_model( model_input, worker_input, kwargs = inputs num_steps = worker_input.num_steps + if (execute_model_req is not None and execute_model_req.spec_step_idx): + kwargs["spec_step_idx"] = execute_model_req.spec_step_idx self.execute_worker(worker_input)