Skip to content

FFI Custom JVP + VJP Support #27352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Edenhofer opened this issue Mar 22, 2025 · 3 comments
Open

FFI Custom JVP + VJP Support #27352

Edenhofer opened this issue Mar 22, 2025 · 3 comments
Labels
enhancement New feature or request

Comments

@Edenhofer
Copy link
Contributor

I would like to describe a use-case for customizing both JVP and VJP via FFI for astrophysical imaging problems ("JAX doesn’t currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please open an issue describing you use case if you hit this limitation in practice" https://docs.jax.dev/en/latest/ffi.html).

In the models that we built for imaging black holes, radio galaxies, and our 3D galactic neighborhood, we use variational inference schemes. Two variational inference schemes, that we particularly like to use, approximate the true posterior distribution of black hole images (and the like) are MGVI and geoVI (https://arxiv.org/abs/1901.11033, https://arxiv.org/abs/2105.10470). They use $J^\dagger N J$ with $J$ the Jacobian of our physics-informed model and $N$ some data-dependent matrix to approximate the posterior uncertainty. For optimal memory utilization, we often end up putting custom code into our physics-informed models and thus require a custom $J^\dagger$ and $J$ to run our inference.

Our current working horse is https://github.com/NIFTy-PPL/JAXbind to customize both $J^\dagger$ and $J$. It uses the custom primitive interface in JAX. Ideally, we would like to switch to JAX's FFI if possible but our need to evaluate $J^\dagger$ and $J$ is blocking that.

Related to NIFTy-PPL/JAXbind#39 .

@dfm
Copy link
Collaborator

dfm commented Mar 22, 2025

I'll stop by to comment that there's no reason why you can't switch to backing your primitives using the XLA FFI! All of the core JAX custom calls have been migrated, and that's what we do there. The point is that core.Primitive isn't a public JAX API, so that's not advertised as a user-facing API. If you're already using it, then there's nothing stopping you from migrating the backend code.

@Edenhofer
Copy link
Contributor Author

Thanks for the comment. Yes, @roth-jakob implemented the switch to jax.extend's Primitive in NIFTy-PPL/JAXbind#38 .

@dfm
Copy link
Collaborator

dfm commented Mar 22, 2025

Right, but all I meant was that you can still update your C++ custom calls to use the FFI, while keeping the same Primitive shenanigans - you'd just need to write a new lowering rule! If you need to customize both modes of AD, you can't currently use the jax.ffi.ffi_call frontend, but it can still be beneficial to port the backend implementation as a first step.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants