Skip to content

Add VLLM_T_COMPILE_FULLGRAPH flag #932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 97 additions & 27 deletions .jenkins/test_config_t_compile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,98 +4,168 @@ stages:
steps:
- name: gsm8k_small_g3_tp1
flavor: g3
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-small.txt -t 1
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-small.txt -t 1
- name: gsm8k_small_g3_tp2
flavor: g3.s
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-small.txt -t 2
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-small.txt -t 2
- name: gsm8k_small_g2_tp1
flavor: g2
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-small.txt -t 1
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-small.txt -t 1
- name: gsm8k_small_g2_tp2
flavor: g2.s
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-small.txt -t 2
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-small.txt -t 2
- name: test_gsm8k_large_models
steps:
- name: gsm8k_large_g3_tp2
flavor: g3.s
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-large.txt -t 2
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-large.txt -t 2
- name: gsm8k_large_g2_tp4
flavor: g2.m
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-large.txt -t 4
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-large.txt -t 4
- name: test_gsm8k_fp8
steps:
- name: gsm8k_small_g3_tp1_fp8
flavor: g3
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-fp8.txt -t 1
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-fp8.txt -t 1
- name: gsm8k_small_g3_tp2_fp8
flavor: g3.s
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-fp8.txt -t 2
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-fp8.txt -t 2
- name: test_gsm8k_mss
steps:
- name: gsm8k_small_g3_tp1_mss
flavor: g3
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-mss.txt -t 1
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-mss.txt -t 1
- name: gsm8k_small_g2_tp1_mss
flavor: g2
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-mss.txt -t 1
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-mss.txt -t 1
- name: gsm8k_small_g3_tp2_mss
flavor: g3.s
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-mss.txt -t 2
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-mss.txt -t 2
- name: gsm8k_small_g2_tp2_mss
flavor: g2.s
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-mss.txt -t 2
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-mss.txt -t 2
- name: gsm8k_small_g2_tp1_spec_decode
flavor: g2
command: cd .jenkins/lm-eval-harness && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-mss.txt -t 1
command: >
cd .jenkins/lm-eval-harness &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-mss.txt -t 1
- name: test_gsm8k_spec_decode
steps:
- name: gsm8k_small_g2_tp1_mlp_spec_decode
flavor: g2
command: PT_HPU_LAZY_MODE=0 TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_mlp_correctness.py::test_mlp_e2e_greedy_correctness
command: >
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0 TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True
pytest -v tests/spec_decode/e2e/test_mlp_correctness.py::test_mlp_e2e_greedy_correctness
- name: gsm8k_small_g2_tp1_medusa_spec_decode
flavor: g2
command: PT_HPU_LAZY_MODE=0 TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_medusa_correctness.py::test_medusa_e2e_greedy_correctness
command: >
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0 TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True
pytest -v tests/spec_decode/e2e/test_medusa_correctness.py::test_medusa_e2e_greedy_correctness
- name: gsm8k_small_g2_tp1_eagle_spec_decode
flavor: g2
command: PT_HPU_LAZY_MODE=0 VLLM_COS_SIN_RECOMPUTE=true TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness
command: >
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0 VLLM_COS_SIN_RECOMPUTE=true TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True
pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness
- name: tests_lora
steps:
- name: test_llama_lora
flavor: g2
command: PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true pytest -v tests/lora/test_llama_hpu.py::test_llama_lora_1x
command: >
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true
pytest -v tests/lora/test_llama_hpu.py::test_llama_lora_1x
- name: test_multilora
flavor: g2
command: PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true pytest -v tests/lora/test_multilora_hpu.py::test_llama_multilora_1x
command: >
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true
pytest -v tests/lora/test_multilora_hpu.py::test_llama_multilora_1x
# - name: test_long_context
# flavor: g2
# command: PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true pytest -v tests/lora/test_long_context_hpu.py::test_quality
# command: VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true pytest -v tests/lora/test_long_context_hpu.py::test_quality
- name: tests_multimodal
steps:
- name: multimodal_small_g3_tp1
flavor: g3
command: cd .jenkins/vision && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-small.txt -t 1
command: >
cd .jenkins/vision &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-small.txt -t 1
- name: multimodal_small_g3_tp2
flavor: g3.s
command: cd .jenkins/vision && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-small.txt -t 2
command: >
cd .jenkins/vision &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-small.txt -t 2
- name: multimodal_small_g3_tp1_mss
flavor: g3
command: cd .jenkins/vision && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-mss.txt -t 1
command: >
cd .jenkins/vision && VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-mss.txt -t 1
- name: multimodal_small_g3_tp2_mss
flavor: g3.s
command: cd .jenkins/vision && PT_HPU_LAZY_MODE=0 bash run-tests.sh -c configs/models-mss.txt -t 2
command: >
cd .jenkins/vision &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
bash run-tests.sh -c configs/models-mss.txt -t 2
- name: tests_int4_quantization
steps:
- name: test_awq
flavor: g2
command: PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true pytest -v tests/quantization/test_awq.py::test_awq
command: >
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true
pytest -v tests/quantization/test_awq.py::test_awq
- name: test_gptq
flavor: g2
command: PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true pytest -v tests/quantization/test_gptq.py::test_gptq
command: >
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0 VLLM_SKIP_WARMUP=true
pytest -v tests/quantization/test_gptq.py::test_gptq
- name: tests_guided_decode
steps:
- name: test_lazy_outlines
flavor: g2
command: export VLLM_SKIP_WARMUP=true && pip install -e tests/vllm_test_utils && PT_HPU_LAZY_MODE=0 pytest -v tests/entrypoints/llm/test_lazy_outlines.py -s -vvv --log-cli-level=INFO
command: >
export VLLM_SKIP_WARMUP=true && pip install -e tests/vllm_test_utils &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
pytest -v tests/entrypoints/llm/test_lazy_outlines.py -s -vvv --log-cli-level=INFO
- name: test_guided_generate
flavor: g2
command: export VLLM_SKIP_WARMUP=true && pip install -e tests/vllm_test_utils && PT_HPU_LAZY_MODE=0 pytest -v tests/entrypoints/llm/test_guided_generate.py -s -vvv --log-cli-level=INFO
command: >
export VLLM_SKIP_WARMUP=true && pip install -e tests/vllm_test_utils &&
VLLM_T_COMPILE_FULLGRAPH=True PT_HPU_LAZY_MODE=0
pytest -v tests/entrypoints/llm/test_guided_generate.py -s -vvv --log-cli-level=INFO
1 change: 1 addition & 0 deletions README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of devi
- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL`: if `true` - logs graph compilations for every vLLM engine step, even if no compilation occurs. Disabled by default.
- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS`: if `true` - logs CPU fallbacks for each vLLM engine step, but only if any fallback occurs. Disabled by default.
- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL`: if `true` - logs CPU fallbacks for each vLLM engine step, even if no fallback occur. Disabled by default.
- `VLLM_T_COMPILE_FULLGRAPH`: if `true` - PyTorch compile function raises an error if any graph breaks happened during compilation. This allows an easy detection of existing graph breaks, which usually reduce the performance. Disabled by default.

**Performance Tuning Knobs:**

Expand Down
31 changes: 25 additions & 6 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,35 +245,54 @@ def __init__(self, model, vllm_config, layer_names):
self.set_causal_option(self.model)
if not is_fake_hpu() and not htorch.utils.internal.is_lazy(
) and not enforce_eager:
fullgraph = os.getenv('VLLM_T_COMPILE_FULLGRAPH',
'false').strip().lower() in ("1", "true")
if os.getenv('VLLM_REGIONAL_COMPILATION',
'true').lower() == 'true':
self.regional_compilation_layers_list = [
RMSNorm, VocabParallelEmbedding
]
self._regional_compilation(self.model)
self._regional_compilation(self.model, fullgraph)
else:
self.model = torch.compile(self.model,
backend='hpu_backend',
fullgraph=fullgraph,
dynamic=False)

def _regional_compilation(self,
module,
fullgraph,
parent_module=None,
module_name=None):
if isinstance(module, torch.nn.ModuleList):
for children_name, children_module in module.named_children():
self._compile_region(module, children_name, children_module)
self._compile_region(module, fullgraph, children_name,
children_module)
elif any(
isinstance(module, layer)
for layer in self.regional_compilation_layers_list):
self._compile_region(parent_module, module_name, module)
self._compile_region(
parent_module,
fullgraph,
module_name,
module,
)
else:
for children_name, children_module in module.named_children():
self._regional_compilation(children_module, module,
self._regional_compilation(children_module, fullgraph, module,
children_name)

def _compile_region(self, model, name, module):
module = torch.compile(module, backend='hpu_backend', dynamic=False)
def _compile_region(
self,
model,
fullgraph,
name,
module,
):
module = torch.compile(module,
backend='hpu_backend',
fullgraph=fullgraph,
dynamic=False)
setattr(model, name, module)

def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
Expand Down
Loading