Skip to content

Commit 7ad86eb

Browse files
committed
Implement loglik and ecdf diff plots
1 parent dbfff9b commit 7ad86eb

File tree

3 files changed

+169
-10
lines changed

3 files changed

+169
-10
lines changed

stanify/builders/stan_block_codegen.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ def code(self):
8181
class InitialValueStatementsCodegen:
8282
"""
8383
Generates code for statements needed for calculating initial values
84+
85+
Attributes
86+
----------
87+
requires_params : bool
88+
Boolean indicating whether any of the initial value variables requires a stan parameter to calculate. This means
89+
that the initial value can't be calculated at `transformed data`, but instead `transformed parameters`
90+
required_params : set[str]
91+
Set holding the stan parameters which are required for calculating initial values, if any.
8492
"""
8593
_code: IndentedString = field(init=False, default_factory=IndentedString)
8694
requires_params: bool = field(init=False, default=True)
@@ -93,7 +101,6 @@ def code(self) -> str:
93101
def generate(self, v2s_code_handler: Vensim2StanCodeHandler, vensim_model_context: VensimModelContext,
94102
stan_model_context: StanModelContext) -> None:
95103
# Calculate the initial values for the stock variables.
96-
self._code += "\n"
97104
self._code += "// initial stock values\n"
98105

99106
# Create the initial state vector. This will be passed to the ODE solver
@@ -393,6 +400,72 @@ def code(self) -> str:
393400
return str(self._code)
394401

395402

