Skip to content

Commit 32176fe

Browse files
authored
[torch.compile] support moe models (#9632)
Signed-off-by: youkaichao <[email protected]>
1 parent 4e2d95e commit 32176fe

File tree

12 files changed

+217
-78
lines changed

12 files changed

+217
-78
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,22 +88,23 @@ def prepare(i: int):
8888
input_gating.copy_(gating_output[i])
8989

9090
def run():
91-
fused_moe(
92-
x,
93-
w1,
94-
w2,
95-
input_gating,
96-
topk,
97-
renormalize=True,
98-
inplace=True,
99-
override_config=config,
100-
use_fp8_w8a8=use_fp8_w8a8,
101-
use_int8_w8a16=use_int8_w8a16,
102-
w1_scale=w1_scale,
103-
w2_scale=w2_scale,
104-
a1_scale=a1_scale,
105-
a2_scale=a2_scale,
106-
)
91+
from vllm.model_executor.layers.fused_moe import override_config
92+
with override_config(config):
93+
fused_moe(
94+
x,
95+
w1,
96+
w2,
97+
input_gating,
98+
topk,
99+
renormalize=True,
100+
inplace=True,
101+
use_fp8_w8a8=use_fp8_w8a8,
102+
use_int8_w8a16=use_int8_w8a16,
103+
w1_scale=w1_scale,
104+
w2_scale=w2_scale,
105+
a1_scale=a1_scale,
106+
a2_scale=a2_scale,
107+
)
107108

108109
# JIT compilation & warmup
109110
run()

tests/compile/test_basic_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
@pytest.mark.parametrize(
1414
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
1515
[
16-
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASH_ATTN", "generate", True),
16+
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True),
1717
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
1818
["--quantization", "compressed-tensors"
1919
], 1, 1, "FLASH_ATTN", "generate", True),
20-
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
20+
("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True),
2121
# TODO: add multi-modality test for llava
2222
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
2323
])

tests/kernels/test_awq_marlin.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
import pytest
66
import torch
77

