Skip to content

Commit 7f87e4b

Browse files
authored
Add support for constructing target distribution from formula (#66)
* Add function for constructing target distribution from formula * Wrap trace function into target distribution arg * Ensure gradient a vector and add to NAMESPACE * Combine BridgeStan interface functions for constructing target distribution and trace function * Allow passing formula or Stan model directly to sample_chain * Use dummy variable declaration to avoid check note * Qualify deriv call with stats package name * Test target distribution from formula function * Use base::inherits in place of methods::is * Remove removed trace_function argument from sample_chain docs Clarify trace_function allowable output types * Test using invalid target distribution with sample_chain raises error * Test passing explicit trace function to sample_chain * Test using sample_chain with Stan model and log density formula works
1 parent 95acd2d commit 7f87e4b

13 files changed

+423
-249
lines changed

NAMESPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ export(sample_chain)
1313
export(scale_adapter)
1414
export(shape_adapter)
1515
export(stochastic_approximation_scale_adapter)
16+
export(target_distribution_from_log_density_formula)
1617
export(target_distribution_from_stan_model)
17-
export(trace_function_from_stan_model)
1818
export(variance_shape_adapter)

R/bridges.R

Lines changed: 77 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
#' Construct target distribution from a BridgeStan `StanModel` object.
22
#'
33
#' @param model Stan model object to use for target (posterior) distribution.
4+
#' @param include_log_density Whether to include an entry `log_density`
5+
#' corresponding to current log density for target distribution in values
6+
#' returned by trace function.
7+
#' @param include_generated_quantities Whether to included generated quantities
8+
#' in Stan model definition in values returned by trace function.
9+
#' @param include_transformed_parameters Whether to include transformed
10+
#' parameters in Stan model definition in values returned by trace function.
411
#'
512
#' @return A list with entries
613
#' * `log_density`: A function to evaluate log density function for target
714
#' distribution given current position vector.
815
#' * `value_and_gradient_log_density`: A function to evaluate value and gradient
916
#' of log density function for target distribution given current position
1017
#' vector, returning as a list with entries `value` and `gradient`.
18+
#' * `trace_function`: A function which given a `chain_state()` object returns a
19+
#' named vector of values to trace during sampling. The constrained parameter
20+
#' values of model will always be included.
1121
#'
1222
#' @export
1323
#'
@@ -18,46 +28,13 @@
1828
#' 876287L, state <- chain_state(stats::rnorm(model$param_unc_num()))
1929
#' )
2030
#' state$log_density(target_distribution)
21-
target_distribution_from_stan_model <- function(model) {
22-
list(
23-
log_density = model$log_density,
24-
value_and_gradient_log_density = function(position) {
25-
value_and_gradient <- model$log_density_gradient(position)
26-
names(value_and_gradient) <- c("value", "gradient")
27-
value_and_gradient
28-
}
29-
)
30-
}
31-
32-
#' Construct trace function from a BridgeStan `StanModel` object.
33-
#'
34-
#' @param model Stan model object to use to generate (constrained) parameters to
35-
#' trace.
36-
#' @param include_log_density Whether to include an entry `log_density`
37-
#' corresponding to current log density for target distribution in values
38-
#' returned by trace function.
39-
#' @param include_generated_quantities Whether to included generated quantities
40-
#' in Stan model definition in values returned by trace function.
41-
#' @param include_transformed_parameters Whether to include transformed
42-
#' parameters in Stan model definition in values returned by trace function.
43-
#'
44-
#' @return A function which given `chain_state` object returns a named vector of
45-
#' values to trace during sampling. The constrained parameter values of model
46-
#' will always be included.
47-
#'
48-
#' @export
49-
#'
50-
#' @examplesIf requireNamespace("bridgestan", quietly = TRUE)
51-
#' model <- example_gaussian_stan_model()
52-
#' trace_function <- trace_function_from_stan_model(model)
53-
#' withr::with_seed(876287L, state <- chain_state(rnorm(model$param_unc_num())))
54-
#' trace_function(state)
55-
trace_function_from_stan_model <- function(
31+
#' target_distribution$trace_function(state)
32+
target_distribution_from_stan_model <- function(
5633
model,
5734
include_log_density = TRUE,
5835
include_generated_quantities = FALSE,
5936
include_transformed_parameters = FALSE) {
60-
function(state) {
37+
trace_function <- function(state) {
6138
position <- state$position()
6239
trace_values <- model$param_constrain(
6340
position, include_transformed_parameters, include_generated_quantities
@@ -68,15 +45,25 @@ trace_function_from_stan_model <- function(
6845
}
6946
trace_values
7047
}
48+
list(
49+
log_density = model$log_density,
50+
value_and_gradient_log_density = function(position) {
51+
value_and_gradient <- model$log_density_gradient(position)
52+
names(value_and_gradient) <- c("value", "gradient")
53+
value_and_gradient
54+
},
55+
trace_function = trace_function
56+
)
7157
}
7258

7359
#' Construct an example BridgeStan `StanModel` object for a Gaussian model.
7460
#'
7561
#' Requires BridgeStan package to be installed. Generative model is assumed to
76-
#' be of the form `y ~ normal(mu, sigma)` for unknown `mu` and `sigma`.
62+
#' be of the form `y ~ normal(mu, sigma)` for unknown `mu ~ normal(0, 3)` and
63+
#' `sigma ~ half_normal(0, 3)`.
7764
#'
7865
#' @param n_data Number of independent data points `y` to generate and condition
79-
#' model against.
66+
#' model against from `normal(0, 1)`.
8067
#' @param seed Integer seed for Stan model.
8168
#'
8269
#' @return BridgeStan StanModel object.
@@ -88,20 +75,23 @@ trace_function_from_stan_model <- function(
8875
#' model$param_names()
8976
example_gaussian_stan_model <- function(n_data = 50, seed = 1234L) {
9077
rlang::check_installed("bridgestan", reason = "to use this function")
91-
model_string <- "data {
92-
int<lower=0> N;
93-
vector[N] y;
78+
model_string <- "
79+
data {
80+
int<lower=0> N;
81+
vector[N] y;
9482
}
9583
parameters {
9684
real mu;
9785
real<lower=0> sigma;
9886
}
9987
model {
88+
mu ~ normal(0, 3);
89+
sigma ~ normal(0, 3);
10090
y ~ normal(mu, sigma);
10191
}"
10292
withr::with_seed(seed, y <- stats::rnorm(n_data))
10393
data_string <- sprintf('{"N": %i, "y": [%s]}', n_data, toString(y))
104-
model_file <- tempfile("gaussian", fileext = ".stan")
94+
model_file <- NULL # to avoid 'no visible binding for global variable' note
10595
withr::with_tempfile("model_file",
10696
{
10797
writeLines(model_string, model_file)
@@ -111,3 +101,48 @@ example_gaussian_stan_model <- function(n_data = 50, seed = 1234L) {
111101
fileext = ".stan"
112102
)
113103
}
104+
105+
#' Construct target distribution from a formula specifying log density.
106+
#'
107+
#' @param log_density_formula Formula for which right-hand side specifies
108+
#' expression for logarithm of (unnormalized) density of target distribution.
109+
#'
110+
#' @return A list with entries
111+
#' * `log_density`: A function to evaluate log density function for target
112+
#' distribution given current position vector.
113+
#' * `value_and_gradient_log_density`: A function to evaluate value and gradient
114+
#' of log density function for target distribution given current position
115+
#' vector, returning as a list with entries `value` and `gradient`.
116+
#'
117+
#' @export
118+
#'
119+
#' @examples
120+
#' target_distribution <- target_distribution_from_log_density_formula(
121+
#' ~ (-(x^2 + y^2) / 8 - (x^2 - y)^2 - (x - 1)^2 / 10)
122+
#' )
123+
#' target_distribution$value_and_gradient_log_density(c(0.1, -0.3))
124+
target_distribution_from_log_density_formula <- function(log_density_formula) {
125+
variables <- all.vars(log_density_formula)
126+
deriv_log_density <- stats::deriv(log_density_formula, variables, func = TRUE)
127+
value_and_gradient_log_density <- function(position) {
128+
names(position) <- variables
129+
value <- rlang::inject(deriv_log_density(!!!position))
130+
gradient <- drop(attr(value, "gradient"))
131+
attr(value, "gradient") <- NULL
132+
list(value = value, gradient = gradient)
133+
}
134+
log_density <- function(position) {
135+
value_and_gradient_log_density(position)$value
136+
}
137+
trace_function <- function(state) {
138+
trace_values <- state$position()
139+
names(trace_values) <- variables
140+
trace_values["log_density"] <- log_density(state$position())
141+
trace_values
142+
}
143+
list(
144+
log_density = log_density,
145+
value_and_gradient_log_density = value_and_gradient_log_density,
146+
trace_function = trace_function
147+
)
148+
}

R/chains.R

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,31 @@
44
#' target distribution and proposal (defaulting to Barker proposal), optionally
55
#' adapting proposal parameters in a warm-up stage.
66
#'
7-
#' @inheritParams sample_metropolis_hastings
7+
#' @param target_distribution Target stationary distribution for chain. One of:
8+
#' * A one-sided formula specifying expression for log density of target
9+
#' distribution which will be passed to
10+
#' [target_distribution_from_log_density_formula()] to construct functions
11+
#' to evaluate log density and its gradient using [deriv()].
12+
#' * A `bridgestan::StanModel` instance (requires `bridgestan` to be
13+
#' installed) specifying target model and data. Will be passed to
14+
#' [target_distribution_from_stan_model()] using default values for optional
15+
#' arguments - to override call [target_distribution_from_stan_model()]
16+
#' directly and pass the returned list as the `target_distribution` argument
17+
#' here.
18+
#' * A list with named entries `log_density` and `gradient_log_density`
19+
#' corresponding to respectively functions for evaluating the logarithm of
20+
#' the (potentially unnormalized) density of the target distribution and its
21+
#' gradient (only required for gradient-based proposals). As an alternative
22+
#' to `gradient_log_density` an entry `value_and_gradient_log_density` may
23+
#' instead be provided which is a function returning both the value and
24+
#' gradient of the logarithm of the (unnormalized) density of the target
25+
#' distribution as a list under the names `value` and `gradient`
26+
#' respectively. The list may also contain a named entry `trace_function`,
27+
#' correspond to a function which given current chain state outputs a named
28+
#' vector or list of variables to trace on each main (non-adaptive) chain
29+
#' iteration. If a `trace_function` entry is not specified, then the default
30+
#' behaviour is to trace the position component of the chain state along
31+
#' with the log density of the target distribution.
832
#' @param initial_state Initial chain state. Either a vector specifying just
933
#' the position component of the chain state or a list output by `chain_state`
1034
#' specifying the full chain state.
@@ -32,8 +56,6 @@
3256
#' coerce the average acceptance rate to a target value using a dual-averaging
3357
#' algorithm, and adapting the shape to an estimate of the covariance of the
3458
#' target distribution.
35-
#' @param trace_function Function which given current chain state outputs list
36-
#' of variables to trace on each main (non-adaptive) chain iteration.
3759
#' @param show_progress_bar Whether to show progress bars during sampling.
3860
#' Requires `progress` package to be installed to have an effect.
3961
#' @param trace_warm_up Whether to record chain traces and adaptation /
@@ -78,7 +100,6 @@ sample_chain <- function(
78100
n_main_iteration,
79101
proposal = barker_proposal(),
80102
adapters = list(scale_adapter(), shape_adapter()),
81-
trace_function = NULL,
82103
show_progress_bar = TRUE,
83104
trace_warm_up = FALSE) {
84105
progress_available <- requireNamespace("progress", quietly = TRUE)
@@ -90,8 +111,24 @@ sample_chain <- function(
90111
} else {
91112
stop("initial_state must be a vector or list with an entry named position.")
92113
}
93-
if (is.null(trace_function)) {
114+
if (inherits(target_distribution, "formula")) {
115+
target_distribution <- target_distribution_from_log_density_formula(
116+
target_distribution
117+
)
118+
} else if (inherits(target_distribution, "StanModel")) {
119+
target_distribution <- target_distribution_from_stan_model(
120+
target_distribution
121+
)
122+
} else if (
123+
!is.list(target_distribution) ||
124+
!("log_density" %in% names(target_distribution))
125+
) {
126+
stop("target_distribution invalid - see documentation for allowable types.")
127+
}
128+
if (is.null(target_distribution$trace_function)) {
94129
trace_function <- default_trace_function(target_distribution)
130+
} else {
131+
trace_function <- target_distribution$trace_function
95132
}
96133
statistic_names <- list("accept_prob")
97134
warm_up_results <- chain_loop(

README.Rmd

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -72,36 +72,22 @@ cat(
7272
```
7373

7474
As a second example, the snippet below demonstrates sampling from a two-dimensional banana shaped distribution based on the [Rosenbrock function](https://en.wikipedia.org/wiki/Rosenbrock_function) and plotting the generated chain samples.
75-
Here we use the default values of the `proposal` and `adapters` arguments to `sample_chain`,
75+
Here we use the default values of the `proposal` and `adapters` arguments to `sample_chain()`,
7676
corresponding respectively to the Barker proposal, and adapters for tuning the proposal scale to coerce the average acceptance rate using a dual-averaging algorithm,
7777
and for tuning the proposal shape based on an estimate of the target distribution covariance matrix.
78+
The `target_distribution` argument to `sample_chain()` is passed a formula specifying the log density of the target distribution, which is passed to `target_distribution_from_log_density_formula()` to construct necessary functions,
79+
using `stats::deriv()` to symbolically compute derivatives.
7880

7981

8082
```{r banana-samples, fig.width=6, fig.height=4}
8183
library(rmcmc)
8284
8385
set.seed(651239L)
84-
target_distribution <- list(
85-
log_density = function(x) -sum(x^2) / 8 - (x[1]^2 - x[2])^2 - (x[1] - 1)^2 / 10,
86-
gradient_log_density = function(x) {
87-
c(
88-
-x[1] / 4 + 4 * x[1] * (x[2] - x[1]^2) - 0.2 * x[1] + 0.2,
89-
-x[2] / 4 + 2 * x[1]^2 - 2 * x[2]
90-
)
91-
}
92-
)
9386
results <- sample_chain(
94-
target_distribution = target_distribution,
87+
target_distribution = ~ (-(x^2 + y^2) / 8 - (x^2 - y)^2 - (x - 1)^2 / 100),
9588
initial_state = rnorm(2),
9689
n_warm_up_iteration = 10000,
97-
n_main_iteration = 10000,
98-
)
99-
plot(
100-
results$traces[, "position1"],
101-
results$traces[, "position2"],
102-
xlab = expression(x[1]),
103-
ylab = expression(x[2]),
104-
col = "#1f77b4",
105-
pch = 20
90+
n_main_iteration = 10000
10691
)
92+
plot(results$traces[, "x"], results$traces[, "y"], col = "#1f77b4", pch = 20)
10793
```

README.md

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -76,39 +76,28 @@ As a second example, the snippet below demonstrates sampling from a
7676
two-dimensional banana shaped distribution based on the [Rosenbrock
7777
function](https://en.wikipedia.org/wiki/Rosenbrock_function) and
7878
plotting the generated chain samples. Here we use the default values of
79-
the `proposal` and `adapters` arguments to `sample_chain`, corresponding
80-
respectively to the Barker proposal, and adapters for tuning the
81-
proposal scale to coerce the average acceptance rate using a
79+
the `proposal` and `adapters` arguments to `sample_chain()`,
80+
corresponding respectively to the Barker proposal, and adapters for
81+
tuning the proposal scale to coerce the average acceptance rate using a
8282
dual-averaging algorithm, and for tuning the proposal shape based on an
83-
estimate of the target distribution covariance matrix.
83+
estimate of the target distribution covariance matrix. The
84+
`target_distribution` argument to `sample_chain()` is passed a formula
85+
specifying the log density of the target distribution, which is passed
86+
to `target_distribution_from_log_density_formula()` to construct
87+
necessary functions, using `stats::deriv()` to symbolically compute
88+
derivatives.
8489

8590
``` r
8691
library(rmcmc)
8792

8893
set.seed(651239L)
89-
target_distribution <- list(
90-
log_density = function(x) -sum(x^2) / 8 - (x[1]^2 - x[2])^2 - (x[1] - 1)^2 / 10,
91-
gradient_log_density = function(x) {
92-
c(
93-
-x[1] / 4 + 4 * x[1] * (x[2] - x[1]^2) - 0.2 * x[1] + 0.2,
94-
-x[2] / 4 + 2 * x[1]^2 - 2 * x[2]
95-
)
96-
}
97-
)
9894
results <- sample_chain(
99-
target_distribution = target_distribution,
95+
target_distribution = ~ (-(x^2 + y^2) / 8 - (x^2 - y)^2 - (x - 1)^2 / 100),
10096
initial_state = rnorm(2),
10197
n_warm_up_iteration = 10000,
102-
n_main_iteration = 10000,
103-
)
104-
plot(
105-
results$traces[, "position1"],
106-
results$traces[, "position2"],
107-
xlab = expression(x[1]),
108-
ylab = expression(x[2]),
109-
col = "#1f77b4",
110-
pch = 20
98+
n_main_iteration = 10000
11199
)
100+
plot(results$traces[, "x"], results$traces[, "y"], col = "#1f77b4", pch = 20)
112101
```
113102

114103
<img src="man/figures/README-banana-samples-1.png" width="100%" />

man/example_gaussian_stan_model.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
4.44 KB
Loading

0 commit comments

Comments
 (0)