Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significant performance improvement on MoE block of SwitchTransformer #30490

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,22 @@ def forward(self, hidden_states):
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)
# for idx, expert in enumerate(self.experts.values()):
# token_indices = router_mask[:, :, idx].bool()
# next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)
ranggihwang marked this conversation as resolved.
Show resolved Hide resolved

# Preformance improvement version of Switch Transformer
# It utilized sparse tensor and only access the activated experts
# This significantly reduces latency proprotional to the number of experts.
router_mask = router_mask.bool()
idx_mask = router_mask.transpose(1,2)
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2)
idx_mask = idx_mask.sum(dim=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
idx_mask = idx_mask.sum(dim=2)
idx_mask = idx_mask.sum(dim=1)

idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[0].tolist() # length: number of "activated" expert / value: index
for idx in idx_mask:
next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx)) \
(hidden_states[router_mask[:, :, idx]])

hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
Expand Down