@@ -231,6 +231,16 @@ def forward(self, state):
231
231
return torch .matmul (state , self .weight )
232
232
233
233
234
+ def calculate_routing_tensors (score , topk , hidden_states_dtype ):
235
+ routing_weights = F .softmax (score , dim = 1 , dtype = torch .float32 )
236
+ routing_weights , selected_experts = torch .topk (routing_weights ,
237
+ topk ,
238
+ dim = - 1 )
239
+ routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
240
+ routing_weights = routing_weights .to (hidden_states_dtype )
241
+ return routing_weights , selected_experts
242
+
243
+
234
244
class StaticFusedMOE (torch .nn .Module ):
235
245
236
246
def __init__ (self , num_total_experts ):
@@ -243,12 +253,8 @@ def __init__(self, num_total_experts):
243
253
244
254
def forward (self , hidden_states , w1 , w2 , score , topk ):
245
255
B , D = hidden_states .shape
246
- routing_weights = F .softmax (score , dim = 1 , dtype = torch .float32 )
247
- routing_weights , selected_experts = torch .topk (routing_weights ,
248
- topk ,
249
- dim = - 1 )
250
- routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
251
- routing_weights = routing_weights .to (hidden_states .dtype )
256
+ routing_weights , selected_experts = calculate_routing_tensors (
257
+ score , topk , hidden_states .dtype )
252
258
final_hidden_states = torch .zeros ((1 , B , D ),
253
259
dtype = hidden_states .dtype ,
254
260
device = hidden_states .device )
@@ -271,3 +277,33 @@ def forward(self, hidden_states, w1, w2, score, topk):
271
277
final_hidden_states += current_hidden_states_static
272
278
273
279
return final_hidden_states .view (- 1 , D )
280
+
281
+
282
+ class DynamicFusedMOE (torch .nn .Module ):
283
+
284
+ def __init__ (self , num_total_experts ):
285
+ super ().__init__ ()
286
+ self .num_total_experts = num_total_experts
287
+
288
+ def forward (self , hidden_states , w1 , w2 , score , topk ):
289
+ htorch .core .mark_step ()
290
+ routing_weights , selected_experts = calculate_routing_tensors (
291
+ score , topk , hidden_states .dtype )
292
+ # pre-processing for custom op inputs
293
+ experts_range = range (self .num_total_experts )
294
+ w1_list = [w1 [i ,:,:].squeeze () for i in experts_range ]
295
+ w2_list = [w2 [i ,:,:].squeeze () for i in experts_range ]
296
+
297
+ final_hidden_states = torch .ops .hpu .mixture_of_experts (
298
+ hidden_states = hidden_states ,
299
+ expert_routing_table = selected_experts ,
300
+ router_weights = routing_weights ,
301
+ w12 = w1_list ,
302
+ w3 = w2_list ,
303
+ permuted_weights = True ,
304
+ activation = "silu" ,
305
+ experts_min = 0 ,
306
+ experts_max = 7
307
+ )
308
+
309
+ return final_hidden_states .view (- 1 , hidden_states .shape [1 ])
0 commit comments