Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significant performance improvement on MoE block of SwitchTransformer #30490

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch BigBirdPegasus model."""

"""PyTorch BigBirdPegasus model."""

import copy
import math
Expand Down Expand Up @@ -707,11 +706,9 @@ def bigbird_block_sparse_attention(
attention_probs[:, :, -2 * from_block_size : -from_block_size, :to_block_size] = second_last_attn_weights[
:, :, :, :to_block_size
] # 1st key block (global)
attention_probs[
:, :, -2 * from_block_size : -from_block_size, -3 * to_block_size :
] = second_last_attn_weights[
:, :, :, to_block_size : 4 * to_block_size
] # last three blocks (global + sliding)
attention_probs[:, :, -2 * from_block_size : -from_block_size, -3 * to_block_size :] = (
second_last_attn_weights[:, :, :, to_block_size : 4 * to_block_size]
) # last three blocks (global + sliding)
# random keys
for p1, i1, w1 in zip(range(bsz), rand_attn, second_last_attn_weights):
# p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,9 +1032,9 @@ def _update_causal_mask(
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)

if (
self.config._attn_implementation == "sdpa"
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Gemma model."""
"""PyTorch Gemma model."""

import math
import warnings
Expand Down Expand Up @@ -1018,9 +1018,9 @@ def _update_causal_mask(
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)

if (
self.config._attn_implementation == "sdpa"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch GPTSANJapanese model."""

"""PyTorch GPTSANJapanese model."""

import copy
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -273,9 +272,22 @@ def forward(self, hidden_states):
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)

# Preformance improvement version of Switch Transformer
# It utilized sparse tensor and only access the activated experts
# This significantly reduces latency proprotional to the number of experts.
router_mask = router_mask.bool()
idx_mask = router_mask.transpose(1, 2)
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2)
idx_mask = idx_mask.sum(dim=2)
idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
].tolist() # length: number of "activated" expert / value: index
for idx in idx_mask:
next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
hidden_states[router_mask[:, :, idx]]
)

hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch SwitchTransformers model."""

"""PyTorch SwitchTransformers model."""

import copy
import math
Expand Down Expand Up @@ -297,9 +296,22 @@ def forward(self, hidden_states):
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)

# Preformance improvement version of Switch Transformer
# It utilized sparse tensor and only access the activated experts
# This significantly reduces latency proprotional to the number of experts.
router_mask = router_mask.bool()
idx_mask = router_mask.transpose(1, 2)
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2)
Comment on lines +304 to +305
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
idx_mask = router_mask.transpose(1, 2)
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2)
idx_mask = router.reshape(batch*seq_len,num_experts).transpose(0,1)

equivalent and more understandable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a hint like # batch * seq, num_expert also helps!

idx_mask = idx_mask.sum(dim=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
idx_mask = idx_mask.sum(dim=2)
idx_mask = idx_mask.sum(dim=1)

idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens

idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
].tolist() # length: number of "activated" expert / value: index
for idx in idx_mask:
next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
hidden_states[router_mask[:, :, idx]]
)

hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
Expand Down
12 changes: 6 additions & 6 deletions tests/models/roc_bert/test_tokenization_roc_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_full_tokenizer(self):
def test_chinese(self):
tokenizer = RoCBertBasicTokenizer()

self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
self.assertListEqual(tokenizer.tokenize("ah\u535a\u63a8zz"), ["ah", "\u535a", "\u63a8", "zz"])

# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_lower(self):
Expand All @@ -82,7 +82,7 @@ def test_basic_tokenizer_lower(self):
self.assertListEqual(
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["hello", "!", "how", "are", "you", "?"]
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
self.assertListEqual(tokenizer.tokenize("H\u00e9llo"), ["hello"])

# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_false with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_lower_strip_accents_false(self):
Expand All @@ -91,7 +91,7 @@ def test_basic_tokenizer_lower_strip_accents_false(self):
self.assertListEqual(
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hällo", "!", "how", "are", "you", "?"]
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"])
self.assertListEqual(tokenizer.tokenize("H\u00e9llo"), ["h\u00e9llo"])

# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_true with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_lower_strip_accents_true(self):
Expand All @@ -100,7 +100,7 @@ def test_basic_tokenizer_lower_strip_accents_true(self):
self.assertListEqual(
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
self.assertListEqual(tokenizer.tokenize("H\u00e9llo"), ["hello"])

# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_default with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_lower_strip_accents_default(self):
Expand All @@ -109,7 +109,7 @@ def test_basic_tokenizer_lower_strip_accents_default(self):
self.assertListEqual(
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
self.assertListEqual(tokenizer.tokenize("H\u00e9llo"), ["hello"])

# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_no_lower(self):
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_is_whitespace(self):
self.assertTrue(_is_whitespace("\t"))
self.assertTrue(_is_whitespace("\r"))
self.assertTrue(_is_whitespace("\n"))
self.assertTrue(_is_whitespace("\u00A0"))
self.assertTrue(_is_whitespace("\u00a0"))

self.assertFalse(_is_whitespace("A"))
self.assertFalse(_is_whitespace("-"))
Expand Down