@@ -81,6 +81,14 @@ def code(self):
81
81
class InitialValueStatementsCodegen :
82
82
"""
83
83
Generates code for statements needed for calculating initial values
84
+
85
+ Attributes
86
+ ----------
87
+ requires_params : bool
88
+ Boolean indicating whether any of the initial value variables requires a stan parameter to calculate. This means
89
+ that the initial value can't be calculated at `transformed data`, but instead `transformed parameters`
90
+ required_params : set[str]
91
+ Set holding the stan parameters which are required for calculating initial values, if any.
84
92
"""
85
93
_code : IndentedString = field (init = False , default_factory = IndentedString )
86
94
requires_params : bool = field (init = False , default = True )
@@ -93,7 +101,6 @@ def code(self) -> str:
93
101
def generate (self , v2s_code_handler : Vensim2StanCodeHandler , vensim_model_context : VensimModelContext ,
94
102
stan_model_context : StanModelContext ) -> None :
95
103
# Calculate the initial values for the stock variables.
96
- self ._code += "\n "
97
104
self ._code += "// initial stock values\n "
98
105
99
106
# Create the initial state vector. This will be passed to the ODE solver
@@ -393,6 +400,72 @@ def code(self) -> str:
393
400
return str (self ._code )
394
401
395
402
403
+ @dataclass
404
+ class Data2DrawsGeneratedQuantitiesLogLikV2SWalker (ModelBlockStatementV2SWalker ):
405
+ """
406
+ Generates code to generate `_loglik += "dist_lpdf(lhs_variate | arg1, arg2);`.
407
+ Since the LHS variable doesn't go to the LHS of the generated code, we first generate code for the LHS normally,
408
+ but after that cut the code into a temporary variable(`lhs_code`).
409
+ After that, we prepend the contents of the temporary variable as the first special argument of the distribution
410
+ function.
411
+
412
+ Attributes
413
+ ----------
414
+ lhs_code : str
415
+ The variable that holds the generated code for the LHS variable. The contents of this variable is then used
416
+ during codegen for the lpdf/lpmf distribution function, which adds the value as the first argument.
417
+ """
418
+ lhs_code : str = field (init = False , default = "" )
419
+
420
+ def walk_Statement (self , node : ast .Statement ):
421
+ if node .op == "=" :
422
+ return
423
+ self .walk (node .left )
424
+ self ._code += "loglik += "
425
+ self .walk (node .right )
426
+ self ._code += ";\n "
427
+
428
+ def walk_FunctionCall (self , node : ast .FunctionCall ):
429
+ function_name_map = {
430
+ "bernoulli" : "bernoulli_lpmf" ,
431
+ "binomial" : "binomial_lpmf" ,
432
+ "neg_binomial" : "neg_binomial_lpmf" ,
433
+ "poisson" : "poisson_lpmf" ,
434
+ "normal" : "normal_lpdf" ,
435
+ "cauchy" : "caucly_lpdf" ,
436
+ "lognormal" : "lognormal_lpdf" ,
437
+ "exponential" : "exponential_lpdf" ,
438
+ "gamma" : "gamma_lpdf" ,
439
+ "weibull" : "weibull_lpdf" ,
440
+ "beta" : "beta_lpdf"
441
+ }
442
+ if node .name in function_name_map :
443
+ self ._code += f"{ function_name_map [node .name ]} ("
444
+
445
+ # Add the LHS variable to the log probability function
446
+ self ._code += f"{ self .lhs_code } | "
447
+ else :
448
+ self ._code += node .name
449
+
450
+ for index , arg in enumerate (node .arglist ):
451
+ self .walk (arg )
452
+ if index < len (node .arglist ) - 1 :
453
+ self ._code += ", "
454
+ self ._code += ")"
455
+
456
+ def walk_Variable (self , node : ast .Variable ):
457
+ # Check if we need to retrieve the code of the LHS variable
458
+ flush_lhs = False
459
+ if not self .lhs_code :
460
+ flush_lhs = True
461
+
462
+ super ().walk_Variable (node )
463
+
464
+ if flush_lhs :
465
+ self .lhs_code = str (self ._code )
466
+ self ._code .clear ()
467
+
468
+
396
469
class ModelBlockCodegen (StanBlockCodegen ):
397
470
"""
398
471
Generates code for the `model` Stan block. This includes prior and likelihood specifications.
@@ -563,6 +636,58 @@ def build_forloops(node: Node, current_subscripts: dict[str, str], nest_level=0)
563
636
self ._code += walker .code
564
637
565
638
639
+ class Data2DrawsGeneratedQuantitiesBlockCodegen (StanBlockCodegen ):
640
+ """
641
+ For Data2Draws, the generated quantities block holds only the log likelihood calculations.
642
+ """
643
+ def generate (self , v2s_code_handler : Vensim2StanCodeHandler , vensim_model_context : VensimModelContext ,
644
+ stan_model_context : StanModelContext ) -> None :
645
+
646
+ # Insert the loglik variable
647
+ self ._code += "real loglik = 0.0;\n "
648
+
649
+ for statement in v2s_code_handler .program_ast .statements :
650
+ if statement .op == "=" :
651
+ continue
652
+
653
+ left_variable = statement .left
654
+
655
+ # If the LHS is a data variable, we ignore the sample statement
656
+ if left_variable .not_param :
657
+ continue
658
+
659
+ self .generate_code_for_statements (statement , v2s_code_handler , vensim_model_context , stan_model_context )
660
+
661
+ def generate_code_for_statements (self , statement : ast .Statement , v2s_code_handler : Vensim2StanCodeHandler ,
662
+ vensim_model_context : VensimModelContext , stan_model_context : StanModelContext ):
663
+ if statement .op == "=" :
664
+ return
665
+
666
+ left_variable = statement .left
667
+
668
+ loop_variable_mapping = {} # key is subscript name, value is loop variable
669
+ if left_variable .subscripts :
670
+ indent_levels = len (left_variable .subscripts )
671
+
672
+ for nest_level in range (indent_levels ):
673
+ loop_variable = chr (ord ("i" ) + nest_level )
674
+ loop_bound = left_variable .subscripts [nest_level ]
675
+ self ._code += f"for ({ loop_variable } in 1:{ loop_bound } ){{\n "
676
+ self ._code .indent_level += 1
677
+ loop_variable_mapping [loop_bound ] = loop_variable
678
+
679
+ loglik_walker = Data2DrawsGeneratedQuantitiesLogLikV2SWalker (loop_variable_mapping ,
680
+ stan_model_context .array_dims_subscript_map )
681
+ loglik_walker .walk (statement )
682
+ self ._code += loglik_walker .code
683
+
684
+ if left_variable .subscripts :
685
+ for nest_level in range (len (left_variable .subscripts )):
686
+ self ._code .indent_level -= 1
687
+ self ._code += "}\n "
688
+
689
+
690
+
566
691
class Data2DrawsDataBlockCodegen (StanBlockCodegen ):
567
692
def generate (self , v2s_code_handler : Vensim2StanCodeHandler , vensim_model_context : VensimModelContext ,
568
693
stan_model_context : StanModelContext ) -> None :
@@ -614,10 +739,20 @@ def walk_FunctionCall(self, node: ast.FunctionCall):
614
739
self ._code += ")"
615
740
616
741
742
+ class Draws2DataGeneratedQuantitiesLogLikV2SWalker (Data2DrawsGeneratedQuantitiesLogLikV2SWalker ):
743
+ """
744
+ This class is the same as `Data2DrawsGeneratedQuantitiesLogLikV2SWalker`, and hence just inherits and does nothing
745
+ else.
746
+ """
747
+
748
+
617
749
class Draws2DataGeneratedQuantitiesBlockCodegen (StanBlockCodegen ):
618
750
def generate (self , v2s_code_handler : Vensim2StanCodeHandler , vensim_model_context : VensimModelContext ,
619
751
stan_model_context : StanModelContext ) -> None :
620
752
753
+ # Insert the loglik variable
754
+ self ._code += "real loglik = 0.0;\n "
755
+
621
756
# Draw the parameters. Sort the sampling statements in topological order
622
757
# Variables defined in transformed data, functions block, and stocks are never a parameter
623
758
ignored_variables = stan_model_context .transformed_data_variables .union (stan_model_context .timestep_variant_datafunc_variables ,
@@ -753,6 +888,11 @@ def generate_code_for_statements(self, statement: ast.Statement, v2s_code_handle
753
888
walker .walk (statement )
754
889
self ._code += walker .code
755
890
891
+ loglik_walker = Draws2DataGeneratedQuantitiesLogLikV2SWalker (loop_variable_mapping ,
892
+ stan_model_context .array_dims_subscript_map )
893
+ loglik_walker .walk (statement )
894
+ self ._code += loglik_walker .code
895
+
756
896
if left_variable .subscripts :
757
897
for nest_level in range (len (left_variable .subscripts )):
758
898
self ._code .indent_level -= 1
@@ -837,6 +977,10 @@ def generate_and_write(self, full_file_path: Path, functions_file_name: str) ->
837
977
model_gen .generate (self .v2s_code_handler , self .vensim_model_context , self .stan_model_context )
838
978
f .write (model_gen .code )
839
979
980
+ gq_gen = Data2DrawsGeneratedQuantitiesBlockCodegen ("generated quantities" )
981
+ gq_gen .generate (self .v2s_code_handler , self .vensim_model_context , self .stan_model_context )
982
+ f .write (gq_gen .code )
983
+
840
984
841
985
class Draws2DataCodegen (StanFileCodegen ):
842
986
def generate_and_write (self , full_file_path : Path , functions_file_name : str ) -> None :
0 commit comments