403+
@dataclass
404+
class Data2DrawsGeneratedQuantitiesLogLikV2SWalker(ModelBlockStatementV2SWalker):
405+
"""
406+
Generates code to generate `_loglik += "dist_lpdf(lhs_variate | arg1, arg2);`.
407+
Since the LHS variable doesn't go to the LHS of the generated code, we first generate code for the LHS normally,
408+
but after that cut the code into a temporary variable(`lhs_code`).
409+
After that, we prepend the contents of the temporary variable as the first special argument of the distribution
410+
function.
411+
412+
Attributes
413+
----------
414+
lhs_code : str
415+
The variable that holds the generated code for the LHS variable. The contents of this variable is then used
416+
during codegen for the lpdf/lpmf distribution function, which adds the value as the first argument.
417+
"""
418+
lhs_code: str = field(init=False, default="")
419+
420+
def walk_Statement(self, node: ast.Statement):
421+
if node.op == "=":
422+
return
423+
self.walk(node.left)
424+
self._code += "loglik += "
425+
self.walk(node.right)
426+
self._code += ";\n"
427+
428+
def walk_FunctionCall(self, node: ast.FunctionCall):
429+
function_name_map = {
430+
"bernoulli": "bernoulli_lpmf",
431+
"binomial": "binomial_lpmf",
432+
"neg_binomial": "neg_binomial_lpmf",
433+
"poisson": "poisson_lpmf",
434+
"normal": "normal_lpdf",
435+
"cauchy": "caucly_lpdf",
436+
"lognormal": "lognormal_lpdf",
437+
"exponential": "exponential_lpdf",
438+
"gamma": "gamma_lpdf",
439+
"weibull": "weibull_lpdf",
440+
"beta": "beta_lpdf"
441+
}
442+
if node.name in function_name_map:
443+
self._code += f"{function_name_map[node.name]}("
444+
445+
# Add the LHS variable to the log probability function
446+
self._code += f"{self.lhs_code} | "
447+
else:
448+
self._code += node.name
449+
450+
for index, arg in enumerate(node.arglist):
451+
self.walk(arg)
452+
if index < len(node.arglist) - 1:
453+
self._code += ", "
454+
self._code += ")"
455+
456+
def walk_Variable(self, node: ast.Variable):
457+
# Check if we need to retrieve the code of the LHS variable
458+
flush_lhs = False
459+
if not self.lhs_code:
460+
flush_lhs = True
461+
462+
super().walk_Variable(node)
463+
464+
if flush_lhs:
465+
self.lhs_code = str(self._code)
466+
self._code.clear()
467+
468+
396469
class ModelBlockCodegen(StanBlockCodegen):
397470
"""
398471
Generates code for the `model` Stan block. This includes prior and likelihood specifications.
@@ -563,6 +636,58 @@ def build_forloops(node: Node, current_subscripts: dict[str, str], nest_level=0)
563636
self._code += walker.code
564637

565638

639+
class Data2DrawsGeneratedQuantitiesBlockCodegen(StanBlockCodegen):
640+
"""
641+
For Data2Draws, the generated quantities block holds only the log likelihood calculations.
642+
"""
643+
def generate(self, v2s_code_handler: Vensim2StanCodeHandler, vensim_model_context: VensimModelContext,
644+
stan_model_context: StanModelContext) -> None:
645+
646+
# Insert the loglik variable
647+
self._code += "real loglik = 0.0;\n"
648+
649+
for statement in v2s_code_handler.program_ast.statements:
650+
if statement.op == "=":
651+
continue
652+
653+
left_variable = statement.left
654+
655+
# If the LHS is a data variable, we ignore the sample statement
656+
if left_variable.not_param:
657+
continue
658+
659+
self.generate_code_for_statements(statement, v2s_code_handler, vensim_model_context, stan_model_context)
660+
661+
def generate_code_for_statements(self, statement: ast.Statement, v2s_code_handler: Vensim2StanCodeHandler,
662+
vensim_model_context: VensimModelContext, stan_model_context: StanModelContext):
663+
if statement.op == "=":
664+
return
665+
666+
left_variable = statement.left
667+
668+
loop_variable_mapping = {} # key is subscript name, value is loop variable
669+
if left_variable.subscripts:
670+
indent_levels = len(left_variable.subscripts)
671+
672+
for nest_level in range(indent_levels):
673+
loop_variable = chr(ord("i") + nest_level)
674+
loop_bound = left_variable.subscripts[nest_level]
675+
self._code += f"for ({loop_variable} in 1:{loop_bound}){{\n"
676+
self._code.indent_level += 1
677+
loop_variable_mapping[loop_bound] = loop_variable
678+
679+
loglik_walker = Data2DrawsGeneratedQuantitiesLogLikV2SWalker(loop_variable_mapping,
680+
stan_model_context.array_dims_subscript_map)
681+
loglik_walker.walk(statement)
682+
self._code += loglik_walker.code
683+
684+
if left_variable.subscripts:
685+
for nest_level in range(len(left_variable.subscripts)):
686+
self._code.indent_level -= 1
687+
self._code += "}\n"
688+
689+
690+
566691
class Data2DrawsDataBlockCodegen(StanBlockCodegen):
567692
def generate(self, v2s_code_handler: Vensim2StanCodeHandler, vensim_model_context: VensimModelContext,
568693
stan_model_context: StanModelContext) -> None:
@@ -614,10 +739,20 @@ def walk_FunctionCall(self, node: ast.FunctionCall):
614739
self._code += ")"
615740

616741

742+
class Draws2DataGeneratedQuantitiesLogLikV2SWalker(Data2DrawsGeneratedQuantitiesLogLikV2SWalker):
743+
"""
744+
This class is the same as `Data2DrawsGeneratedQuantitiesLogLikV2SWalker`, and hence just inherits and does nothing
745+
else.
746+
"""
747+
748+
617749
class Draws2DataGeneratedQuantitiesBlockCodegen(StanBlockCodegen):
618750
def generate(self, v2s_code_handler: Vensim2StanCodeHandler, vensim_model_context: VensimModelContext,
619751
stan_model_context: StanModelContext) -> None:
620752

753+
# Insert the loglik variable
754+
self._code += "real loglik = 0.0;\n"
755+
621756
# Draw the parameters. Sort the sampling statements in topological order
622757
# Variables defined in transformed data, functions block, and stocks are never a parameter
623758
ignored_variables = stan_model_context.transformed_data_variables.union(stan_model_context.timestep_variant_datafunc_variables,
@@ -753,6 +888,11 @@ def generate_code_for_statements(self, statement: ast.Statement, v2s_code_handle
753888
walker.walk(statement)
754889
self._code += walker.code
755890

891+
loglik_walker = Draws2DataGeneratedQuantitiesLogLikV2SWalker(loop_variable_mapping,
892+
stan_model_context.array_dims_subscript_map)
893+
loglik_walker.walk(statement)
894+
self._code += loglik_walker.code
895+
756896
if left_variable.subscripts:
757897
for nest_level in range(len(left_variable.subscripts)):
758898
self._code.indent_level -= 1
@@ -837,6 +977,10 @@ def generate_and_write(self, full_file_path: Path, functions_file_name: str) ->
837977
model_gen.generate(self.v2s_code_handler, self.vensim_model_context, self.stan_model_context)
838978
f.write(model_gen.code)
839979

980+
gq_gen = Data2DrawsGeneratedQuantitiesBlockCodegen("generated quantities")
981+
gq_gen.generate(self.v2s_code_handler, self.vensim_model_context, self.stan_model_context)
982+
f.write(gq_gen.code)
983+
840984

841985
class Draws2DataCodegen(StanFileCodegen):
842986
def generate_and_write(self, full_file_path: Path, functions_file_name: str) -> None:

stanify/builders/utilities.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ def add_raw(self, string: str, ignore_indent: bool = False) -> None:
5757
else:
5858
self.__iadd__(string)
5959

60+
def clear(self) -> None:
61+
"""
62+
Flush the currently saved string
63+
"""
64+
self.string = ""
65+
6066
def __str__(self) -> str:
6167
return self.string
6268

stanify/calibrator/plots.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def _calculate_ranks(theta_xr, post_theta_xr):
4545
def _calculate_fractional_ranks(theta_xr, post_theta_xr):
4646
return (1 + np.sum(theta_xr > post_theta_xr)) / (1 + n_post_draws)
4747

48-
prior_input_core_dims = []# if not kwargs else list(kwargs.keys())
49-
post_input_core_dims = ["posterior_draw"]# if not kwargs else ["posterior_draw"] + list(kwargs.keys())
48+
prior_input_core_dims = []
49+
post_input_core_dims = ["posterior_draw"]
5050

5151
if fractional:
5252
ranks = xr.apply_ufunc(_calculate_fractional_ranks, prior_draws, post_draws,
@@ -87,7 +87,7 @@ def plot_rank_hist(sbc_idata: InferenceData, variable_name: str, bins=20, fracti
8787
plt.show()
8888

8989

90-
def plot_ecdf(sbc_idata: InferenceData, variable_name: str, gamma=0.8, **kwargs) -> None:
90+
def plot_ecdf(sbc_idata: InferenceData, variable_name: str, alpha: float = 0.01, diff: bool = False, **kwargs) -> None:
9191
"""
9292
Plot the calculated ECDF of the SBC ranks against expected ECDF envelope.
9393
@@ -98,8 +98,10 @@ def plot_ecdf(sbc_idata: InferenceData, variable_name: str, gamma=0.8, **kwargs)
9898
`stanify.builders.vensim2stan.Vensim2Stan.run_sbc`.
9999
variable_name : str
100100
The variable name to be plotted
101-
gamma : float
102-
The gamma parameter for calculating the expected ECDF envelope
101+
alpha : float
102+
The alpha parameter for calculating the expected ECDF envelope, indicating the confidence level.
103+
diff : bool
104+
Whether to plot the ECDF difference plot. Defaults to `False`.
103105
kwargs : Any
104106
Any additional arguments to be passed to the `InferenceData.isel` method. This is for when the variable is
105107
subscripted and has named dimensions. For example, if a parameter `sigma` has an additional dimension named
@@ -112,12 +114,19 @@ def plot_ecdf(sbc_idata: InferenceData, variable_name: str, gamma=0.8, **kwargs)
112114
def rank_ecdf(x):
113115
return np.sum(fractional_ranks < x) / n_prior_draw
114116

