Skip to content

Commit 467a2c2

Browse files
Add tests for evidence calculation sample batching.
1 parent 89325d5 commit 467a2c2

File tree

1 file changed

+61
-18
lines changed

1 file changed

+61
-18
lines changed

tests/test_evidence.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@
1616
models_to_test_1 = [sphere_1000D, real_nvp_2D, spline_4D]
1717
models_to_test_2 = [sphere_2D, real_nvp_2D, spline_4D]
1818

19+
models_to_test_2D = [
20+
sphere_2D,
21+
real_nvp_2D,
22+
md.RealNVPModel(2, standardize=True),
23+
md.RQSplineModel(2),
24+
md.RQSplineModel(2, standardize=True),
25+
]
26+
27+
chain_batching_options = [None, 2, 10]
28+
1929

2030
@pytest.mark.parametrize("model", models_to_test_1)
2131
def test_constructor(model):
@@ -59,6 +69,29 @@ def test_set_shift(model):
5969
assert rho.shift_set == True
6070

6171

72+
@pytest.mark.parametrize("model", models_to_test_2)
73+
def test_add_chains_sample_batching_error(model):
74+
75+
nchains = 10
76+
n_samples = 20
77+
ndim = model.ndim
78+
num_slices = 300
79+
80+
X = np.zeros((nchains, n_samples, ndim))
81+
Y = np.zeros((nchains, n_samples))
82+
83+
# Add samples to chains
84+
chain = ch.Chains(ndim)
85+
chain.add_chains_3d(X, Y)
86+
87+
model.fitted = True
88+
89+
# Calculate evidence
90+
cal_ev = cbe.Evidence(nchains, model)
91+
with pytest.raises(ValueError):
92+
cal_ev.add_chains(chain, num_slices=num_slices)
93+
94+
6295
@pytest.mark.parametrize("model", models_to_test_1)
6396
def test_process_run_with_shift(model):
6497
nchains = 10
@@ -111,7 +144,9 @@ def test_process_run_with_shift(model):
111144
assert np.exp(rho.ln_evidence_inv_var_var) == pytest.approx(evidence_inv_var_var)
112145

113146

114-
def test_add_chains():
147+
@pytest.mark.parametrize("model", models_to_test_2D)
148+
@pytest.mark.parametrize("num_slices", chain_batching_options)
149+
def test_add_chains(model, num_slices):
115150
nchains = 200
116151
nsamples = 500
117152
ndim = 2
@@ -125,22 +160,25 @@ def test_add_chains():
125160
chain = ch.Chains(ndim)
126161
chain.add_chains_3d(X, Y)
127162

128-
# Fit the Hyper_sphere
129-
domain = [np.array([1e-1, 1e1])]
130-
sphere = mdl.HyperSphere(ndim, domain)
131-
sphere.fit(chain.samples, chain.ln_posterior)
163+
if hasattr(model, "flow"):
164+
model.fit(chain.samples, epochs=5)
165+
else:
166+
model.fit(chain.samples, chain.ln_posterior)
132167

133168
# Calculate evidence
134-
cal_ev = cbe.Evidence(nchains, sphere, cbe.Shifting.MEAN_SHIFT)
135-
cal_ev.add_chains(chain)
169+
cal_ev = cbe.Evidence(nchains, model, cbe.Shifting.MEAN_SHIFT)
170+
cal_ev.add_chains(chain, num_slices=num_slices)
136171

137172
print("cal_ev.evidence_inv = {}".format(np.exp(cal_ev.ln_evidence_inv)))
138173

139-
assert np.exp(cal_ev.ln_evidence_inv) == pytest.approx(0.159438606)
140-
assert np.exp(cal_ev.ln_evidence_inv_var) == pytest.approx(1.164628268e-07)
141-
assert np.exp(cal_ev.ln_evidence_inv_var_var) ** 0.5 == pytest.approx(
142-
1.142786462e-08
143-
)
174+
if hasattr(model, "flow"):
175+
assert np.exp(cal_ev.ln_evidence_inv) == pytest.approx(0.159438606, rel=0.01)
176+
else:
177+
assert np.exp(cal_ev.ln_evidence_inv) == pytest.approx(0.159438606)
178+
assert np.exp(cal_ev.ln_evidence_inv_var) == pytest.approx(1.164628268e-07)
179+
assert np.exp(cal_ev.ln_evidence_inv_var_var) ** 0.5 == pytest.approx(
180+
1.142786462e-08
181+
)
144182

145183
nsamples1 = 300
146184
chains1 = ch.Chains(ndim)
@@ -150,14 +188,19 @@ def test_add_chains():
150188
for i_chain in range(nchains):
151189
chains2.add_chain(X[i_chain, nsamples1:, :], Y[i_chain, nsamples1:])
152190

153-
ev = cbe.Evidence(nchains, sphere, cbe.Shifting.MEAN_SHIFT)
191+
ev = cbe.Evidence(nchains, model, cbe.Shifting.MEAN_SHIFT)
154192
# Might have small numerical differences if don't use same mean_shift.
155-
ev.add_chains(chains1)
156-
ev.add_chains(chains2)
193+
ev.add_chains(chains1, num_slices=num_slices)
194+
ev.add_chains(chains2, num_slices=num_slices)
157195

158-
assert np.exp(ev.ln_evidence_inv) == pytest.approx(0.159438606)
159-
assert np.exp(ev.ln_evidence_inv_var) == pytest.approx(1.164628268e-07)
160-
assert np.exp(ev.ln_evidence_inv_var_var) ** 0.5 == pytest.approx(1.142786462e-08)
196+
if hasattr(model, "flow"):
197+
assert np.exp(ev.ln_evidence_inv) == pytest.approx(0.159438606, rel=0.01)
198+
else:
199+
assert np.exp(ev.ln_evidence_inv) == pytest.approx(0.159438606)
200+
assert np.exp(ev.ln_evidence_inv_var) == pytest.approx(1.164628268e-07)
201+
assert np.exp(ev.ln_evidence_inv_var_var) ** 0.5 == pytest.approx(
202+
1.142786462e-08
203+
)
161204

162205
return
163206

0 commit comments

Comments
 (0)