Skip to content

Commit 254892d

Browse files
authored
feat: new optimization methods + poetry to hatch + cleaner linear search
- new optimization methods - poetry to hatch - cleaner linear search
2 parents 2ff4457 + 55341af commit 254892d

File tree

17 files changed

+875
-685
lines changed

17 files changed

+875
-685
lines changed

.github/workflows/ci.yaml

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
name: Build, test and publish
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
tags:
8+
- "*"
9+
pull_request:
10+
release:
11+
types: [published]
12+
13+
jobs:
14+
build:
15+
runs-on: ubuntu-latest
16+
steps:
17+
- uses: actions/checkout@v4
18+
with:
19+
fetch-depth: 0
20+
- uses: actions/setup-python@v5
21+
name: Install Python
22+
with:
23+
python-version: "3.10"
24+
- name: Build sdist and wheel
25+
run: |
26+
python -m pip install -U pip
27+
python -m pip install -U build
28+
python -m build .
29+
- uses: actions/upload-artifact@v4
30+
with:
31+
path: dist/*
32+
33+
test:
34+
runs-on: ubuntu-latest
35+
needs: [build]
36+
steps:
37+
- uses: actions/checkout@v4
38+
- uses: actions/setup-python@v5
39+
with:
40+
python-version: "3.10"
41+
- name: Install dependencies
42+
run: python -m pip install -U pip
43+
- name: Install package and test dependencies
44+
run: python -m pip install ".[dev]"
45+
- name: Run tests
46+
run: python -m pytest
47+
48+
publish:
49+
environment:
50+
name: pypi
51+
url: https://pypi.org/p/portrait
52+
permissions:
53+
id-token: write
54+
needs: [test]
55+
runs-on: ubuntu-latest
56+
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')
57+
steps:
58+
- uses: actions/download-artifact@v4
59+
with:
60+
name: artifact
61+
path: dist
62+
- uses: pypa/[email protected]

.github/workflows/ci.yml

Lines changed: 0 additions & 29 deletions
This file was deleted.

.github/workflows/publish.yml

Lines changed: 0 additions & 50 deletions
This file was deleted.

docs/index.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ Efficient detection of planets transiting quiet or active stars
6161
6262
markdown/install
6363
notebooks/motivation.ipynb
64-
notebooks/star.ipynb
65-
notebooks/templates.ipynb
6664
examples.md
6765
```
6866

@@ -81,5 +79,8 @@ notebooks/tutorials/exocomet.ipynb
8179
:caption: Reference
8280
8381
markdown/how.ipynb
82+
notebooks/star.ipynb
83+
notebooks/templates.ipynb
84+
markdown/hardware.md
8485
markdown/API
8586
```

docs/markdown/hardware.md

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Hardware acceleration
2+
3+
When running the linear search, nuance exploits the parallelization capabilities of JAX by using a default mapping strategy depending on the available devices.
4+
5+
## Solving for `(t0, D)`
6+
7+
To solve a particular model (like a transit) with a given epoch `t0` and duration `D`, we define the function
8+
9+
```python
10+
import jax
11+
12+
@jax.jit
13+
def solve(t0, D):
14+
m = model(time, t0, D)
15+
ll, w, v = nu._solve(m)
16+
return w[-1], v[-1, -1], ll
17+
```
18+
19+
where `model` is the [template model](../notebooks/templates.ipynb), and `nu._solve` is the `Nuance._solve` method returning:
20+
21+
- `w[-1]` the template model depth
22+
- `v[-1, -1]` the variance of the template model depth
23+
- `ll` the log-likelihood of the data to the model
24+
25+
## Batching over `(t0s, Ds)`
26+
The goal of the linear search is then to call `solve` for a grid of of epochs `t0s` and durations `Ds`. As `t0s` is usually very large compared to `Ds` (~1000 vs. ~10), the default strategy is to batch the `t0s`:
27+
28+
```python
29+
# we pad to have fixed size batches
30+
t0s_padded = np.pad(t0s, [0, batch_size - (len(t0s) % batch_size) % batch_size])
31+
t0s_batches = np.reshape(
32+
t0s_padded, (len(t0s_padded) // batch_size, batch_size)
33+
)
34+
```
35+
36+
## JAX mapping
37+
38+
In order to solve a given batch in an optimal way, the `batch_size` can be set depending on the devices available:
39+
40+
- If multiple **CPUs** are available, the `batch_size` is chosen as the number of devices (`jax.device_count()`) and we can solve a given batch using
41+
42+
```python
43+
solve_batch = jax.pmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))
44+
```
45+
46+
where each batch is `jax.pmap`ed over all available CPUs along the `t0s` axis.
47+
48+
- If a **GPU** is available, the `batch_size` can be larger and the batch is `jax.vmap`ed along `t0s`
49+
50+
```python
51+
solve_batch = jax.vmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))
52+
```
53+
54+
Then, the linear search consists in iterating over `t0s_batches`:
55+
56+
```python
57+
results = []
58+
59+
for t0_batch in t0s_batches:
60+
results.append(solve_batch(t0_batch, Ds))
61+
```
62+
63+
```{note}
64+
Of course, one familiar with JAX can use their own mapping strategy to evaluate `solve` over a grid of epochs `t0s` and durations `Ds`.
65+
```
66+
67+
## The full method
68+
69+
The method `nuance.Naunce.linear_search` is then
70+
71+
```python
72+
def linear_search( self, t0s, Ds):
73+
74+
backend = jax.default_backend()
75+
batch_size = {"cpu": DEVICES_COUNT, "gpu": 1000}[backend]
76+
77+
@jax.jit
78+
def solve(t0, D):
79+
m = self.model(self.time, t0, D)
80+
ll, w, v = self._solve(m)
81+
return jnp.array([w[-1], v[-1, -1], ll])
82+
83+
# Batches
84+
t0s_padded = np.pad(t0s, [0, batch_size - (len(t0s) % batch_size) % batch_size])
85+
t0s_batches = np.reshape(
86+
t0s_padded, (len(t0s_padded) // batch_size, batch_size)
87+
)
88+
89+
# Mapping
90+
if backend == "cpu":
91+
solve_batch = jax.pmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))
92+
else:
93+
solve_batch = jax.vmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))
94+
95+
# Iterate
96+
results = []
97+
98+
for t0_batch in t0s_batches:
99+
results.append(solve_batch(t0_batch, Ds))
100+
101+
...
102+
```

docs/notebooks/exocomet.pdf

-295 KB
Binary file not shown.

docs/notebooks/single.ipynb

Lines changed: 6 additions & 6 deletions
Large diffs are not rendered by default.

docs/notebooks/tutorials/GP_optimization.ipynb

Lines changed: 273 additions & 0 deletions
Large diffs are not rendered by default.

docs/notebooks/tutorials/exocomet.ipynb

Lines changed: 169 additions & 173 deletions
Large diffs are not rendered by default.

docs/notebooks/tutorials/exocomet.pdf

-295 KB
Binary file not shown.

0 commit comments

Comments
 (0)