Skip to content

Commit 233159f

Browse files
committed
Tests on checkerboard
1 parent 2806697 commit 233159f

5 files changed

Lines changed: 120 additions & 15 deletions

File tree

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
name: checkerboard-fm
2+
project: flowmaps
3+
4+
dim_flow: 2
5+
dim_conditioning: 0
6+
dim_network: 2
7+
batch_size: 256
8+
9+
weighting: 'one'
10+
loss_type: 'fm_ot'
11+
scaling_type: 'none'
12+
13+
learning_rate: 0.001
14+
15+
hidden_dims: [32, 64, 128, 256, 128, 64, 32]
16+
17+
data:
18+
target: sbisim.data.benchmarks.density_estimation.ToyDataset
19+
_file: configs/datasets/2d/checkerboard.yaml
20+
21+
strategy:
22+
target: sbisim.strategy.ConditionalFlowMatching
23+
params:
24+
dim_flow: ${dim_flow}
25+
dim_conditioning: ${dim_conditioning}
26+
weighting: ${weighting}
27+
loss_type: ${loss_type}
28+
scaling_type: ${scaling_type}
29+
time_alpha: 4
30+
bandwidth: 0.01
31+
prior:
32+
target: sbisim.strategy.distributions.get_normal_gaussian_prior
33+
params:
34+
shape: [ 2 ]
35+
sampler:
36+
target: sbisim.strategy.improved_inference.PriorODESolver
37+
params:
38+
prior:
39+
target: sbisim.strategy.distributions.get_normal_gaussian_prior
40+
params:
41+
shape: [ 2 ]
42+
solver_name: euler
43+
init_stepsize: 0.01
44+
schedule:
45+
target: sbisim.strategy.get_paths
46+
params:
47+
paths:
48+
name: optimal_transport
49+
params:
50+
sigma_min: 1e-4
51+
model:
52+
target: sbisim.flows.models.fmpe.SBIResidualNet
53+
params:
54+
hidden_dims: ${hidden_dims}
55+
out_dim: ${dim_flow}
56+
in_dim: ${dim_network}
57+
activation_fn: gelu
58+
59+
num_steps: 10000
60+
training:
61+
num_epochs: 15
62+
num_steps_per_epoch: ${num_steps}
63+
batch_size: ${batch_size}
64+
early_stopping:
65+
patience: 50
66+
67+
test:
68+
active: True
69+
tests:
70+
scatter:
71+
target: sbisim.callbacks.TwoDimensionalPlot
72+
params:
73+
savedir: ${runtime.logdir}
74+
75+
76+
patience: 3
77+
optimization:
78+
_file: configs/optimization/adam_reduce_lr.yaml
79+
80+
callbacks:
81+
checkpoint:
82+
target: sbisim.callbacks.checkpoint.Checkpoint
83+
params:
84+
savedir: ${runtime.logdir}
85+
key: val_loss

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ objax
1414
corner
1515
orbax
1616
pandas
17-
astropy
17+
astropy
18+
seaborn

sbisim/callbacks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .c2st import C2ST
2-
from .visualization import BenchmarkScatterPlot
2+
from .visualization import BenchmarkScatterPlot
3+
from .two_dimensional import TwoDimensionalPlot

sbisim/callbacks/two_dimensional.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,30 @@
99

1010
from ..data.density_estimation.golden_ratio import fibonacci_ratio, get_centers
1111

12-
# from source.data.dataloader import DataLoader
1312
from ..strategy import Strategy
14-
1513
from matplotlib import pyplot as plt
16-
1714
import jax.numpy as jnp
1815
from tqdm import tqdm
19-
2016
import seaborn as sns
2117

18+
19+
class DataLoader:
20+
pass
21+
22+
2223
class TwoDimensionalPlot(Callback):
2324

2425
name: str = '2d_plot'
2526
save_every: int = 10
2627
num_samples: int = 3000
2728
num_samples_gt: int = 3000
2829

29-
def __init__(self, save_every: int = 10, num_samples: int = 3000):
30+
def __init__(self, save_every: int = 10, num_samples: int = 3000, savedir: str = None):
3031
super(self.__class__, self).__init__()
3132

3233
self.save_every = save_every
3334
self.num_samples = num_samples
35+
self.savedir = savedir
3436

3537

3638
def __call__(self, logs: dict, rng: jr.PRNGKey, *args, **kwargs):
@@ -59,8 +61,12 @@ def _call(self, logs: dict, rng: jr.PRNGKey, strategy: Strategy,
5961

6062
ax.scatter(x[:self.num_samples_gt], y[:self.num_samples_gt])
6163

64+
if self.savedir is not None:
65+
plt.savefig(f"{self.savedir}/pictures/samples_gt.png")
66+
6267
logs[f'samples_gt'] = wandb.Image(fig)
6368

69+
6470
plt.close(fig)
6571

6672
sample_fn = strategy.sample
@@ -75,6 +81,15 @@ def _call(self, logs: dict, rng: jr.PRNGKey, strategy: Strategy,
7581
fig, ax = plt.subplots()
7682
ax.scatter(x, y)
7783

84+
epoch = logs.get("epoch", 0)
85+
if self.savedir is not None:
86+
87+
import os
88+
# create directory if it does not exist
89+
os.makedirs(f"{self.savedir}/pictures", exist_ok=True)
90+
91+
plt.savefig(f"{self.savedir}/pictures/samples_{epoch}.png")
92+
7893
logs[f'samples'] = wandb.Image(fig)
7994

8095
plt.close(fig)
@@ -84,6 +99,9 @@ def _call(self, logs: dict, rng: jr.PRNGKey, strategy: Strategy,
8499
def on_train_begin(self, *args, **kwargs):
85100
return self.__call__(*args, init=True, **kwargs)
86101

102+
def on_test(self, *args, **kwargs):
103+
return self.__call__(*args, init=False, **kwargs)
104+
87105
def on_epoch_end(self, logs: dict, rng: jr.PRNGKey, *args, **kwargs):
88106

89107
if "epoch" in logs and logs["epoch"] % self.save_every == 0:

sbisim/trainer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ def init_strategy(self):
8484

8585
print("Number of parameters: ", wandb.run.summary["num_params"])
8686

87+
# Check if JAX is using GPU
88+
from jax.lib import xla_bridge
89+
platform = xla_bridge.get_backend().platform
90+
if platform == 'gpu':
91+
print("JAX is using GPU")
92+
else:
93+
print(f"JAX is using {platform}")
94+
8795
def pack_(self):
8896
self.strategy.bind(self.opt.get_params())
8997
return {'rng': self.callback_rng, 'strategy': self.strategy,
@@ -212,14 +220,6 @@ def train_epoch(self, logs, rng):
212220
step_ += 1
213221
if step_ >= num_val_batches:
214222
break
215-
216-
# Check if JAX is using GPU
217-
from jax.lib import xla_bridge
218-
platform = xla_bridge.get_backend().platform
219-
if platform == 'gpu':
220-
print("JAX is using GPU")
221-
else:
222-
print(f"JAX is using {platform}")
223223

224224
avg_loss /= step_
225225

0 commit comments

Comments
 (0)