diff --git a/tests/spec_decode/e2e/test_mtp_correctness.py b/tests/spec_decode/e2e/test_mtp_correctness.py index 47b9aeb731fb1..0bad19f61d305 100644 --- a/tests/spec_decode/e2e/test_mtp_correctness.py +++ b/tests/spec_decode/e2e/test_mtp_correctness.py @@ -49,6 +49,9 @@ # Main model "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.85 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -88,6 +91,9 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, # Main model "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.85 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -184,6 +190,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, # Main model "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -224,6 +233,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption( # Main model "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -268,6 +280,9 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs, # Main model "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 2e6108b34211f..fce06a81ff04a 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -256,6 +256,7 @@ def create_worker( allow_zero_draft_token_step=allow_zero_draft_token_step, enable_lm_head_weight_load=enable_lm_head_weight_load, num_spec_prefill_steps=num_spec_prefill_steps) + def __init__( self, proposer_worker: ProposerWorkerBase,