Coordinax enables calculations with coordinates in JAX. Built on Equinox and Quax.
pip install coordinax
Coming soon. In the meantime, if you've used astropy.coordinates
, then
coordinax
should be fairly intuitive.
import coordinax as cx
import jax.numpy as jnp
from unxt import Quantity
q = cx.CartesianPos3D(
x=Quantity(jnp.arange(0, 10.0), "km"),
y=Quantity(jnp.arange(5, 15.0), "km"),
z=Quantity(jnp.arange(10, 20.0), "km"),
)
print(q)
# <CartesianPos3D (x[km], y[km], z[km])
# [[ 0. 5. 10.]
# [ 1. 6. 11.]
# ...
# [ 8. 13. 18.]
# [ 9. 14. 19.]]>
q2 = cx.represent_as(q, cx.SphericalPos)
print(q2)
# <SphericalPos (r[km], theta[rad], phi[rad])
# [[11.18 0.464 1.571]
# [12.57 0.505 1.406]
# ...
# [23.601 0.703 1.019]
# [25.259 0.719 0.999]]>
p = cx.CartesianVel3D(
d_x=Quantity(jnp.arange(0, 10.0), "m/s"),
d_y=Quantity(jnp.arange(5, 15.0), "m/s"),
d_z=Quantity(jnp.arange(10, 20.0), "m/s"),
)
print(p)
# <CartesianVel3D (d_x[m / s], d_y[m / s], d_z[m / s])
# [[ 0. 5. 10.]
# [ 1. 6. 11.]
# ...
# [ 8. 13. 18.]
# [ 9. 14. 19.]]>
p2 = cx.represent_as(p, cx.SphericalVel, q)
print(p2)
# <SphericalVel (d_r[m / s], d_theta[m rad / (km s)], d_phi[m rad / (km s)])
# [[ 1.118e+01 -3.886e-16 0.000e+00]
# [ 1.257e+01 -1.110e-16 0.000e+00]
# ...
# [ 2.360e+01 0.000e+00 0.000e+00]
# [ 2.526e+01 -2.776e-16 0.000e+00]]>
If you found this library to be useful in academic work, then please cite.
We welcome contributions!