Skip to content

Commit a978e56

Browse files
committed
add sph compare experiments
1 parent 1d68d2c commit a978e56

File tree

216 files changed

+17232
-20
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

216 files changed

+17232
-20
lines changed

main.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This is a sample Python script.
2+
3+
# Press ⌃R to execute it or replace it with your code.
4+
# Press Double ⇧ to search everywhere for classes, files, tool windows, actions, and settings.
5+
6+
7+
def print_hi(name):
8+
# Use a breakpoint in the code line below to debug your script.
9+
print(f'Hi, {name}') # Press ⌘F8 to toggle the breakpoint.
10+
11+
12+
# Press the green button in the gutter to run the script.
13+
if __name__ == '__main__':
14+
print_hi('PyCharm')
15+
16+
# See PyCharm help at https://www.jetbrains.com/help/pycharm/

stanify/builders/stan_block_builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ def build_block(self):
255255
code += "real <lower=0> time_saveper;\n"
256256
code += "array[N] real integration_times;\n"
257257
# TODO @Dashadower
258-
# Q1. where code intentions could be logged (answer to theses qs) e.g. reason for filtering out stan_param is to leave _obs
259258
# Q2. using `stan_type`, could hierarchy arrays of vector coding be improved?
260259
# Q3. which code is better btw R ==1 (excluding to leave _obs) vs else part (including _obs)?
261260

stanify/calibrator/draws_data_mapper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import xarray as xr
88
xr.set_options(display_expand_attrs = False)
99
import arviz as az
10-
10+
import typing
11+
if typing.TYPE_CHECKING:
12+
import stanify.stan_model
1113
def trunc4StanNegBinom(real_series):
1214
"""
1315
DataArray type real series
@@ -16,7 +18,7 @@ def trunc4StanNegBinom(real_series):
1618
return int_series
1719

1820

19-
def draws2data(model, idata_kwargs, data_dict):
21+
def draws2data(model, idata_kwargs, data_dict) -> az.InferenceData:
2022
"""
2123
Parameters
2224
----------
@@ -34,7 +36,7 @@ def draws2data(model, idata_kwargs, data_dict):
3436

3537
return draws2data_idata
3638

37-
def data2draws(model, idata_kwargs, data_dict):
39+
def data2draws(model, idata_kwargs, data_dict) -> az.InferenceData:
3840
"""
3941
Parameters
4042
----------
@@ -49,10 +51,9 @@ def data2draws(model, idata_kwargs, data_dict):
4951
# add observed_data to idata_kwargs
5052
observed_data = {k: v for k, v in data_dict.items() if k in model.get_obs_vector_names()}
5153
data2draws_idata = az.from_cmdstanpy(posterior=data2draws_data, observed_data = observed_data, **idata_kwargs)
52-
5354
return data2draws_idata
5455

55-
def draws2data2draws(vensim, setting, precision, numeric, prior, idata_kwargs):
56+
def draws2data2draws(vensim, setting, precision, numeric, prior, idata_kwargs) -> typing.Tuple[az.InferenceData, "stanify.stan_model.vensim2stan"]:
5657
"""
5758
vensim: vensim filepath which provides structral assumption
5859
setting: modeler's selection of which parameter to estimate

stanify/calibrator/visualizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,16 @@ def plot_qoi(sbc_precision, setting, precision, idata_kwargs, model_name):
4646
fig, axes = plt.subplots(precision['R'], 1, figsize=(30, 20))
4747
for r, ax in zip(range(precision['R']), axes):
4848
sbc_aux = sbc.sel(region=r)
49-
sbc_aux.observed_data[obs_name].plot(hue='prior_draw', x='time', ax=ax, alpha=.3)
49+
sbc_aux.observed_data[obs_name].plot(hue='prior_draw', x='time', ax=ax, alpha=.6, figsize = figsize)
5050
sbc_aux.posterior_predictive[f'{obs_name}_post'].mean(['draw', 'chain']).plot(hue='prior_draw', x='time', ax=ax,
51-
alpha=.6, linestyle='dotted')
51+
alpha=.8, linestyle='dotted', figsize = figsize)
5252
save_fig(model_name, False, f"{obs_name}_ppc")
5353
plt.clf()
5454
else:
5555
for obs_name in idata_kwargs['prior_predictive']:
56-
sbc.observed_data[obs_name].plot(hue='prior_draw', x='time', alpha=.3, figsize = figsize)
56+
sbc.observed_data[obs_name].plot(hue='prior_draw', x='time', alpha=.6, figsize = figsize)
5757
sbc.posterior_predictive[f'{obs_name}_post'].mean(['draw', 'chain']).plot(hue='prior_draw', x='time',
58-
alpha=.6, linestyle='dotted', figsize = figsize)
58+
alpha=1, linestyle='dotted', figsize = figsize)
5959
save_fig(model_name, False, f"{obs_name}_ppc")
6060
plt.clf()
6161

0 commit comments

Comments
 (0)