Skip to content

PyAutoLens JAX GPU Stability

Choose a tag to compare

@Jammy2211 Jammy2211 released this 18 Nov 15:38
· 39 commits to main since this release

PyAutoLens JAX Stability

The source code no longer imports JAX or uses JAX with user instruction, meaning all calculations use regular numpy.

JAX is imported and used by Analysis objects when lens modeling begins, ensuring that fast lens modeling using GPUs is always performed by default.

The design of PyAutoLens will build on this, whereby to perform more general lensing calculations users will perform JAX jitting and computation themselves. The docs and guides illustrating this are not written yet, but normal numpy run times are ok for most use cases.

Workspace Restructure

The workspace has been restructured such that the core packages are now the dataset types (imaging, interferometer, etc.):

https://github.com/Jammy2211/autolens_workspace

GPU Modeling Examples

The following Juypter Notebooks, which run via Google Colab, illustrate < 10 minute lens modeling for different science cases:

PyAutoLens JAX Stability Pull Requests

These are described fully in the following two PRs:

#371

Large refactor which passes the numpy or jax numpy import through the code as xp.

This means that no jax arrays are created inside the source code by default, with all calculations default to Numpy, giving the following benefits:

Unit tests and general code use runs faster as it removes JAX overheads.
Numba support for efficient CPU use can be easily retained as no JAX array mixing.
Less ambiguity in sections of code which dont play nice with JAX arrays (e.g. visualization).
Will allow for an easier more explicit user interface where users JAX jit functions themselves and pass the namespace. to get fast run times.

#372

A recent PR on the child projects made JAX optional for likelihood functions, whereby users pass the JAX namespace as the variable `xp` through the source  code.

This PR makes JAX optional at the highest level (e.g. `PyAutoConf` and `PyAutoFit`), including:

- For a non-linear search to use JAX, the `use_jax` input must be passed as `True` to the `Analysis` object.
- The non-linear search will internally work out if it supports JAX natively. This will ultimately have behavior where, for example, if gradients are used it uses `jax.grad`, if not it uses `jax.jit`, and if batching is support `jax.vmap`.
- Currently only `Nautilus` uses the `Analysis.use_jax` attribute to set up a `jax.vmap`.

There are few hacky unclean bits in the autofit model composition where it determines whether to use JAX based on input type. A more thorough consideration of how JAX should work in autofit will be performed in the future.