16
16
models_to_test_1 = [sphere_1000D , real_nvp_2D , spline_4D ]
17
17
models_to_test_2 = [sphere_2D , real_nvp_2D , spline_4D ]
18
18
19
+ models_to_test_2D = [
20
+ sphere_2D ,
21
+ real_nvp_2D ,
22
+ md .RealNVPModel (2 , standardize = True ),
23
+ md .RQSplineModel (2 ),
24
+ md .RQSplineModel (2 , standardize = True ),
25
+ ]
26
+
27
+ chain_batching_options = [None , 2 , 10 ]
28
+
19
29
20
30
@pytest .mark .parametrize ("model" , models_to_test_1 )
21
31
def test_constructor (model ):
@@ -59,6 +69,29 @@ def test_set_shift(model):
59
69
assert rho .shift_set == True
60
70
61
71
72
+ @pytest .mark .parametrize ("model" , models_to_test_2 )
73
+ def test_add_chains_sample_batching_error (model ):
74
+
75
+ nchains = 10
76
+ n_samples = 20
77
+ ndim = model .ndim
78
+ num_slices = 300
79
+
80
+ X = np .zeros ((nchains , n_samples , ndim ))
81
+ Y = np .zeros ((nchains , n_samples ))
82
+
83
+ # Add samples to chains
84
+ chain = ch .Chains (ndim )
85
+ chain .add_chains_3d (X , Y )
86
+
87
+ model .fitted = True
88
+
89
+ # Calculate evidence
90
+ cal_ev = cbe .Evidence (nchains , model )
91
+ with pytest .raises (ValueError ):
92
+ cal_ev .add_chains (chain , num_slices = num_slices )
93
+
94
+
62
95
@pytest .mark .parametrize ("model" , models_to_test_1 )
63
96
def test_process_run_with_shift (model ):
64
97
nchains = 10
@@ -111,7 +144,9 @@ def test_process_run_with_shift(model):
111
144
assert np .exp (rho .ln_evidence_inv_var_var ) == pytest .approx (evidence_inv_var_var )
112
145
113
146
114
- def test_add_chains ():
147
+ @pytest .mark .parametrize ("model" , models_to_test_2D )
148
+ @pytest .mark .parametrize ("num_slices" , chain_batching_options )
149
+ def test_add_chains (model , num_slices ):
115
150
nchains = 200
116
151
nsamples = 500
117
152
ndim = 2
@@ -125,22 +160,25 @@ def test_add_chains():
125
160
chain = ch .Chains (ndim )
126
161
chain .add_chains_3d (X , Y )
127
162
128
- # Fit the Hyper_sphere
129
- domain = [ np . array ([ 1e-1 , 1e1 ])]
130
- sphere = mdl . HyperSphere ( ndim , domain )
131
- sphere .fit (chain .samples , chain .ln_posterior )
163
+ if hasattr ( model , "flow" ):
164
+ model . fit ( chain . samples , epochs = 5 )
165
+ else :
166
+ model .fit (chain .samples , chain .ln_posterior )
132
167
133
168
# Calculate evidence
134
- cal_ev = cbe .Evidence (nchains , sphere , cbe .Shifting .MEAN_SHIFT )
135
- cal_ev .add_chains (chain )
169
+ cal_ev = cbe .Evidence (nchains , model , cbe .Shifting .MEAN_SHIFT )
170
+ cal_ev .add_chains (chain , num_slices = num_slices )
136
171
137
172
print ("cal_ev.evidence_inv = {}" .format (np .exp (cal_ev .ln_evidence_inv )))
138
173
139
- assert np .exp (cal_ev .ln_evidence_inv ) == pytest .approx (0.159438606 )
140
- assert np .exp (cal_ev .ln_evidence_inv_var ) == pytest .approx (1.164628268e-07 )
141
- assert np .exp (cal_ev .ln_evidence_inv_var_var ) ** 0.5 == pytest .approx (
142
- 1.142786462e-08
143
- )
174
+ if hasattr (model , "flow" ):
175
+ assert np .exp (cal_ev .ln_evidence_inv ) == pytest .approx (0.159438606 , rel = 0.01 )
176
+ else :
177
+ assert np .exp (cal_ev .ln_evidence_inv ) == pytest .approx (0.159438606 )
178
+ assert np .exp (cal_ev .ln_evidence_inv_var ) == pytest .approx (1.164628268e-07 )
179
+ assert np .exp (cal_ev .ln_evidence_inv_var_var ) ** 0.5 == pytest .approx (
180
+ 1.142786462e-08
181
+ )
144
182
145
183
nsamples1 = 300
146
184
chains1 = ch .Chains (ndim )
@@ -150,14 +188,19 @@ def test_add_chains():
150
188
for i_chain in range (nchains ):
151
189
chains2 .add_chain (X [i_chain , nsamples1 :, :], Y [i_chain , nsamples1 :])
152
190
153
- ev = cbe .Evidence (nchains , sphere , cbe .Shifting .MEAN_SHIFT )
191
+ ev = cbe .Evidence (nchains , model , cbe .Shifting .MEAN_SHIFT )
154
192
# Might have small numerical differences if don't use same mean_shift.
155
- ev .add_chains (chains1 )
156
- ev .add_chains (chains2 )
193
+ ev .add_chains (chains1 , num_slices = num_slices )
194
+ ev .add_chains (chains2 , num_slices = num_slices )
157
195
158
- assert np .exp (ev .ln_evidence_inv ) == pytest .approx (0.159438606 )
159
- assert np .exp (ev .ln_evidence_inv_var ) == pytest .approx (1.164628268e-07 )
160
- assert np .exp (ev .ln_evidence_inv_var_var ) ** 0.5 == pytest .approx (1.142786462e-08 )
196
+ if hasattr (model , "flow" ):
197
+ assert np .exp (ev .ln_evidence_inv ) == pytest .approx (0.159438606 , rel = 0.01 )
198
+ else :
199
+ assert np .exp (ev .ln_evidence_inv ) == pytest .approx (0.159438606 )
200
+ assert np .exp (ev .ln_evidence_inv_var ) == pytest .approx (1.164628268e-07 )
201
+ assert np .exp (ev .ln_evidence_inv_var_var ) ** 0.5 == pytest .approx (
202
+ 1.142786462e-08
203
+ )
161
204
162
205
return
163
206
0 commit comments