Skip to content

Commit 19406d4

Browse files
Change Gaussian example to 20D.
1 parent 9726ce2 commit 19406d4

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

examples/gaussian_nondiagcov.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def run_example(
101101
if flow_type == "RealNVP":
102102
epochs_num = 5
103103
elif flow_type == "RQSpline":
104-
epochs_num = 3
104+
# epochs_num = 3
105+
epochs_num = 100
105106

106107
# Beginning of path where plots will be saved
107108
save_name_start = "examples/plots/" + flow_type
@@ -112,7 +113,7 @@ def run_example(
112113

113114
# Spline params
114115
n_layers = 5
115-
n_bins = 5
116+
n_bins = 16
116117
hidden_size = [32, 32]
117118
spline_range = (-10.0, 10.0)
118119

@@ -314,6 +315,7 @@ def run_example(
314315

315316
plt.show()
316317

318+
# Save out realisations for violin plot.
317319
evidence_inv_summary[i_realisation, 0] = ev.evidence_inv
318320
evidence_inv_summary[i_realisation, 1] = ev.evidence_inv_var
319321
evidence_inv_summary[i_realisation, 2] = ev.evidence_inv_var_var
@@ -347,11 +349,11 @@ def run_example(
347349
hm.logs.setup_logging()
348350

349351
# Define parameters.
350-
ndim = 5
352+
ndim = 21
351353
nchains = 100
352354
samples_per_chain = 5000
353-
flow_str = "RealNVP"
354-
# flow_str = "RQSpline"
355+
# flow_str = "RealNVP"
356+
flow_str = "RQSpline"
355357
np.random.seed(10) # used for initializing covariance matrix
356358

357359
hm.logs.info_log("Non-diagonal Covariance Gaussian example")

0 commit comments

Comments
 (0)