Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significant performance improvement on MoE block of SwitchTransformer #30490

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ranggihwang
Copy link

What does this PR do?

This PR includes a performant implementation of SwitchTransformersSparseMLP in the Google SwitchTransformer.
In the current implementation of the SwitchTransformer, it spans all possible experts, including the inactive ones.

for idx, expert in enumerate(self.experts.values()):
            token_indices = router_mask[:, :, idx].bool()
            next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)

This results in serious performance degradation of the SwitchTransformer.

스크린샷 2024-04-26 오전 2 16 44 As shown in this figure, the current implementation of the SwitchTransformer spans inactive experts, unnecessarily increasing latency. 스크린샷 2024-04-26 오전 2 17 37 This issue can be particularly severe in models with a larger number of experts, as it needlessly spans more experts.

However, in my custom implementation of SwitchTransformersSparseMLP, it only accesses and computes the active experts.

Advantages

  • This can significantly reduce the latency of the SwitchTransformer and make the model more accessible to a broader range of users.
  • This change achieves greater latency reductions when expert parameters are offloaded to the CPU or SSD.
  • This change addresses the problem of increasing latency proportional to the number of experts.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker @younesbelkada

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work and great investigation ! Thanks for this ! Can you confirm the slow SwitchTransformers test pass?

…ansformers.py


delete old switchtransformer

Co-authored-by: Younes Belkada <[email protected]>
@ranggihwang
Copy link
Author

ranggihwang commented Apr 26, 2024

Is there anything else that I need to do for this PR?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ranggihwang
Yes, please make sure to run make fixup so that the styling checks pass in our CI. For running the slow tests, can you run:

RUN_SLOW=1 pytest tests/models/switch_transformers/test_modeling_switch_transformers.py

@ranggihwang
Copy link
Author

ranggihwang commented Apr 27, 2024

@younesbelkada
I've changed the coding style using make fixup, but I'm encountering an error when running this command:

RUN_SLOW=1 pytest tests/models/switch_transformers/test_modeling_switch_transformers.py

I've noticed the same error with the original switch transformer code, so I assume it's not due to my changes. How can I resolve this issue?

If you need the error log, I can attach it here.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Though we do span on all expert, the way the model is trained emphasize an even loading, meaning on average the experts should all be used no?
Could you share a bit more which model you are using, are you pretraining?

Let's revert changes related to linting that are not supposed to be here as well

@ranggihwang
Copy link
Author

@ArthurZucker
That's a great point.
In the SwitchTransformer model, all experts can be used in the training phase, but in the inference, only some of them are utilized and that's the point where inefficiency is raised that the original code missed.
I'm mainly talking about inference, however, my custom code also can be used for training because it is mathematically equivalent to the original code.

For linting, is it okay for me to revert it to the previous version? Actually, this is the first time for me to contribute codes to HuggingFace, so I'm a bit confused about what to do next to get an acceptance for merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants