PyAutoLens JAX GPU Stability
PyAutoLens JAX Stability
The source code no longer imports JAX or uses JAX with user instruction, meaning all calculations use regular numpy.
JAX is imported and used by Analysis objects when lens modeling begins, ensuring that fast lens modeling using GPUs is always performed by default.
The design of PyAutoLens will build on this, whereby to perform more general lensing calculations users will perform JAX jitting and computation themselves. The docs and guides illustrating this are not written yet, but normal numpy run times are ok for most use cases.
Workspace Restructure
The workspace has been restructured such that the core packages are now the dataset types (imaging, interferometer, etc.):
https://github.com/Jammy2211/autolens_workspace
GPU Modeling Examples
The following Juypter Notebooks, which run via Google Colab, illustrate < 10 minute lens modeling for different science cases:
-
imaging/start_here.ipynb : Galaxy scale strong lenses observed with CCD imaging (e.g. Hubble, James Webb).
-
interferometer/start_here.ipynb : Galaxy scale strong lenses observed with interferometer data (e.g. ALMA).
-
point_source/start_here.ipynb: Galaxy scale strong lenses with a lensed point source (e.g. lensed quasars).
-
group/start_here.ipynb : Group scale strong lenses where there are 2-10 lens galaxies.
PyAutoLens JAX Stability Pull Requests
These are described fully in the following two PRs:
Large refactor which passes the numpy or jax numpy import through the code as xp.
This means that no jax arrays are created inside the source code by default, with all calculations default to Numpy, giving the following benefits:
Unit tests and general code use runs faster as it removes JAX overheads.
Numba support for efficient CPU use can be easily retained as no JAX array mixing.
Less ambiguity in sections of code which dont play nice with JAX arrays (e.g. visualization).
Will allow for an easier more explicit user interface where users JAX jit functions themselves and pass the namespace. to get fast run times.
A recent PR on the child projects made JAX optional for likelihood functions, whereby users pass the JAX namespace as the variable `xp` through the source code.
This PR makes JAX optional at the highest level (e.g. `PyAutoConf` and `PyAutoFit`), including:
- For a non-linear search to use JAX, the `use_jax` input must be passed as `True` to the `Analysis` object.
- The non-linear search will internally work out if it supports JAX natively. This will ultimately have behavior where, for example, if gradients are used it uses `jax.grad`, if not it uses `jax.jit`, and if batching is support `jax.vmap`.
- Currently only `Nautilus` uses the `Analysis.use_jax` attribute to set up a `jax.vmap`.
There are few hacky unclean bits in the autofit model composition where it determines whether to use JAX based on input type. A more thorough consideration of how JAX should work in autofit will be performed in the future.