8+
import vllm.model_executor.layers.fused_moe # noqa
89
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
910
torch_moe_single)
1011
from vllm import _custom_ops as ops
11-
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
12-
fused_marlin_moe, single_marlin_moe)
1312
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
1413
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
1514
awq_marlin_quantize)
@@ -81,7 +80,7 @@ def test_fused_marlin_moe_awq(
8180
score = torch.randn((m, e), device="cuda", dtype=dtype)
8281

8382
topk_weights, topk_ids = fused_topk(a, score, topk, False)
84-
marlin_output = fused_marlin_moe(
83+
marlin_output = torch.ops.vllm.fused_marlin_moe(
8584
a,
8685
qweight1,
8786
qweight2,
@@ -150,14 +149,14 @@ def test_single_marlin_moe_multiply_awq(
150149

151150
score = torch.randn((m, e), device="cuda", dtype=dtype)
152151

153-
marlin_output = single_marlin_moe(a,
154-
qweight,
155-
scales,
156-
score,
157-
topk,
158-
renormalize=False,
159-
w_zeros=zp,
160-
num_bits=num_bits)
152+
marlin_output = torch.ops.vllm.single_marlin_moe(a,
153+
qweight,
154+
scales,
155+
score,
156+
topk,
157+
renormalize=False,
158+
w_zeros=zp,
159+
num_bits=num_bits)
161160

162161
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
163162

tests/kernels/test_moe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
from transformers import MixtralConfig
88
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
99

10+
import vllm.model_executor.layers.fused_moe # noqa
1011
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
1112
torch_moe, torch_moe_single)
1213
from vllm import _custom_ops as ops
1314
from vllm.model_executor.layers.fused_moe import fused_moe
14-
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
15-
fused_marlin_moe, single_marlin_moe)
1615
from vllm.model_executor.layers.fused_moe.fused_moe import (
1716
fused_topk, moe_align_block_size)
1817
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
@@ -193,7 +192,7 @@ def test_fused_marlin_moe(
193192
topk,
194193
renormalize=False,
195194
)
196-
marlin_output = fused_marlin_moe(
195+
marlin_output = torch.ops.vllm.fused_marlin_moe(
197196
a,
198197
qweight1,
199198
qweight2,
@@ -309,7 +308,7 @@ def test_single_marlin_moe_multiply(
309308
sort_indices = stack_and_dev(sort_indices_l)
310309

311310
score = torch.randn((m, e), device="cuda", dtype=dtype)
312-
marlin_output = single_marlin_moe(
311+
marlin_output = torch.ops.vllm.single_marlin_moe(
313312
a,
314313
qweight,
315314
scales,

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,43 @@
1+
from contextlib import contextmanager
2+
from typing import Any, Dict, Optional
3+
14
from vllm.model_executor.layers.fused_moe.layer import (
25
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
36
from vllm.triton_utils import HAS_TRITON
47

8+
_config: Optional[Dict[str, Any]] = None
9+
10+
11+
@contextmanager
12+
def override_config(config):
13+
global _config
14+
old_config = _config
15+
_config = config
16+
yield
17+
_config = old_config
18+
19+
20+
def get_config() -> Optional[Dict[str, Any]]:
21+
return _config
22+
23+
524
__all__ = [
625
"FusedMoE",
726
"FusedMoEMethodBase",
827
"FusedMoeWeightScaleSupported",
28+
"override_config",
29+
"get_config",
930
]
1031

1132
if HAS_TRITON:
12-
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
13-
fused_marlin_moe, single_marlin_moe)
33+
# import to register the custom ops
34+
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
35+
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
1436
from vllm.model_executor.layers.fused_moe.fused_moe import (
1537
fused_experts, fused_moe, fused_topk, get_config_file_name,
1638
grouped_topk)
1739

1840
__all__ += [
19-
"fused_marlin_moe",
20-
"single_marlin_moe",
2141
"fused_moe",
2242
"fused_topk",
2343
"fused_experts",

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Fused MoE utilities for GPTQ."""
22
import functools
3-
from typing import Any, Dict, Optional
3+
from typing import Optional
44

55
import torch
66

@@ -18,6 +18,7 @@ def get_scalar_type(num_bits: int, has_zp: bool):
1818
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
1919

2020

21+
@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[])
2122
def single_marlin_moe(
2223
hidden_states: torch.Tensor,
2324
w: torch.Tensor,
@@ -28,7 +29,6 @@ def single_marlin_moe(
2829
g_idx: Optional[torch.Tensor] = None,
2930
sort_indices: Optional[torch.Tensor] = None,
3031
w_zeros: Optional[torch.Tensor] = None,
31-
override_config: Optional[Dict[str, Any]] = None,
3232
num_bits: int = 8,
3333
is_k_full: bool = True,
3434
) -> torch.Tensor:
@@ -49,8 +49,6 @@ def single_marlin_moe(
4949
- topk (int): The number of top-k experts to select.
5050
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
5151
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
52-
- override_config (Optional[Dict[str, Any]]): Optional override
53-
for the kernel configuration.
5452
- num_bits (bool): The number of bits in expert weights quantization.
5553
5654
Returns:
@@ -79,7 +77,6 @@ def single_marlin_moe(
7977
w.shape,
8078
topk_ids.shape[1],
8179
None,
82-
override_config=override_config,
8380
is_marlin=True)
8481
config = get_config_func(M)
8582

@@ -122,6 +119,24 @@ def single_marlin_moe(
122119
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
123120

124121

122+
@single_marlin_moe.register_fake
123+
def _(
124+
hidden_states: torch.Tensor,
125+
w: torch.Tensor,
126+
scales: torch.Tensor,
127+
gating_output: torch.Tensor,
128+
topk: int,
129+
renormalize: bool,
130+
g_idx: Optional[torch.Tensor] = None,
131+
sort_indices: Optional[torch.Tensor] = None,
132+
w_zeros: Optional[torch.Tensor] = None,
133+
num_bits: int = 8,
134+
is_k_full: bool = True,
135+
) -> torch.Tensor:
136+
return torch.empty_like(hidden_states)
137+
138+
139+
@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[])
125140
def fused_marlin_moe(
126141
hidden_states: torch.Tensor,
127142
w1: torch.Tensor,
@@ -137,7 +152,6 @@ def fused_marlin_moe(
137152
sort_indices2: Optional[torch.Tensor] = None,
138153
w1_zeros: Optional[torch.Tensor] = None,
139154
w2_zeros: Optional[torch.Tensor] = None,
140-
override_config: Optional[Dict[str, Any]] = None,
141155
num_bits: int = 8,
142156
is_k_full: bool = True,
143157
) -> torch.Tensor:
@@ -161,8 +175,6 @@ def fused_marlin_moe(
161175
permutation.
162176
- topk_weights (torch.Tensor): Top-k weights.
163177
- topk_ids (torch.Tensor): Indices of topk-k elements.
164-
- override_config (Optional[Dict[str, Any]]): Optional override
165-
for the kernel configuration.
166178
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
167179
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
168180
- num_bits (bool): The number of bits in expert weights quantization.
@@ -209,7 +221,6 @@ def fused_marlin_moe(
209221
w2.shape,
210222
topk_ids.shape[1],
211223
None,
212-
override_config=override_config,
213224
is_marlin=True,
214225
)
215226
config = get_config_func(M)
@@ -311,3 +322,25 @@ def fused_marlin_moe(
311322

312323
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
313324
dim=1)
325+
326+
327+
@fused_marlin_moe.register_fake
328+
def _(
329+
hidden_states: torch.Tensor,
330+
w1: torch.Tensor,
331+
w2: torch.Tensor,
332+
w1_scale: torch.Tensor,
333+
w2_scale: torch.Tensor,
334+
gating_output: torch.Tensor,
335+
topk_weights: torch.Tensor,
336+
topk_ids: torch.Tensor,
337+
g_idx1: Optional[torch.Tensor] = None,
338+
g_idx2: Optional[torch.Tensor] = None,
339+
sort_indices1: Optional[torch.Tensor] = None,
340+
sort_indices2: Optional[torch.Tensor] = None,
341+
w1_zeros: Optional[torch.Tensor] = None,
342+
w2_zeros: Optional[torch.Tensor] = None,
343+
num_bits: int = 8,
344+
is_k_full: bool = True,
345+
) -> torch.Tensor:
346+
return torch.empty_like(hidden_states)

0 commit comments

Comments
 (0)