Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] PyTorch (MPS) is faster than MLX in backward of convolution layer #1313

Open
arnold-yan opened this issue Aug 8, 2024 · 6 comments

Comments

@arnold-yan
Copy link

arnold-yan commented Aug 8, 2024

Describe the bug
Recently I profiled the neural network layer performance from MLX and compared with PyTorch. I found that although MLX forwarding is consistently faster than PyTorch, in some chips (M1 Pro, M1 Max), PyTorch is much faster (3x~6x) for convolution forward + backward. While in some chips such as M3 Max, MLX is faster than PyTorch.
image

To Reproduce
To reproduce this, I have two minimal examples. The networks just have several convolution layers. You may try these two scripts to verify the performance.

time_pytorch_mlx.zip

@awni
Copy link
Member

awni commented Aug 8, 2024

Same benchmark on an M2 Ultra

average time of Pytorch: 7.20261025428772
average time of MLX: 2.34059739112854

@abdussamettrkr
Copy link
Contributor

On M2Pro

  • average time of MLX: 4.472527980804443
  • average time of Pytorch: 10.073836088180542

@alwint3r
Copy link

On M3 Max

  • average time of MLX: 2.4813356399536133
  • average time of PyTorch: 7.1081931591033936

@awni
Copy link
Member

awni commented Aug 13, 2024

Thanks for the benchmarks everyone! There is clearly an unexpected performance cliff on M1 machines here as MLX is substantially faster on M2+. We'll need to take a deeper look at that to figure out where it's coming from.

@pyvadev
Copy link

pyvadev commented Sep 13, 2024

On M1

  • average time of MLX: 30.113215446472168
  • average time of Pytorch: 15.948616743087769

@jrp2014
Copy link

jrp2014 commented Sep 13, 2024

M3 Max:
average time of MLX: 2.939736843109131
average time of Pytorch: 5.9829957485198975

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants