Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS support #790

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

MPS support #790

wants to merge 4 commits into from

Conversation

maximegmd
Copy link

Context

  • For testing purposes it can be useful to run directly on a local Mac computer.

Changelog

  • Checks support for BF16 on MPS device.
  • Added a configuration targeting MPS, changes to path were required due to the way Mac resolves symlinks in /tmp as /private/ActualPath.
  • Set optimizer to Adam instead of AdamW to fit in memory on 64GB devices.

Test plan

  • Ran a training job on Mistral 7b full finetune.
  • The current test jobs are very CUDA specific, maybe this could be changed as well?
  • We may need to integrate Mac runners in the pipeline.
  • Some dependencies such as bitsandbytes are not yet Mac compatible.

Copy link

pytorch-bot bot commented Apr 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/790

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 18, 2024
@joecummings
Copy link
Contributor

@maximegmd This is awesome! Can you post some loss curves for the finetune you ran?

@maximegmd
Copy link
Author

@maximegmd This is awesome! Can you post some loss curves for the finetune you ran?

I will complete a run during the weekend, losses looked fine but the Llama3 release changed my priorities ^^

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for making this change, supporting MPS has been on our TODO!

I'm a bit confused about how this is working because the device param is used to fetch the device using this utility function, which in turn depends on this function. We seemingly are never actually returning mps as the device. So how is this working? This'll just default to CPU I think

@maximegmd
Copy link
Author

device is not None when this function is called so it just passes 'mps' to torch.device() which is the expected pytorch name.

But you are correct that there is room for improvement to automatically return mps when device is not manually specified in the config.

@kartikayk
Copy link
Contributor

Oh good point, totally glossed over the fact the device = torch.device(device) outside the if block. Yup sounds good. What's the iter/sec you're getting with this on a mac?

@maximegmd
Copy link
Author

If I recall correctly it was around 20s/it but I suspect I was swapping a bit so I can probably improve the speed. The main issue is bitsandbytes not supporting MPS so it uses quite a bit of memory for the optimizer state.

I will try to push a llama3 config tomorrow with some numbers now that my llama3 finetune is running :)

@maximegmd
Copy link
Author

Here is a train run on Gemma 2B, sadly laptop went to sleep right before the end but this is a 14 hours run, should be representative enough.
log_1715964526.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants