Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DMFF Converts all of the JAX Code to Float64 #194

Open
jacob-jacob-jacob opened this issue Mar 12, 2025 · 0 comments
Open

DMFF Converts all of the JAX Code to Float64 #194

jacob-jacob-jacob opened this issue Mar 12, 2025 · 0 comments
Labels
wontfix This will not be worked on

Comments

@jacob-jacob-jacob
Copy link

Summary

Thank you for the great library!

DMFF sets JAX to use float64. Therefore any code using DMFF has to be refactored if one wants to keep using Float32.

DMFF Version

--

JAX Version

0.4.20

OpenMM Version

--

Python Version, CUDA Version, GCC Version, Operating System Version etc

No response

Details

DMFF sets JAX to use float64.
Therefore, any allocation (like jnp.zeros, random.normal, jnp.linspace, ...) will from then on output float64 which is different to the default of float32 which is also the most common choice for neural networks in general.

To keep using your old code base with float32 instead of float64 (float64 would slow down the neural network training a lot) the full code base has to be refactored after importing dmff. Every single allocation statement has to be found and an explicit dype=jnp.float32 has to be added.
This procedure is quite error prone and as long as one overlooks only one of these statements it could happen that your full code or parts of your code now use float64 and slow down the performance of your algorithms.

Is there any way to make this more user-friendly?

@jacob-jacob-jacob jacob-jacob-jacob added the wontfix This will not be worked on label Mar 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

1 participant