From d705f8402df4b9b8e3267570cb916dedffde405e Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Fri, 14 Feb 2025 18:24:05 -0800 Subject: [PATCH] fix oom of mtp tests Signed-off-by: Lu Fang --- tests/spec_decode/e2e/test_mtp_correctness.py | 15 +++++++++++++++ vllm/spec_decode/spec_decode_worker.py | 1 + 2 files changed, 16 insertions(+) diff --git a/tests/spec_decode/e2e/test_mtp_correctness.py b/tests/spec_decode/e2e/test_mtp_correctness.py index 47b9aeb731fb1..0bad19f61d305 100644 --- a/tests/spec_decode/e2e/test_mtp_correctness.py +++ b/tests/spec_decode/e2e/test_mtp_correctness.py @@ -49,6 +49,9 @@ # Main model "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.85 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -88,6 +91,9 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, # Main model "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.85 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -184,6 +190,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, # Main model "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -224,6 +233,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption( # Main model "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -268,6 +280,9 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs, # Main model "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 2e6108b34211f..fce06a81ff04a 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -256,6 +256,7 @@ def create_worker( allow_zero_draft_token_step=allow_zero_draft_token_step, enable_lm_head_weight_load=enable_lm_head_weight_load, num_spec_prefill_steps=num_spec_prefill_steps) + def __init__( self, proposer_worker: ProposerWorkerBase,