115-
ecdf_xaxis = np.linspace(0, 0.99, 100)
117+
ecdf_xaxis = np.linspace(0, 1 - 1e-7, n_prior_draw)
116118

117-
ecdf_lower = binom.ppf(gamma / 2, n_prior_draw, ecdf_xaxis) / n_prior_draw
118-
ecdf_upper = binom.ppf(1 - gamma / 2, n_prior_draw, ecdf_xaxis) / n_prior_draw
119+
ecdf_lower = binom.ppf(alpha / 2, n_prior_draw, ecdf_xaxis) / n_prior_draw
120+
ecdf_upper = binom.ppf(1 - alpha / 2, n_prior_draw, ecdf_xaxis) / n_prior_draw
119121

120-
plt.plot(ecdf_xaxis, np.vectorize(rank_ecdf)(ecdf_xaxis), "-", ms=2, color="black")
122+
ecdf_values = np.vectorize(rank_ecdf)(ecdf_xaxis)
123+
124+
if diff:
125+
ecdf_lower -= ecdf_xaxis
126+
ecdf_upper -= ecdf_xaxis
127+
ecdf_values -= ecdf_xaxis
128+
129+
plt.plot(ecdf_xaxis, ecdf_values, "-", ms=2, color="black")
121130
plt.plot(ecdf_xaxis, ecdf_lower, "-", color="green")
122131
plt.plot(ecdf_xaxis, ecdf_upper, "-", color="green")
123132
plt.title(f"ECDF for parameter {variable_name}")

0 commit comments

Comments
 (0)