Skip to content

PyAutoLens-JAX GPU

Choose a tag to compare

@Jammy2211 Jammy2211 released this 20 Oct 18:27
· 56 commits to main since this release

UPDATE: Latest JAX version is now 2025.11.5.1

This release marks the completion of two years work implementing JAX (https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html) in PyAutoLens.

With JAX, any lens modeling analysis can be run on GPU, with speed up of ~x50 or more for all lens modeling.

Core Release

The core PyAutoLens API does not change significantly, however existing users redownload the new autolens workspace, which has new configs and examples:

https://github.com/Jammy2211/autolens_workspace

New user should checkout the start_here.ipynb notebook, which can be read via a Google Colab by clicking the hyperlink.

GPU Modeling Examples

The following Juypter Notebooks, which run via Google Colab, illustrate < 10 minute lens modeling for different science cases:

Performance Of Other Features

  • Pixelized sources run ~x5 - x20 faster on modern HPC GPU clusters, with lens modeling times typically ~10 - 20 minutes. Pixelized source performance depends on the available GPU VRAM. In November 2025 a release will make GPU performance of pixelized sources for all GPU hardware approach < 10 minute lens models.

  • Interferometer with many Visibilities: Above ~ 100,000 visibilities interferometer performance suffers significant slow down. **In December 2025 a new release will make all interferometer modeling efficient irrespective of the number of visibilities.

  • CPU Performance: For pixelized sources CPU performance is worse than the previous PyAutoLens, as JAX is not optimized for CPUs. A future release will restore performance to be on par with previous versions, but users seeking to perform pixelized source modeling without GPU may wish to use the previous PyAutoLens.

Strong Lens Galaxy Clusters

This release can perform strong lens cluster calculations and lens modeling on GPU. For those familiar with cluster lensing, this includes performing an image-plane chi-squared multiple image calculation for clusters with over 100s of cluster members, with full support for multi-plane ray tracing of the entire cluster!

Initial profiling shows it runs 50 or more times faster than other strong lens cluster codes run on CPU. Documentation and examples for cluster modeling are actively being developed but not yet mature. You can find the most up to date examples at the following links:

https://github.com/Jammy2211/autolens_workspace/blob/release/start_here_cluster.ipynb
https://github.com/Jammy2211/autolens_workspace/tree/release/scripts/simulators/cluster
https://github.com/Jammy2211/autolens_workspace/tree/release/scripts/modeling/cluster