You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
DMFF sets JAX to use float64. Therefore any code using DMFF has to be refactored if one wants to keep using Float32.
DMFF Version
--
JAX Version
0.4.20
OpenMM Version
--
Python Version, CUDA Version, GCC Version, Operating System Version etc
No response
Details
DMFF sets JAX to use float64.
Therefore, any allocation (like jnp.zeros, random.normal, jnp.linspace, ...) will from then on output float64 which is different to the default of float32 which is also the most common choice for neural networks in general.
To keep using your old code base with float32 instead of float64 (float64 would slow down the neural network training a lot) the full code base has to be refactored after importing dmff. Every single allocation statement has to be found and an explicit dype=jnp.float32 has to be added.
This procedure is quite error prone and as long as one overlooks only one of these statements it could happen that your full code or parts of your code now use float64 and slow down the performance of your algorithms.
Is there any way to make this more user-friendly?
The text was updated successfully, but these errors were encountered:
Summary
Thank you for the great library!
DMFF sets JAX to use float64. Therefore any code using DMFF has to be refactored if one wants to keep using Float32.
DMFF Version
--
JAX Version
0.4.20
OpenMM Version
--
Python Version, CUDA Version, GCC Version, Operating System Version etc
No response
Details
DMFF sets JAX to use float64.
Therefore, any allocation (like
jnp.zeros, random.normal, jnp.linspace, ...
) will from then on outputfloat64
which is different to the default offloat32
which is also the most common choice for neural networks in general.To keep using your old code base with
float32
instead offloat64
(float64
would slow down the neural network training a lot) the full code base has to be refactored after importingdmff
. Every single allocation statement has to be found and an explicitdype=jnp.float32
has to be added.This procedure is quite error prone and as long as one overlooks only one of these statements it could happen that your full code or parts of your code now use
float64
and slow down the performance of your algorithms.Is there any way to make this more user-friendly?
The text was updated successfully, but these errors were encountered: