-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Low performance of mx.fast.metal_kernel when calling the same kernel > 10000x times #1828
Comments
One problem is that you are not evaluating the graph until the very end of the computation which is really slow since the graph gets really big and uses a lot of memory. My first recommendation is to make sure you to read how lazy evaluation works in MLX. Check-out the documentation.
if n % 10 == 0:
mx.eval(d_T1) That really sped things up for me. |
Thanks for taking a look But I can't see an improvement on my end I just replaced the These are the computing times with different points to call No call to mx.eval: Time loop 0.075, time to recover results 11.4, total time: 11.56 Did you test with MLX v 0.22.0? |
Huh.. it looks like setting the Since I don't think you need to zero initialize the data, I think that should solve your problem. I'll look into why it's so much slower though. |
That did the trick!, In my M3 Max it didn't go as fast as 2x compared py-metal-compute, but it got reduced to 3.31s total time with MLX vs 4.6s with py-metal-compute, so 30% better than the reference, this is great, thanks so much! |
We need some help to improve performance of
mx.fast.metal_kernel
. We have this turn-key tool for treatment planning of transcranial ultrasound stimulation called BabelBrainthat uses the GPU to accelerate calculations extensively. The tool is a multi GPU platform (Metal, CUDA and OpenCL), with Metal being the preferred backend as most of our users in neurosciences are Mac users. We use a lot of GPU kernels for this.
Historically, we used the py-metal-compute library to code all our GPU kernels, and we got great results even for the most demanding kernels (some of them matching the performance of an A6000 GPU).
We started to port from py-metal-compute to MLX mx.fast.metal_kernel as we want, in the long run, take advantage of all the other functionalities that MLX has, especially for ML extensions we will add to our tools.
Most simple kernels run equally well with both libraries, but we have some important performance differences with the most intense functions that solve PDEs with Finite-Difference Time-Difference (FDTD) solvers that call the same GPU kernels tens of thousands of times. In such cases, MLX runs 2.5x slower than py-metal-compute in small matrix sizes, and it gets much worse as the matrix size increases. We suspect that the lazy-execution model of MLX is causing some impact for such execution.
The difference in the execution model shows when with MLX with apparent faster iterations through 12000 calls to the kernel, but then performance takes a hit when results are recovered to Numpy arrays. We need to interface with Numpy arrays as this code is part of much more complex software for the treatment planning of transcranial ultrasound stimulation. py-metal-compute takes more time to complete the 12000 calls to the kernel, but the transfer to Numpy arrays is very small, translating in an overall wall-time is less than 1/2 that of MLX.
Because BabelBrain is a bit of a complex tool, we created this stand-alone repository BHTE_MLX that shows the issue. The repo is self contained, all code is there.
Do you have any suggestions we could explore to tweak
mx.fast.metal_kernel
? like controlling how the lazy model works; py-metal-compute has a more simplistic control of kernels where we can control when the command buffer is completed, so I wonder if we could do something similar, but we are not sure how to tweak the control of the kernel execution in MLX. Any suggestions are welcome. Our FDTD kernels are the last piece we must sort out before migrating completely to MLX.The text was updated successfully, but these errors were encountered: