Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low performance of mx.fast.metal_kernel when calling the same kernel > 10000x times #1828

Closed
spichardo opened this issue Feb 3, 2025 · 4 comments

Comments

@spichardo
Copy link

We need some help to improve performance of mx.fast.metal_kernel . We have this turn-key tool for treatment planning of transcranial ultrasound stimulation called BabelBrain
that uses the GPU to accelerate calculations extensively. The tool is a multi GPU platform (Metal, CUDA and OpenCL), with Metal being the preferred backend as most of our users in neurosciences are Mac users. We use a lot of GPU kernels for this.

Historically, we used the py-metal-compute library to code all our GPU kernels, and we got great results even for the most demanding kernels (some of them matching the performance of an A6000 GPU).

We started to port from py-metal-compute to MLX mx.fast.metal_kernel as we want, in the long run, take advantage of all the other functionalities that MLX has, especially for ML extensions we will add to our tools.

Most simple kernels run equally well with both libraries, but we have some important performance differences with the most intense functions that solve PDEs with Finite-Difference Time-Difference (FDTD) solvers that call the same GPU kernels tens of thousands of times. In such cases, MLX runs 2.5x slower than py-metal-compute in small matrix sizes, and it gets much worse as the matrix size increases. We suspect that the lazy-execution model of MLX is causing some impact for such execution.

The difference in the execution model shows when with MLX with apparent faster iterations through 12000 calls to the kernel, but then performance takes a hit when results are recovered to Numpy arrays. We need to interface with Numpy arrays as this code is part of much more complex software for the treatment planning of transcranial ultrasound stimulation. py-metal-compute takes more time to complete the 12000 calls to the kernel, but the transfer to Numpy arrays is very small, translating in an overall wall-time is less than 1/2 that of MLX.

Because BabelBrain is a bit of a complex tool, we created this stand-alone repository BHTE_MLX that shows the issue. The repo is self contained, all code is there.

Do you have any suggestions we could explore to tweak mx.fast.metal_kernel? like controlling how the lazy model works; py-metal-compute has a more simplistic control of kernels where we can control when the command buffer is completed, so I wonder if we could do something similar, but we are not sure how to tweak the control of the kernel execution in MLX. Any suggestions are welcome. Our FDTD kernels are the last piece we must sort out before migrating completely to MLX.

@awni
Copy link
Member

awni commented Feb 4, 2025

One problem is that you are not evaluating the graph until the very end of the computation which is really slow since the graph gets really big and uses a lot of memory.

My first recommendation is to make sure you to read how lazy evaluation works in MLX. Check-out the documentation.

  • The call to mx.synchronize is basically a no-op in your code (and probably slowing things down). It is really only intended to be used with async_eval, so I would just remove that.
  • Instead, replace it with something like:
if n % 10 == 0:
  mx.eval(d_T1)

That really sped things up for me.

@spichardo
Copy link
Author

Thanks for taking a look

But I can't see an improvement on my end

I just replaced the mx.synchronize call with the mx.eval every 10 steps, but results remain slow. It seems that forcing mx.eval every X steps helps to reduce the latency to get the final results, but it impacts considerably the for loop.

These are the computing times with different points to call mx.eval(d_T1) . The wall-time includes time for the loop and to recover all matrices. I can see the time to execute the for loop and final data transfer changes depending on the setting, but the faster recover the results gets, the slower the for loop becomes, nullifying any improvement. These are my findings with an M3 Max

No call to mx.eval: Time loop 0.075, time to recover results 11.4, total time: 11.56
n%10 : Time Loop: 11.9, time to recover results 0.01 , total time :11.9s
n%100 : Time Loop: 11.3, time to recover results 0.09 , total time :11.4s
n%500 : Time Loop: 10.89, time to recover results 0.45s , total time :11.4s
n%1000 : Time Loop: 10.5, time to recover results 0.9s , total time :11.4s
n%5000 : Time Loop: 9.5, time to recover results 1.8s , total time :11.34s

Did you test with MLX v 0.22.0?

@awni
Copy link
Member

awni commented Feb 4, 2025

Huh.. it looks like setting the init_value is causing a pretty extreme performance regression. I just removed init_value=0 from both kernel calls and it sped it up by 3x (and it's about 2x faster than the Metal backend implementation (including the eval change from above).

Since I don't think you need to zero initialize the data, I think that should solve your problem. I'll look into why it's so much slower though.

@spichardo
Copy link
Author

spichardo commented Feb 4, 2025

That did the trick!,

In my M3 Max it didn't go as fast as 2x compared py-metal-compute, but it got reduced to 3.31s total time with MLX vs 4.6s with py-metal-compute, so 30% better than the reference, this is great, thanks so much!

@awni awni closed this as completed Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants