-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrouter.py
443 lines (364 loc) · 16.4 KB
/
router.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
import torch
from megatron.core import parallel_state
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.moe_utils import (
MoEAuxLossAutoScaler,
save_to_aux_losses_tracker,
sinkhorn,
switch_load_balancing_loss_func,
topk_softmax_with_capacity,
z_loss_func,
)
from megatron.core.transformer.transformer_config import TransformerConfig
class Router(ABC, MegatronModule):
"""Base Router class"""
def __init__(self, config: TransformerConfig) -> None:
"""
Initialize the Router module.
Args:
config (TransformerConfig): Configuration object for the Transformer model.
"""
super().__init__(config)
self.config = config
self.num_experts = self.config.num_moe_experts
self.moe_aux_loss_func = None
self.layer_number = None
# Initialize the gate weights.
# TODO: Add support for GPU initialization, which requires updating the golden values.
self.weight = torch.nn.Parameter(
torch.empty((self.config.num_moe_experts, self.config.hidden_size), dtype=torch.float32)
)
if config.perform_initialization:
config.init_method(self.weight)
self.weight.data = self.weight.data.to(dtype=config.params_dtype)
setattr(self.weight, 'sequence_parallel', config.sequence_parallel)
def gating(self, input: torch.Tensor):
"""Forward pass of the router gate.
Args:
input (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Logits tensor.
"""
if self.weight.device.type == 'cpu':
# move weights to GPU
self.weight.data = self.weight.data.to(device=torch.cuda.current_device())
logits = torch.nn.functional.linear(input, self.weight)
return logits
@abstractmethod
def routing(self, logits: torch.Tensor):
"""Routing function.
Args:
logits (torch.Tensor): Logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
probabilities and mapping.
"""
raise NotImplementedError("Routing function not implemented.")
@abstractmethod
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
raise NotImplementedError("Forward function not implemented.")
def set_layer_number(self, layer_number: int):
"""Set the layer number for the router."""
self.layer_number = layer_number
class TopKRouter(Router):
"""Route each token to the top-k experts."""
def __init__(self, config: TransformerConfig) -> None:
"""Initialize the zero token dropping router.
Args:
config (TransformerConfig): The configuration for the transformer model.
"""
super().__init__(config=config)
self.topk = self.config.moe_router_topk
self.routing_type = self.config.moe_router_load_balancing_type
self.input_jitter = None
def sinkhorn_load_balancing(self, logits: torch.Tensor):
"""Apply sinkhorn routing to the logits tensor.
Args:
logits (torch.Tensor): The logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing token assignment
probabilities and mask.
"""
def _sinkhorn_activation(logits):
if self.topk == 1:
logits = torch.sigmoid(logits)
else: # k > 1
logits = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
return logits
assert self.config.moe_aux_loss_coeff == 0, "Sinkhorn routing does not support aux loss."
if self.training:
with torch.no_grad():
norm_logits = sinkhorn(
logits.to(dtype=torch.float32)
) # explicit fp32 conversion for stability
_, indices = torch.topk(norm_logits, k=self.topk, dim=1)
logits = _sinkhorn_activation(logits)
else:
logits = _sinkhorn_activation(logits)
_, indices = torch.topk(logits, k=self.topk, dim=1)
map = torch.zeros_like(logits).int().scatter(1, indices, 1).bool()
scores = logits * map
return scores, map
def aux_loss_load_balancing(self, logits: torch.Tensor):
"""Apply loss-based load balancing to the logits tensor.
Args:
logits (torch.Tensor): the logits tensor after gating, shape: [num_tokens, num_experts].
Returns:
probs (torch.Tensor): The probabilities of token to experts assignment.
indices (torch.Tensor): The mask of token to experts assignment.
"""
probs, routing_map, tokens_per_expert = topk_softmax_with_capacity(
logits,
self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
deterministic_mode=self.config.deterministic_mode,
)
if self.training:
# Apply load balancing loss
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
probs = self.apply_load_balancing_loss(scores, tokens_per_expert, activation=probs)
return probs, routing_map
def apply_load_balancing_loss(
self,
probs: torch.Tensor,
num_local_tokens_per_expert: torch.Tensor,
activation: torch.Tensor,
):
"""Applies auxiliary loss to the MoE layer.
Args:
probs (torch.Tensor): The probs output by the router for each token.
[num_tokens, num_experts]
num_local_tokens_per_expert (torch.Tensor): The number of tokens per expert.
[num_experts]
activation (torch.Tensor): The activation tensor to attach the gradient function to.
Returns:
torch.Tensor: The activation tensor with the attached gradient function.
"""
moe_aux_loss_coeff = self.config.moe_aux_loss_coeff
sequence_partition_group = None
if self.config.moe_token_dispatcher_type == "alltoall_seq":
sequence_partition_group = parallel_state.get_context_parallel_group()
moe_aux_loss_coeff /= parallel_state.get_tensor_model_parallel_world_size()
else:
sequence_partition_group = parallel_state.get_tensor_and_context_parallel_group()
aux_loss = switch_load_balancing_loss_func(
probs,
num_local_tokens_per_expert,
self.topk,
moe_aux_loss_coeff,
sequence_partition_group=sequence_partition_group,
)
save_to_aux_losses_tracker(
"load_balancing_loss",
aux_loss / moe_aux_loss_coeff,
self.layer_number,
self.config.num_layers,
reduce_group=sequence_partition_group,
)
activation = MoEAuxLossAutoScaler.apply(activation, aux_loss)
return activation
def apply_z_loss(self, logits):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
if self.config.moe_z_loss_coeff is not None and self.training:
moe_z_loss_coeff = (
self.config.moe_z_loss_coeff
/ parallel_state.get_tensor_and_context_parallel_world_size()
)
z_loss = z_loss_func(logits, moe_z_loss_coeff)
logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
save_to_aux_losses_tracker(
"z_loss", z_loss / moe_z_loss_coeff, self.layer_number, self.config.num_layers
)
return logits
def apply_input_jitter(self, input: torch.Tensor):
"""Add noise to the input tensor.
Refer to https://arxiv.org/abs/2101.03961.
Args:
input (Tensor): Input tensor.
Returns:
Tensor: Jittered input.
"""
if self.config.moe_input_jitter_eps is not None:
eps = self.config.moe_input_jitter_eps
if self.input_jitter is None:
self.input_jitter = torch.distributions.uniform.Uniform(
torch.tensor(1.0 - eps, device=input.device),
torch.tensor(1.0 + eps, device=input.device),
).rsample
return input * self.input_jitter(input.shape)
else:
return input
def routing(self, logits: torch.Tensor):
"""Top-k routing function
Args:
logits (torch.Tensor): Logits tensor after gating.
Returns:
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment,
with shape [num_tokens, num_experts].
"""
logits = logits.view(-1, self.config.num_moe_experts)
# Apply Z-Loss
logits = self.apply_z_loss(logits)
if self.config.moe_token_dispatcher_type == "alltoall_seq":
# Gather the logits from the TP region
logits = gather_from_sequence_parallel_region(logits)
if self.routing_type == "sinkhorn":
scores, routing_map = self.sinkhorn_load_balancing(logits)
elif self.routing_type == "aux_loss":
scores, routing_map = self.aux_loss_load_balancing(logits)
elif self.routing_type == "none":
# A naive top-k routing without load balancing
scores, routing_map, _ = topk_softmax_with_capacity(
logits,
self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
deterministic_mode=self.config.deterministic_mode,
)
else:
raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
return scores, routing_map
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
self.hidden = input.shape[-1]
# Apply input jitter
input = self.apply_input_jitter(input)
logits = self.gating(input)
logits = logits.view(-1, self.config.num_moe_experts)
scores, routing_map = self.routing(logits)
return scores, routing_map
class ReLURouter(Router):
"""Route each token to the experts with non-zero relu outputs."""
def __init__(self, config: TransformerConfig) -> None:
"""Initialize the relu router.
Args:
config (TransformerConfig): The configuration for the transformer model.
"""
super().__init__(config=config)
self.topk = self.config.moe_router_topk
# self.target_sparsity = 1 - self.topk / self.num_experts
self.input_jitter = None
def apply_input_jitter(self, input: torch.Tensor):
"""Add noise to the input tensor.
Refer to https://arxiv.org/abs/2101.03961.
Args:
input (Tensor): Input tensor.
Returns:
Tensor: Jittered input.
"""
if self.config.moe_input_jitter_eps is not None:
eps = self.config.moe_input_jitter_eps
if self.input_jitter is None:
self.input_jitter = torch.distributions.uniform.Uniform(
torch.tensor(1.0 - eps, device=input.device),
torch.tensor(1.0 + eps, device=input.device),
).rsample
return input * self.input_jitter(input.shape)
else:
return input
def l1_reg_load_balancing(self, logits: torch.Tensor):
"""Apply load balancing L1 regularization loss to the ReLU output.
Args:
logits (torch.Tensor): Logits tensor after gating, shape: [num_tokens, num_experts].
Returns:
probs (torch.Tensor): The probabilities of token to experts assignment, shape [num_tokens, num_experts].
routing_map (torch.Tensor): The mapping of token to experts assignment, shape [num_tokens, num_experts].
"""
probs = torch.relu(logits)
routing_map = probs > 0
if self.training and torch.is_grad_enabled():
num_local_tokens_per_expert = routing_map.sum(dim=0)
# Apply l1 regularization
probs = self.apply_l1_reg(probs, num_local_tokens_per_expert, activation=probs)
# Record the sparsity of the ReLU output
sparsity = 1 - routing_map.sum().float() / routing_map.numel()
self.config.moe_relu_sparsity += sparsity
return probs, routing_map
def apply_l1_reg(self, probs: torch.Tensor, num_local_tokens_per_expert: torch.Tensor, activation: torch.Tensor):
"""Apply load balancing L1 regularization loss to the ReLU output.
Args:
probs (torch.Tensor): The probs output by the router for each token.
[num_tokens, num_experts]
num_local_tokens_per_expert (torch.Tensor): The number of tokens per expert.
[num_experts]
activation (torch.Tensor): The activation tensor to attach the gradient function to.
Returns:
torch.Tensor: The activation tensor with the attached gradient function.
"""
l1_reg_coeff = self.config.moe_relu_l1_reg_coeff.item()
sequence_partition_group = None
if self.config.moe_token_dispatcher_type == "alltoall_seq":
sequence_partition_group = parallel_state.get_context_parallel_group()
l1_reg_coeff /= parallel_state.get_tensor_model_parallel_world_size()
else:
sequence_partition_group = parallel_state.get_tensor_and_context_parallel_group()
# L1 regularization with load balancing shares the same formula with switch load balancing loss:
# l1_reg = sum((probs_per_expert/num_tokens) *
# (tokens_per_expert/(num_tokens*topk))) * num_experts * l1_reg_coeff.
l1_reg = switch_load_balancing_loss_func(
probs,
num_local_tokens_per_expert,
self.topk,
l1_reg_coeff,
sequence_partition_group=sequence_partition_group,
)
save_to_aux_losses_tracker(
"l1_reg_loss",
l1_reg / l1_reg_coeff,
self.layer_number,
self.config.num_layers,
reduce_group=sequence_partition_group,
)
activation = MoEAuxLossAutoScaler.apply(activation, l1_reg)
return activation
def routing(self, logits: torch.Tensor):
"""ReLU routing function
Args:
logits (torch.Tensor): Logits tensor after gating, shape: [num_tokens, num_experts].
Returns:
probs (torch.Tensor): The probabilities of token to experts assignment, shape [num_tokens, num_experts].
routing_map (torch.Tensor): The mapping of token to experts assignment, shape [num_tokens, num_experts].
"""
logits = logits.view(-1, self.config.num_moe_experts)
if self.config.moe_token_dispatcher_type == "alltoall_seq":
# Gather the logits from the TP region
logits = gather_from_sequence_parallel_region(logits)
scores, routing_map = self.l1_reg_load_balancing(logits)
return scores, routing_map
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
self.hidden = input.shape[-1]
# Apply input jitter
input = self.apply_input_jitter(input)
logits = self.gating(input)
logits = logits.view(-1, self.config.num_moe_experts)
scores, routing_map = self.routing(logits)
return scores, routing_map