Skip to content

Commit 8492294

Browse files
authored
Implement custom moe op for Mixtral (#303)
Enable new mixture_of_experts op for bf16 route. When using quantization, run with legacy moe solution.
2 parents 4183a07 + 5a271c8 commit 8492294

File tree

2 files changed

+54
-11
lines changed

2 files changed

+54
-11
lines changed

vllm/hpu/ops.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,16 @@ def forward(self, state):
231231
return torch.matmul(state, self.weight)
232232

233233

234+
def calculate_routing_tensors(score, topk, hidden_states_dtype):
235+
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
236+
routing_weights, selected_experts = torch.topk(routing_weights,
237+
topk,
238+
dim=-1)
239+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
240+
routing_weights = routing_weights.to(hidden_states_dtype)
241+
return routing_weights, selected_experts
242+
243+
234244
class StaticFusedMOE(torch.nn.Module):
235245

236246
def __init__(self, num_total_experts):
@@ -243,12 +253,8 @@ def __init__(self, num_total_experts):
243253

244254
def forward(self, hidden_states, w1, w2, score, topk):
245255
B, D = hidden_states.shape
246-
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
247-
routing_weights, selected_experts = torch.topk(routing_weights,
248-
topk,
249-
dim=-1)
250-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
251-
routing_weights = routing_weights.to(hidden_states.dtype)
256+
routing_weights, selected_experts = calculate_routing_tensors(
257+
score, topk, hidden_states.dtype)
252258
final_hidden_states = torch.zeros((1, B, D),
253259
dtype=hidden_states.dtype,
254260
device=hidden_states.device)
@@ -271,3 +277,33 @@ def forward(self, hidden_states, w1, w2, score, topk):
271277
final_hidden_states += current_hidden_states_static
272278

273279
return final_hidden_states.view(-1, D)
280+
281+
282+
class DynamicFusedMOE(torch.nn.Module):
283+
284+
def __init__(self, num_total_experts):
285+
super().__init__()
286+
self.num_total_experts = num_total_experts
287+
288+
def forward(self, hidden_states, w1, w2, score, topk):
289+
htorch.core.mark_step()
290+
routing_weights, selected_experts = calculate_routing_tensors(
291+
score, topk, hidden_states.dtype)
292+
# pre-processing for custom op inputs
293+
experts_range = range(self.num_total_experts)
294+
w1_list = [w1[i,:,:].squeeze() for i in experts_range]
295+
w2_list = [w2[i,:,:].squeeze() for i in experts_range]
296+
297+
final_hidden_states = torch.ops.hpu.mixture_of_experts(
298+
hidden_states=hidden_states,
299+
expert_routing_table=selected_experts,
300+
router_weights=routing_weights,
301+
w12=w1_list,
302+
w3=w2_list,
303+
permuted_weights=True,
304+
activation="silu",
305+
experts_min=0,
306+
experts_max=7
307+
)
308+
309+
return final_hidden_states.view(-1, hidden_states.shape[1])

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,14 @@ def __init__(
202202
self.num_expert_group = num_expert_group
203203
self.topk_group = topk_group
204204
if is_hpu():
205-
from vllm.hpu.ops import StaticFusedMOE
206-
self.hpu_static_fused_moe = StaticFusedMOE(self.num_experts)
205+
from vllm.hpu.ops import StaticFusedMOE, DynamicFusedMOE
206+
from vllm.model_executor.layers.quantization.inc import INCConfig
207+
selected_fused_moe = (
208+
StaticFusedMOE
209+
if isinstance(quant_config, INCConfig)
210+
else DynamicFusedMOE
211+
)
212+
self.hpu_static_fused_moe = selected_fused_moe(self.num_experts)
207213

208214
if quant_config is None:
209215
self.quant_method: Optional[QuantizeMethodBase] = (
@@ -254,24 +260,25 @@ def weight_loader(self, param: torch.nn.Parameter,
254260
shard_size = self.intermediate_size_per_partition
255261
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
256262

263+
from vllm.hpu.ops import StaticFusedMOE
257264
# w1, gate_proj case: Load into first shard of w13.
258265
if shard_id == 0:
259266
param_data[expert_id,
260267
0:shard_size, :] = loaded_weight[shard, :]
261-
if is_hpu():
268+
if is_hpu() and isinstance(self.hpu_static_fused_moe, StaticFusedMOE):
262269
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(
263270
param_data[expert_id])
264271
# w3, up_proj case: Load into second shard of w13.
265272
elif shard_id == 2:
266273
param_data[expert_id, shard_size:2 *
267274
shard_size, :] = loaded_weight[shard, :]
268-
if is_hpu():
275+
if is_hpu() and isinstance(self.hpu_static_fused_moe, StaticFusedMOE):
269276
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(
270277
param_data[expert_id])
271278
# w2, down_proj case: Load into only shard of w2.
272279
elif shard_id == 1:
273280
param_data[expert_id, :, :] = loaded_weight[:, shard]
274-
if is_hpu():
281+
if is_hpu() and isinstance(self.hpu_static_fused_moe, StaticFusedMOE):
275282
self.hpu_static_fused_moe.w2_list[expert_id].set_weight(
276283
param_data[expert_id])
277284
else:

0 commit comments

Comments
 (0)