Skip to content

Commit

Permalink
fix oom of mtp tests
Browse files Browse the repository at this point in the history
Signed-off-by: Lu Fang <[email protected]>
  • Loading branch information
luccafong committed Feb 17, 2025
1 parent 641ce86 commit d705f84
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/spec_decode/e2e/test_mtp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
Expand Down Expand Up @@ -88,6 +91,9 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
Expand Down Expand Up @@ -184,6 +190,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
Expand Down Expand Up @@ -224,6 +233,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
Expand Down Expand Up @@ -268,6 +280,9 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
Expand Down
1 change: 1 addition & 0 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def create_worker(
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)

def __init__(
self,
proposer_worker: ProposerWorkerBase,
Expand Down

0 comments on commit d705f84

Please sign in to comment.