You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I would like to describe a use-case for customizing both JVP and VJP via FFI for astrophysical imaging problems ("JAX doesn’t currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please open an issue describing you use case if you hit this limitation in practice" https://docs.jax.dev/en/latest/ffi.html).
In the models that we built for imaging black holes, radio galaxies, and our 3D galactic neighborhood, we use variational inference schemes. Two variational inference schemes, that we particularly like to use, approximate the true posterior distribution of black hole images (and the like) are MGVI and geoVI (https://arxiv.org/abs/1901.11033, https://arxiv.org/abs/2105.10470). They use $J^\dagger N J$ with $J$ the Jacobian of our physics-informed model and $N$ some data-dependent matrix to approximate the posterior uncertainty. For optimal memory utilization, we often end up putting custom code into our physics-informed models and thus require a custom $J^\dagger$ and $J$ to run our inference.
Our current working horse is https://github.com/NIFTy-PPL/JAXbind to customize both $J^\dagger$ and $J$. It uses the custom primitive interface in JAX. Ideally, we would like to switch to JAX's FFI if possible but our need to evaluate $J^\dagger$ and $J$ is blocking that.
I'll stop by to comment that there's no reason why you can't switch to backing your primitives using the XLA FFI! All of the core JAX custom calls have been migrated, and that's what we do there. The point is that core.Primitive isn't a public JAX API, so that's not advertised as a user-facing API. If you're already using it, then there's nothing stopping you from migrating the backend code.
Right, but all I meant was that you can still update your C++ custom calls to use the FFI, while keeping the same Primitive shenanigans - you'd just need to write a new lowering rule! If you need to customize both modes of AD, you can't currently use the jax.ffi.ffi_call frontend, but it can still be beneficial to port the backend implementation as a first step.
I would like to describe a use-case for customizing both JVP and VJP via FFI for astrophysical imaging problems ("JAX doesn’t currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please open an issue describing you use case if you hit this limitation in practice" https://docs.jax.dev/en/latest/ffi.html).
In the models that we built for imaging black holes, radio galaxies, and our 3D galactic neighborhood, we use variational inference schemes. Two variational inference schemes, that we particularly like to use, approximate the true posterior distribution of black hole images (and the like) are MGVI and geoVI (https://arxiv.org/abs/1901.11033, https://arxiv.org/abs/2105.10470). They use$J^\dagger N J$ with $J$ the Jacobian of our physics-informed model and $N$ some data-dependent matrix to approximate the posterior uncertainty. For optimal memory utilization, we often end up putting custom code into our physics-informed models and thus require a custom $J^\dagger$ and $J$ to run our inference.
Our current working horse is https://github.com/NIFTy-PPL/JAXbind to customize both$J^\dagger$ and $J$ . It uses the custom primitive interface in JAX. Ideally, we would like to switch to JAX's FFI if possible but our need to evaluate $J^\dagger$ and $J$ is blocking that.
Related to NIFTy-PPL/JAXbind#39 .
The text was updated successfully, but these errors were encountered: