Skip to content

Commit 6957c95

Browse files
committed
Refactor stan builder, add loglik inside stan, share step_size scheme
1 parent 7e3d786 commit 6957c95

File tree

621 files changed

+2519
-48356
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

621 files changed

+2519
-48356
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"inv_metric": [0.00890483, 0.0146411, 0.0180988, 0.0173486, 0.0162696, 0.0185366, 0.0186761, 0.0176091, 0.0177719, 0.0194737, 0.0269911, 0.0241012, 0.0270683, 0.0305894, 0.0295884, 0.0285179, 0.0270491, 0.0247198, 0.00094971]}

stanify/builders/stan_block_builder.py

Lines changed: 43 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -482,24 +482,18 @@ def build_block(self, hier_est_param_names):
482482
if self.precision_context.R == 1:
483483
for statement in self.stan_model_context.sample_statements:
484484
if statement.distribution_type != statement.assignment_dist:
485-
code += f"{statement.lhs_expr} ~ {statement.distribution_type}({', '.join([str(arg) for arg in statement.distribution_args])});\n"
485+
code += f"{adj_expr(statement)};\n"
486486
else:
487487
for statement in self.stan_model_context.sample_statements:
488-
param_name = statement.lhs_expr
489-
if param_name in hier_est_param_names:
488+
if statement.lhs_expr in hier_est_param_names:
490489
dist_code = "rep_vector(" + f'{statement.distribution_args[0]}, R), ' + f"{', '.join(statement.distribution_args[1:])}"
491-
code += f"{param_name} ~ {statement.distribution_type}({dist_code});\n"
492-
493-
elif param_name in self.stan_model_context.obs_integ_outcome_vector_names:
494-
code += "for (r in 1:R)\n"
495-
code.indent_level += 1
496-
dist_code = f'{param_name}'[:-4] + "[:, r], " + f"{', '.join(statement.distribution_args[1:])}"
497-
code += f"{param_name}[:, r] ~ {statement.distribution_type}({dist_code});\n"
498-
code.indent_level -= 1
490+
code += f"{statement.lhs_expr} ~ {statement.distribution_type}({dist_code});\n"
499491

492+
elif statement.lhs_expr in self.stan_model_context.obs_integ_outcome_vector_names:
493+
code += f"{adj_expr(statement, is_hier=True)};\n"
500494
else:
501495
if statement.distribution_type != statement.assignment_dist:
502-
code += f"{param_name} ~ {statement.distribution_type}({', '.join([str(arg) for arg in statement.distribution_args])});\n"
496+
code += f"{adj_expr(statement)};\n"
503497
code.indent_level -= 1
504498
code += "}\n"
505499
#TODO @Dashadower what is the diff btw classes that has its own code vs not (self.code VS return str(code))
@@ -511,25 +505,25 @@ def __init__(self, precision_context: "PrecisionContext", stan_model_context: "S
511505
self.precision_context = precision_context
512506
self.stan_model_context = stan_model_context
513507
self.vensim_model_context = vensim_model_context
514-
# TODO @Dashadower how to write message, if some target_simulated_vector_names is not in vensim_integ_outcome (inconsistency btw user-defined and vensim syntax)
508+
# TODO @Dashadower how to write message, if some target_sim_vector_names is not in vensim_integ_outcome (inconsistency btw user-defined and vensim syntax)
515509
integ_outcome_vector_names = set(self.stan_model_context.target_integ_outcome_vector_names) & set(self.vensim_model_context.integ_outcome_vector_names)
516510

517511
def build_block(self, hier_est_param_names, transformed_parameters_code: str = ""):
518512
self.code = IndentedString()
519513
self.code += "generated quantities{\n"
520514
self.code.indent_level += 1
521-
self.build_param_rng_functions(hier_est_param_names)
515+
self.build_param_pri_pred_functions(hier_est_param_names)
522516
self.code += "\n"
523517
self.code.add_raw(transformed_parameters_code, ignore_indent=True)
524518
self.code += "\n"
525-
self.build_obs_rng_functions()
519+
self.build_data_pri_pred_functions()
526520
self.code.indent_level -= 1
527521
self.code += "}\n"
528522

529523
return str(self.code)
530524

531525

532-
def build_param_rng_functions(self, hier_est_param_names):
526+
def build_param_pri_pred_functions(self, hier_est_param_names):
533527

534528
ignored_variables = set(self.stan_model_context.stan_data.keys()).union(
535529
set(self.vensim_model_context.integ_outcome_vector_names))
@@ -557,40 +551,29 @@ def build_param_rng_functions(self, hier_est_param_names):
557551
if statement.init_state:
558552
param_name = param_name + "__init"
559553
if param_name in hier_est_param_names:
560-
561-
562554
dist_code = "rep_vector(" + f'{statement.distribution_args[0]}, R), ' + f"{', '.join(statement.distribution_args[1:])}"
563555
self.code += f"real {param_name}[R] = {statement.distribution_type}_rng({dist_code});\n"
564556
else:
565-
self.code += f"real {param_name} = {statement.distribution_type}_rng({', '.join(statement.distribution_args)});\n"
557+
self.code += f"real {adj_expr(statement, is_pri_pred=True)};\n"
566558
processed_statements.add(statement)
567559

568-
def build_obs_rng_functions(self):
560+
def build_data_pri_pred_functions(self):
569561
if self.precision_context.R == 1:
570562
self.code += "// Define and assign generated value to observed vector (matching vector)\n"
571-
572563
for statement in self.stan_model_context.sample_statements:
573564
if statement.lhs_expr in self.stan_model_context.obs_integ_outcome_vector_names:
574-
vec_name = statement.lhs_expr
575-
self.code += f"array [N] real {vec_name} = {statement.distribution_type}_rng({', '.join(statement.distribution_args)});\n"
565+
self.code += f"array [N] real {adj_expr(statement, is_pri_pred=True)};\n"
576566
else:
577567
self.code += "// Define observed vector (matching vector)\n"
578568
for statement in self.stan_model_context.sample_statements:
579569
if statement.lhs_expr in self.stan_model_context.obs_integ_outcome_vector_names:
580570
self.code += f"array[N] vector[R] {statement.lhs_expr};\n"
581571

582572
self.code += "// Assign generated value to observed vector (matching vector)\n"
583-
self.code += "for (r in 1:R){\n"
584-
self.code.indent_level += 1
585573
for statement in self.stan_model_context.sample_statements:
586-
#TODO @Dashadower statement.lhs_variable vs .lhs_expr
587574
if statement.lhs_expr in self.stan_model_context.obs_integ_outcome_vector_names:
588-
vec_name = statement.lhs_expr
589-
dist_code = f'{vec_name}'[:-4] + "[:, r], " + f"{', '.join(statement.distribution_args[1:])}"
590-
self.code += f"{vec_name}[:, r] = {statement.distribution_type}_rng({dist_code});\n"
575+
self.code += f"{adj_expr(statement, is_pri_pred=True, is_hier=True)};\n"
591576
# link(alpha) ~ N(0,1); link(alpha) is expr, alpha is var
592-
self.code.indent_level -= 1
593-
self.code += "}\n"
594577

595578
class Data2DrawsStanGQBuilder():
596579
def __init__(self, precision_context: "PrecisionContext", stan_model_context: "StanModelContext", vensim_model_context: "VensimModelContext"):
@@ -599,75 +582,58 @@ def __init__(self, precision_context: "PrecisionContext", stan_model_context: "S
599582
self.vensim_model_context = vensim_model_context
600583

601584

602-
def build_block(self):
585+
def build_block(self, hier_est_param_names):
603586
self.code = IndentedString()
604587
self.code += "generated quantities{\n"
605588
self.code.indent_level += 1
606589
self.build_post_pred_rng_functions()
607590
self.code += "\n"
608-
self.build_loglik_functions()
591+
self.build_loglik_functions(hier_est_param_names)
609592
self.code.indent_level -= 1
610593
self.code += "}\n"
611594

612595
return str(self.code)
613596

614-
def build_loglik_functions(self):
597+
def build_loglik_functions(self, hier_est_param_names):
615598
self.code += "real loglik;\n"
599+
self.code += "real loglik_prior;\n"
600+
for tn in self.stan_model_context.target_integ_outcome_vector_names:
601+
self.code += f"real loglik_{tn};\n"
616602

617-
if self.precision_context.R == 1:
618-
for statement in self.stan_model_context.sample_statements:
619-
if statement.lhs_expr in self.stan_model_context.obs_integ_outcome_vector_names:
620-
param_name = statement.lhs_expr
621-
loc = statement.distribution_args[0]
622-
scale = statement.distribution_args[1]
623-
if statement.distribution_type in ["normal", "lognormal"]:
624-
self.code += f"loglik += {statement.distribution_type}_lpdf({param_name}|{loc}, {scale});\n"
625-
elif statement.distribution_type in ["neg_binom_2"]:
626-
self.code += f"loglik += {statement.distribution_type}_lpmf({param_name}|{loc}, {scale});\n"
627-
else:
628-
self.code += "for (r in 1:R){\n"
629-
self.code.indent_level += 1
630-
for statement in self.stan_model_context.sample_statements:
631-
if statement.lhs_expr in self.stan_model_context.obs_integ_outcome_vector_names:
632-
obs_vec_name = statement.lhs_expr
633-
target_vec_name = obs_vec_name[:-4]
634-
scale = statement.distribution_args[1]
635-
if statement.distribution_type in ["normal", "lognormal"]:
636-
self.code += f"loglik += {statement.distribution_type}_lpdf({obs_vec_name}[:, r]|{target_vec_name}[:, r], {scale});\n"
637-
elif statement.distribution_type in ["neg_binom_2"]:
638-
self.code += f"loglik += {statement.distribution_type}_lpmf({obs_vec_name}[:, r]|{target_vec_name}[:, r], {scale});\n"
639-
self.code.indent_level -= 1
640-
self.code += "}\n"
603+
# add loglik for "matter ~ form"
604+
# self.stan_model_context.all_stan_variables include "matter"
605+
# 1. parameter draw ~ prior distribution
606+
# X. target_sim = parameter draw
607+
# 2. observed data ~ likelihood distribution (target_sim)
608+
609+
#is_lp_pq=True: component-wise (P for estimated parameter, Q for target_simulated)
610+
for statement in self.stan_model_context.sample_statements:
611+
if statement.lhs_expr in list(self.stan_model_context.all_stan_variables): #['adj_frac1[R]', 'adj_frac2', 'm_noise_scale', 'stocked_pping_obs[R]', 'stocked_ping_obs[R]']
612+
# lp contribution from prior function and realized estimated parameter value
613+
if statement.lhs_expr in hier_est_param_names:
614+
self.code += f"{adj_expr(statement, is_lp_tot=True, is_hier=True)};\n" #
615+
self.code += f"{adj_expr(statement, is_lp_prior=True, is_hier=True)};\n"
616+
617+
# lp contribution from likelihood function and observed data
618+
elif statement.lhs_expr in self.stan_model_context.obs_integ_outcome_vector_names:
619+
self.code += f"{adj_expr(statement, is_lp_tot=True, is_hier=(self.precision_context.R > 1))};\n"
620+
self.code += f"{adj_expr(statement, is_lp_q=True, is_hier=(self.precision_context.R > 1))};\n"
641621

642-
def build_post_pred_rng_functions(self):
643622

623+
def build_post_pred_rng_functions(self):
644624
if self.precision_context.R == 1:
645625
self.code += "// Define and assign generated value to posterior predictive vector\n"
646626
for statement in self.stan_model_context.sample_statements:
647627
if statement.lhs_expr in self.stan_model_context.obs_integ_outcome_vector_names:
648-
# TODO @Dashadower how to use the following in the future?
649-
# stan_type = self.stan_model_context.stan_data[statement.lhs_expr].stan_type # stan_type = self.stan_model_context.stan_data[statement.lhs_expr].stan_type KeyError: 'prey_obs'
650-
scale = statement.distribution_args[1]
651-
# if stan_type.startswith("vector"):
652-
# self.code += f"{stan_type} {statement.lhs_expr}_post = to_vector({statement.distribution_type}_rng({', '.join(statement.distribution_args)}));\n"
653-
# else:
654-
# self.code += f"{stan_type} {statement.lhs_expr}_post = {statement.distribution_type}_rng({', '.join(statement.distribution_args)});\n"
655-
self.code += f"array[N] real {statement.lhs_expr}_post = {statement.distribution_type}_rng({', '.join(statement.distribution_args)});\n"
656-
else:
628+
self.code += f"array[N] real {adj_expr(statement, is_post_pred=True)};\n"
629+
630+
elif self.precision_context.R > 1:
657631
self.code += "// Define observed vector (matching vector)\n"
658632
for statement in self.stan_model_context.sample_statements:
659633
if statement.lhs_expr in self.stan_model_context.obs_integ_outcome_vector_names:
660634
self.code += f"array[N] vector[R] {statement.lhs_expr}_post;\n"
661635

662636
self.code += "// Assign generated value to observed vector (matching vector)\n"
663-
self.code += "for (r in 1:R){\n"
664-
self.code.indent_level += 1
665637
for statement in self.stan_model_context.sample_statements:
666-
#TODO @Dashadower statement.lhs_variable vs .lhs_expr, obs_integ_outcome_vector_names is list VS target_// is tuple so the latter doesn't work
667638
if statement.lhs_expr in self.stan_model_context.obs_integ_outcome_vector_names:
668-
scale = statement.distribution_args[1]
669-
#TODO {', '.join(statement.distribution_args)} is better; posterior predictive remove _obs?? (w.o. loc, scale)
670-
self.code += f"{statement.lhs_expr}_post[:, r] = {statement.distribution_type}_rng({statement.lhs_expr[:-4]}[:, r], {scale});\n"
671-
self.code.indent_level -= 1
672-
self.code += "}\n"
673-
639+
self.code += f"{adj_expr(statement, is_post_pred=True, is_hier=True)};\n"

stanify/builders/stan_model.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,24 +181,25 @@ def set_prior(self, variable_name: str, distribution_type: str, *args, lower=flo
181181
if name in self.vensim_model_context.variable_names and name not in self.vensim_model_context.integ_outcome_vector_names:
182182
self.stan_model_context.exposed_parameters.update(used_variable_names)
183183

184-
if variable_name in self.vensim_model_context.variable_names and variable_name not in self.vensim_model_context.integ_outcome_vector_names:
184+
if (variable_name in self.vensim_model_context.variable_names) and (variable_name not in self.vensim_model_context.integ_outcome_vector_names):
185+
# adj_frac1_loc is excluded as it is not defined in vensim
185186
self.stan_model_context.exposed_parameters.add(variable_name)
186187

187188

188189
self.stan_model_context.sample_statements.append(SamplingStatement(variable_name, distribution_type, *args, lower=lower, upper=upper, init_state=init_state))
189190

190191

191-
def set_type(self, est_param_names: list, hier_est_param_names: list, target_simulated_vector_names: list, driving_vector_names: list, model_name: str):
192+
def set_type(self, est_param_names: list, hier_est_param_names: list, target_sim_vector_names: list, driving_vector_names: list, model_name: str):
192193
self.est_param_names = est_param_names
193194
self.hier_est_param_names = hier_est_param_names
194-
# TODO @Dashadower how to make target_simulated_vector_names and integ_outcome_vector_names consistent? make class TypeContext? related to lin
195-
self.stan_model_context.target_integ_outcome_vector_names = target_simulated_vector_names
195+
# TODO @Dashadower how to make target_sim_vector_names and integ_outcome_vector_names consistent? make class TypeContext? related to lin
196+
self.stan_model_context.target_integ_outcome_vector_names = target_sim_vector_names
196197
self.stan_model_context.obs_integ_outcome_vector_names = [f'{name}_obs' for name in self.stan_model_context.target_integ_outcome_vector_names]
197198
self.driving_vector_names = driving_vector_names
198199
self.model_name = model_name
199200

200-
#TODO @dashadower for external refernce of target_simulated_vector_names
201-
# is it better to use model.target_simulated_vector_names or define function for consistency?
201+
#TODO @dashadower for external refernce of target_sim_vector_names
202+
# is it better to use model.target_sim_vector_names or define function for consistency?
202203
def get_latent_vector_names(self):
203204
return [f'{target}' for target in self.stan_model_context.target_integ_outcome_vector_names]
204205

@@ -211,9 +212,9 @@ def get_latent_obs_vector_names(self):
211212
return self.get_latent_vector_names() + self.get_obs_vector_names()
212213

213214

214-
def update_setting(self, est_param_names: list, target_simulated_vector_names: list, driving_vector_names: list, model_name: str):
215+
def update_setting(self, est_param_names: list, target_sim_vector_names: list, driving_vector_names: list, model_name: str):
215216
self.est_param_names = est_param_names
216-
self.stan_model_context.target_integ_outcome_vector_names = target_simulated_vector_names
217+
self.stan_model_context.target_integ_outcome_vector_names = target_sim_vector_names
217218
self.driving_vector_names = driving_vector_names
218219
self.model_name = model_name
219220
# if self.initial_time in self.integration_times:
@@ -386,7 +387,7 @@ def stanify_data2draws(self):
386387
f.write(StanModelBuilder(self.precision_context, self.stan_model_context).build_block(self.hier_est_param_names))
387388
f.write("\n")
388389

389-
f.write(Data2DrawsStanGQBuilder(self.precision_context,self.stan_model_context, self.vensim_model_context,).build_block())
390+
f.write(Data2DrawsStanGQBuilder(self.precision_context,self.stan_model_context, self.vensim_model_context,).build_block(self.hier_est_param_names))
390391

391392
stan_model = cmdstanpy.CmdStanModel(stan_file=stan_data2draws_path, cpp_options={'STAN_THREADS':'true'})
392393
return stan_model

0 commit comments

Comments
 (0)