Releases: Jammy2211/PyAutoLens
JAX Interferometry
Efficient and scalable implementation of pixelized source reconstructions for interferometer analysis using JAX.
This PR introduces a new approach to source reconstructions for interferometer data that fully exploits the symmetries and sparsity of the non-uniform fast transformation.
A high level summary of the implementation is:
-
Pixelized source reconstructions are performed in a way whereby the run time and amount of VRAM used is independent of the number of visibilities.
-
Lens modeling run times are fast, with a 1+ million visibility ALMA dataset being modeled in around 1 hour on a GPU!
-
Other improvements to interferometer analysis and a significant portion of support documentation and examples on the
autolens_workspaceare now provided.
Whilst a quantitative comparison has not yet been performed, my intuition is that this code runs significantly faster than the previous PyAutoLens interferometer modeling and the Powell et al implementation.
Checkout the interferometer package of the autolens_workspace for a complete run through of how to use JAX GPU interferometer analysis!
https://github.com/Jammy2211/autolens_workspace/tree/release/notebooks/interferometer
Delaunay JAX
The adaptive Delaunay mesh using a Hilbert image-mesh now supports fully JAX'd likelihood functions running on GPU, which was disabled in previous releases.
The Delaunay mesh itself is not computed on GPU, but CPU, via a JAX pure_callback. Full details are provided below, but this does not impact significantly on performance:
JAX improvements + Fast CPU Pixelizations support + Delaunay
This release continues to build stability for JAX + GPU support:
https://github.com/Jammy2211/PyAutoLens/releases/tag/2025.11.18.1
This includes many fixes to small errors and bugs, for example more light and mass profiles support JAX after small fixes to the source code.
Fast CPU Pixelizations
Before this release, pixelized source reconstructions could only be computed using JAX, either via GPU or CPU.
There were two important factors in run time:
1. GPU VRAM Limitations
JAX only provides significant acceleration on GPUs with large VRAM (≥16 GB).
To avoid excessive VRAM usage, examples often restrict pixelization meshes (e.g. 20 × 20).
On consumer GPUs with limited memory, JAX may be slower than CPU execution.
2. Sparse Matrix Performance
Pixelized source reconstructions require operations on very large, highly sparse matrices.
- JAX currently lacks sparse-matrix support and must compute using dense matrices, which scale poorly.
This release restore support for PyAutoLens’s previous CPU implementation (via numba) which fully exploits sparsity, providing large speed gains at high image resolution (e.g. pixel_scales <= 0.03).
CPU execution can outperform JAX, even on powerful GPUs, for high-resolution datasets or when many CPU cores are used.
Development is actively working on how to get better performance from JAX that exploits sparsity on GPU, but this is proving to be a very challenging problem.
Delaunay
Support for the Delaunay mesh which was the main pixelized source reconstruction has been restored in this release, albeit it only currently works using the numba implementation above and therefore only supports CPU.
Development is actively working on having JAX support for Delaunay source reconstructions, with this expected to be available in the short-term.
PyAutoLens JAX GPU Stability
PyAutoLens JAX Stability
The source code no longer imports JAX or uses JAX with user instruction, meaning all calculations use regular numpy.
JAX is imported and used by Analysis objects when lens modeling begins, ensuring that fast lens modeling using GPUs is always performed by default.
The design of PyAutoLens will build on this, whereby to perform more general lensing calculations users will perform JAX jitting and computation themselves. The docs and guides illustrating this are not written yet, but normal numpy run times are ok for most use cases.
Workspace Restructure
The workspace has been restructured such that the core packages are now the dataset types (imaging, interferometer, etc.):
https://github.com/Jammy2211/autolens_workspace
GPU Modeling Examples
The following Juypter Notebooks, which run via Google Colab, illustrate < 10 minute lens modeling for different science cases:
-
imaging/start_here.ipynb : Galaxy scale strong lenses observed with CCD imaging (e.g. Hubble, James Webb).
-
interferometer/start_here.ipynb : Galaxy scale strong lenses observed with interferometer data (e.g. ALMA).
-
point_source/start_here.ipynb: Galaxy scale strong lenses with a lensed point source (e.g. lensed quasars).
-
group/start_here.ipynb : Group scale strong lenses where there are 2-10 lens galaxies.
PyAutoLens JAX Stability Pull Requests
These are described fully in the following two PRs:
Large refactor which passes the numpy or jax numpy import through the code as xp.
This means that no jax arrays are created inside the source code by default, with all calculations default to Numpy, giving the following benefits:
Unit tests and general code use runs faster as it removes JAX overheads.
Numba support for efficient CPU use can be easily retained as no JAX array mixing.
Less ambiguity in sections of code which dont play nice with JAX arrays (e.g. visualization).
Will allow for an easier more explicit user interface where users JAX jit functions themselves and pass the namespace. to get fast run times.
A recent PR on the child projects made JAX optional for likelihood functions, whereby users pass the JAX namespace as the variable `xp` through the source code.
This PR makes JAX optional at the highest level (e.g. `PyAutoConf` and `PyAutoFit`), including:
- For a non-linear search to use JAX, the `use_jax` input must be passed as `True` to the `Analysis` object.
- The non-linear search will internally work out if it supports JAX natively. This will ultimately have behavior where, for example, if gradients are used it uses `jax.grad`, if not it uses `jax.jit`, and if batching is support `jax.vmap`.
- Currently only `Nautilus` uses the `Analysis.use_jax` attribute to set up a `jax.vmap`.
There are few hacky unclean bits in the autofit model composition where it determines whether to use JAX based on input type. A more thorough consideration of how JAX should work in autofit will be performed in the future.
PyAutoLens-JAX GPU
UPDATE: Latest JAX version is now 2025.11.5.1
This release marks the completion of two years work implementing JAX (https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html) in PyAutoLens.
With JAX, any lens modeling analysis can be run on GPU, with speed up of ~x50 or more for all lens modeling.
Core Release
The core PyAutoLens API does not change significantly, however existing users redownload the new autolens workspace, which has new configs and examples:
https://github.com/Jammy2211/autolens_workspace
New user should checkout the start_here.ipynb notebook, which can be read via a Google Colab by clicking the hyperlink.
GPU Modeling Examples
The following Juypter Notebooks, which run via Google Colab, illustrate < 10 minute lens modeling for different science cases:
-
start_here_imaging.ipynb: Galaxy-scale strong lenses observed with CCD imaging (e.g. Hubble, James Webb).
-
start_here_interferometer.ipynb: Galaxy scale strong lenses observed with interferometer data (e.g. ALMA).
-
start_here_point_source.ipynb: Galaxy scale strong lenses with a lensed point source (e.g. lensed quasars).
-
start_here_group.ipynb: Group scale strong lenses where there are 2-10 lens galaxies.
-
start_here_multi_wavelength.ipynb: Model multiple images (different wavelength imaging, imaging + interferometer) simultaneously.
Performance Of Other Features
-
Pixelized sources run ~x5 - x20 faster on modern HPC GPU clusters, with lens modeling times typically ~10 - 20 minutes. Pixelized source performance depends on the available GPU VRAM. In November 2025 a release will make GPU performance of pixelized sources for all GPU hardware approach < 10 minute lens models.
-
Interferometer with many Visibilities: Above ~ 100,000 visibilities interferometer performance suffers significant slow down. **In December 2025 a new release will make all interferometer modeling efficient irrespective of the number of visibilities.
-
CPU Performance: For pixelized sources CPU performance is worse than the previous PyAutoLens, as JAX is not optimized for CPUs. A future release will restore performance to be on par with previous versions, but users seeking to perform pixelized source modeling without GPU may wish to use the previous PyAutoLens.
Strong Lens Galaxy Clusters
This release can perform strong lens cluster calculations and lens modeling on GPU. For those familiar with cluster lensing, this includes performing an image-plane chi-squared multiple image calculation for clusters with over 100s of cluster members, with full support for multi-plane ray tracing of the entire cluster!
Initial profiling shows it runs 50 or more times faster than other strong lens cluster codes run on CPU. Documentation and examples for cluster modeling are actively being developed but not yet mature. You can find the most up to date examples at the following links:
https://github.com/Jammy2211/autolens_workspace/blob/release/start_here_cluster.ipynb
https://github.com/Jammy2211/autolens_workspace/tree/release/scripts/simulators/cluster
https://github.com/Jammy2211/autolens_workspace/tree/release/scripts/modeling/cluster
May 2025
- Results workflow API, which generates .csv, .png and .fits files of large libraries of results for quick and efficient inspection:
https://github.com/Jammy2211/autolens_workspace/tree/main/notebooks/results/workflow
-
Visualization now outputs .fits files corresponding to each subplot, which more concisely contain all information of a fit and are used by the above workflow API.
-
Visualization Simplified, removing customization of individual image outputs.
-
Remove Analysis summing API, replacing all dataset combinations with
AnalysisFactorandFactorGraphModelAPI used for graphical modeling:
-
Pixelized source reconstruction output as a .csv file which can be loaded and interpolated for better source science analysis.
-
Double source plane lens modeling now outputs individual subplot_fit for each plane.
-
Latent variable API bug fixes and now used in some test example scripts.
January 2025
The main updates are visualization of Delaunay mesh's using Delaunah triangles and a significant refactoring of over sampling, with the primary motivation to make the code much less complex for the ongoing JAX implementation.
There have also been more improvements to point source modeling, including JAX functionality, which will be documented fully in the near future.
What's Changed
- Feature/disable noise by @Jammy2211 in #324
- feature/delaunay_visual by @Jammy2211 in #323
- feature/inversion_noise_map by @Jammy2211 in #325
- feature/positions_lh_mass_centre by @Jammy2211 in #326
- feature/triangle array typing by @rhayes777 in #328
- feature/array testing by @rhayes777 in #327
- Feature/over sampling refactor by @Jammy2211 in #332
- remove max containing size from solver by @rhayes777 in #329
- feature/andrew implementation by @rhayes777 in #331
Full Changelog: 2024.11.13.2...2025.1.18.7
November 2024 update
Small bug fixes and optimizations for Euclid lens modeling pipeline.
November 2024
Minor release with stability updates and one main feature.
-
Extra Galaxies API for modeling multiple galaxies at once: https://github.com/Jammy2211/autolens_workspace/blob/release/notebooks/features/extra_galaxies.ipynb
-
Multiwavelength lens modeling with SLaM multi wavelength pipelines: https://github.com/Jammy2211/autolens_workspace/tree/main/scripts/advanced/chaining/slam/multi
-
More improvements to Point source solver and Shape solver.
-
Sensitivity mapping improvements which will be fully documented in the future.
September 2024
This release updates all projects to support Python 3.12, with support tested for Python 3.9 - 3.12 and 3.11 regarded as most stable.
This includes many project dependency updates:
https://github.com/rhayes777/PyAutoFit/blob/main/requirements.txt
https://github.com/rhayes777/PyAutoFit/blob/main/optional_requirements.txt
https://github.com/Jammy2211/PyAutoGalaxy/blob/main/requirements.txt
https://github.com/Jammy2211/PyAutoGalaxy/blob/main/optional_requirements.txt
https://github.com/Jammy2211/PyAutoLens/blob/main/requirements.txt
https://github.com/Jammy2211/PyAutoLens/blob/main/optional_requirements.txt
Workspace Restructure:
This release has a workspace restructure, which is now grouped at a high level by tasks (e.g. modeling, simulators) rather than datasets:
https://github.com/Jammy2211/autolens_workspace
The readthedocs have been greatly simplified and include a new user guide to help navigate the new workspace:
https://pyautolens.readthedocs.io/en/latest/overview/overview_2_new_user_guide.html
PyAutoLens:
- Point source modeling significantly improved with triangle tracing method, image plane chi squared supported: https://github.com/Jammy2211/autolens_workspace/tree/release/notebooks/modeling/point_source
- Shape based point-source modeling for magnification calculations: #300
- Improved Cosmology wrapper to support new
astropyand easier to use in models: Jammy2211/PyAutoGalaxy#193 - Ellipse Fitting: https://github.com/Jammy2211/autogalaxy_workspace/tree/release/notebooks/advanced/misc/ellipse
PyAutoFit:
https://github.com/rhayes777/PyAutoFit/pulls?q=is%3Apr+is%3Aclosed
- Improvements to HowToFit lectures: rhayes777/PyAutoFit#1022
- Support for NumPy arrays in model composition and prior creation, for example creating an
ndarrayof inputshapewhere each value is a free parameter in the seach: rhayes777/PyAutoFit#1021 - Name of
optimizesearches renamed tomle, for maximum likelihood estimator, with improvements to visualization: rhayes777/PyAutoFit#1029 - Improvement to sensitivity mapping functionality and results: https://github.com/rhayes777/PyAutoFit/pulls?q=is%3Apr+is%3Aclosed
- More improvements to JAX Pytree interface, documentation still to come.