Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for constructing target distribution from formula #66

Merged
merged 13 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ export(sample_chain)
export(scale_adapter)
export(shape_adapter)
export(stochastic_approximation_scale_adapter)
export(target_distribution_from_log_density_formula)
export(target_distribution_from_stan_model)
export(trace_function_from_stan_model)
export(variance_shape_adapter)
119 changes: 77 additions & 42 deletions R/bridges.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
#' Construct target distribution from a BridgeStan `StanModel` object.
#'
#' @param model Stan model object to use for target (posterior) distribution.
#' @param include_log_density Whether to include an entry `log_density`
#' corresponding to current log density for target distribution in values
#' returned by trace function.
#' @param include_generated_quantities Whether to included generated quantities
#' in Stan model definition in values returned by trace function.
#' @param include_transformed_parameters Whether to include transformed
#' parameters in Stan model definition in values returned by trace function.
#'
#' @return A list with entries
#' * `log_density`: A function to evaluate log density function for target
#' distribution given current position vector.
#' * `value_and_gradient_log_density`: A function to evaluate value and gradient
#' of log density function for target distribution given current position
#' vector, returning as a list with entries `value` and `gradient`.
#' * `trace_function`: A function which given a `chain_state()` object returns a
#' named vector of values to trace during sampling. The constrained parameter
#' values of model will always be included.
#'
#' @export
#'
Expand All @@ -18,46 +28,13 @@
#' 876287L, state <- chain_state(stats::rnorm(model$param_unc_num()))
#' )
#' state$log_density(target_distribution)
target_distribution_from_stan_model <- function(model) {
list(
log_density = model$log_density,
value_and_gradient_log_density = function(position) {
value_and_gradient <- model$log_density_gradient(position)
names(value_and_gradient) <- c("value", "gradient")
value_and_gradient
}
)
}

#' Construct trace function from a BridgeStan `StanModel` object.
#'
#' @param model Stan model object to use to generate (constrained) parameters to
#' trace.
#' @param include_log_density Whether to include an entry `log_density`
#' corresponding to current log density for target distribution in values
#' returned by trace function.
#' @param include_generated_quantities Whether to included generated quantities
#' in Stan model definition in values returned by trace function.
#' @param include_transformed_parameters Whether to include transformed
#' parameters in Stan model definition in values returned by trace function.
#'
#' @return A function which given `chain_state` object returns a named vector of
#' values to trace during sampling. The constrained parameter values of model
#' will always be included.
#'
#' @export
#'
#' @examplesIf requireNamespace("bridgestan", quietly = TRUE)
#' model <- example_gaussian_stan_model()
#' trace_function <- trace_function_from_stan_model(model)
#' withr::with_seed(876287L, state <- chain_state(rnorm(model$param_unc_num())))
#' trace_function(state)
trace_function_from_stan_model <- function(
#' target_distribution$trace_function(state)
target_distribution_from_stan_model <- function(
model,
include_log_density = TRUE,
include_generated_quantities = FALSE,
include_transformed_parameters = FALSE) {
function(state) {
trace_function <- function(state) {
position <- state$position()
trace_values <- model$param_constrain(
position, include_transformed_parameters, include_generated_quantities
Expand All @@ -68,15 +45,25 @@ trace_function_from_stan_model <- function(
}
trace_values
}
list(
log_density = model$log_density,
value_and_gradient_log_density = function(position) {
value_and_gradient <- model$log_density_gradient(position)
names(value_and_gradient) <- c("value", "gradient")
value_and_gradient
},
trace_function = trace_function
)
}

#' Construct an example BridgeStan `StanModel` object for a Gaussian model.
#'
#' Requires BridgeStan package to be installed. Generative model is assumed to
#' be of the form `y ~ normal(mu, sigma)` for unknown `mu` and `sigma`.
#' be of the form `y ~ normal(mu, sigma)` for unknown `mu ~ normal(0, 3)` and
#' `sigma ~ half_normal(0, 3)`.
#'
#' @param n_data Number of independent data points `y` to generate and condition
#' model against.
#' model against from `normal(0, 1)`.
#' @param seed Integer seed for Stan model.
#'
#' @return BridgeStan StanModel object.
Expand All @@ -88,20 +75,23 @@ trace_function_from_stan_model <- function(
#' model$param_names()
example_gaussian_stan_model <- function(n_data = 50, seed = 1234L) {
rlang::check_installed("bridgestan", reason = "to use this function")
model_string <- "data {
int<lower=0> N;
vector[N] y;
model_string <- "
data {
int<lower=0> N;
vector[N] y;
}
parameters {
real mu;
real<lower=0> sigma;
}
model {
mu ~ normal(0, 3);
sigma ~ normal(0, 3);
y ~ normal(mu, sigma);
}"
withr::with_seed(seed, y <- stats::rnorm(n_data))
data_string <- sprintf('{"N": %i, "y": [%s]}', n_data, toString(y))
model_file <- tempfile("gaussian", fileext = ".stan")
model_file <- NULL # to avoid 'no visible binding for global variable' note
withr::with_tempfile("model_file",
{
writeLines(model_string, model_file)
Expand All @@ -111,3 +101,48 @@ example_gaussian_stan_model <- function(n_data = 50, seed = 1234L) {
fileext = ".stan"
)
}

#' Construct target distribution from a formula specifying log density.
#'
#' @param log_density_formula Formula for which right-hand side specifies
#' expression for logarithm of (unnormalized) density of target distribution.
#'
#' @return A list with entries
#' * `log_density`: A function to evaluate log density function for target
#' distribution given current position vector.
#' * `value_and_gradient_log_density`: A function to evaluate value and gradient
#' of log density function for target distribution given current position
#' vector, returning as a list with entries `value` and `gradient`.
#'
#' @export
#'
#' @examples
#' target_distribution <- target_distribution_from_log_density_formula(
#' ~ (-(x^2 + y^2) / 8 - (x^2 - y)^2 - (x - 1)^2 / 10)
#' )
#' target_distribution$value_and_gradient_log_density(c(0.1, -0.3))
target_distribution_from_log_density_formula <- function(log_density_formula) {
variables <- all.vars(log_density_formula)
deriv_log_density <- stats::deriv(log_density_formula, variables, func = TRUE)
value_and_gradient_log_density <- function(position) {
names(position) <- variables
value <- rlang::inject(deriv_log_density(!!!position))
gradient <- drop(attr(value, "gradient"))
attr(value, "gradient") <- NULL
list(value = value, gradient = gradient)
}
log_density <- function(position) {
value_and_gradient_log_density(position)$value
}
trace_function <- function(state) {
trace_values <- state$position()
names(trace_values) <- variables
trace_values["log_density"] <- log_density(state$position())
trace_values
}
list(
log_density = log_density,
value_and_gradient_log_density = value_and_gradient_log_density,
trace_function = trace_function
)
}
47 changes: 42 additions & 5 deletions R/chains.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,31 @@
#' target distribution and proposal (defaulting to Barker proposal), optionally
#' adapting proposal parameters in a warm-up stage.
#'
#' @inheritParams sample_metropolis_hastings
#' @param target_distribution Target stationary distribution for chain. One of:
#' * A one-sided formula specifying expression for log density of target
#' distribution which will be passed to
#' [target_distribution_from_log_density_formula()] to construct functions
#' to evaluate log density and its gradient using [deriv()].
#' * A `bridgestan::StanModel` instance (requires `bridgestan` to be
#' installed) specifying target model and data. Will be passed to
#' [target_distribution_from_stan_model()] using default values for optional
#' arguments - to override call [target_distribution_from_stan_model()]
#' directly and pass the returned list as the `target_distribution` argument
#' here.
#' * A list with named entries `log_density` and `gradient_log_density`
#' corresponding to respectively functions for evaluating the logarithm of
#' the (potentially unnormalized) density of the target distribution and its
#' gradient (only required for gradient-based proposals). As an alternative
#' to `gradient_log_density` an entry `value_and_gradient_log_density` may
#' instead be provided which is a function returning both the value and
#' gradient of the logarithm of the (unnormalized) density of the target
#' distribution as a list under the names `value` and `gradient`
#' respectively. The list may also contain a named entry `trace_function`,
#' correspond to a function which given current chain state outputs a named
#' vector or list of variables to trace on each main (non-adaptive) chain
#' iteration. If a `trace_function` entry is not specified, then the default
#' behaviour is to trace the position component of the chain state along
#' with the log density of the target distribution.
#' @param initial_state Initial chain state. Either a vector specifying just
#' the position component of the chain state or a list output by `chain_state`
#' specifying the full chain state.
Expand Down Expand Up @@ -32,8 +56,6 @@
#' coerce the average acceptance rate to a target value using a dual-averaging
#' algorithm, and adapting the shape to an estimate of the covariance of the
#' target distribution.
#' @param trace_function Function which given current chain state outputs list
#' of variables to trace on each main (non-adaptive) chain iteration.
#' @param show_progress_bar Whether to show progress bars during sampling.
#' Requires `progress` package to be installed to have an effect.
#' @param trace_warm_up Whether to record chain traces and adaptation /
Expand Down Expand Up @@ -78,7 +100,6 @@ sample_chain <- function(
n_main_iteration,
proposal = barker_proposal(),
adapters = list(scale_adapter(), shape_adapter()),
trace_function = NULL,
show_progress_bar = TRUE,
trace_warm_up = FALSE) {
progress_available <- requireNamespace("progress", quietly = TRUE)
Expand All @@ -90,8 +111,24 @@ sample_chain <- function(
} else {
stop("initial_state must be a vector or list with an entry named position.")
}
if (is.null(trace_function)) {
if (inherits(target_distribution, "formula")) {
target_distribution <- target_distribution_from_log_density_formula(
target_distribution
)
} else if (inherits(target_distribution, "StanModel")) {
target_distribution <- target_distribution_from_stan_model(
target_distribution
)
} else if (
!is.list(target_distribution) ||
!("log_density" %in% names(target_distribution))
) {
stop("target_distribution invalid - see documentation for allowable types.")
}
if (is.null(target_distribution$trace_function)) {
trace_function <- default_trace_function(target_distribution)
} else {
trace_function <- target_distribution$trace_function
}
statistic_names <- list("accept_prob")
warm_up_results <- chain_loop(
Expand Down
26 changes: 6 additions & 20 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -72,36 +72,22 @@ cat(
```

As a second example, the snippet below demonstrates sampling from a two-dimensional banana shaped distribution based on the [Rosenbrock function](https://en.wikipedia.org/wiki/Rosenbrock_function) and plotting the generated chain samples.
Here we use the default values of the `proposal` and `adapters` arguments to `sample_chain`,
Here we use the default values of the `proposal` and `adapters` arguments to `sample_chain()`,
corresponding respectively to the Barker proposal, and adapters for tuning the proposal scale to coerce the average acceptance rate using a dual-averaging algorithm,
and for tuning the proposal shape based on an estimate of the target distribution covariance matrix.
The `target_distribution` argument to `sample_chain()` is passed a formula specifying the log density of the target distribution, which is passed to `target_distribution_from_log_density_formula()` to construct necessary functions,
using `stats::deriv()` to symbolically compute derivatives.


```{r banana-samples, fig.width=6, fig.height=4}
library(rmcmc)

set.seed(651239L)
target_distribution <- list(
log_density = function(x) -sum(x^2) / 8 - (x[1]^2 - x[2])^2 - (x[1] - 1)^2 / 10,
gradient_log_density = function(x) {
c(
-x[1] / 4 + 4 * x[1] * (x[2] - x[1]^2) - 0.2 * x[1] + 0.2,
-x[2] / 4 + 2 * x[1]^2 - 2 * x[2]
)
}
)
results <- sample_chain(
target_distribution = target_distribution,
target_distribution = ~ (-(x^2 + y^2) / 8 - (x^2 - y)^2 - (x - 1)^2 / 100),
initial_state = rnorm(2),
n_warm_up_iteration = 10000,
n_main_iteration = 10000,
)
plot(
results$traces[, "position1"],
results$traces[, "position2"],
xlab = expression(x[1]),
ylab = expression(x[2]),
col = "#1f77b4",
pch = 20
n_main_iteration = 10000
)
plot(results$traces[, "x"], results$traces[, "y"], col = "#1f77b4", pch = 20)
```
35 changes: 12 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,39 +76,28 @@ As a second example, the snippet below demonstrates sampling from a
two-dimensional banana shaped distribution based on the [Rosenbrock
function](https://en.wikipedia.org/wiki/Rosenbrock_function) and
plotting the generated chain samples. Here we use the default values of
the `proposal` and `adapters` arguments to `sample_chain`, corresponding
respectively to the Barker proposal, and adapters for tuning the
proposal scale to coerce the average acceptance rate using a
the `proposal` and `adapters` arguments to `sample_chain()`,
corresponding respectively to the Barker proposal, and adapters for
tuning the proposal scale to coerce the average acceptance rate using a
dual-averaging algorithm, and for tuning the proposal shape based on an
estimate of the target distribution covariance matrix.
estimate of the target distribution covariance matrix. The
`target_distribution` argument to `sample_chain()` is passed a formula
specifying the log density of the target distribution, which is passed
to `target_distribution_from_log_density_formula()` to construct
necessary functions, using `stats::deriv()` to symbolically compute
derivatives.

``` r
library(rmcmc)

set.seed(651239L)
target_distribution <- list(
log_density = function(x) -sum(x^2) / 8 - (x[1]^2 - x[2])^2 - (x[1] - 1)^2 / 10,
gradient_log_density = function(x) {
c(
-x[1] / 4 + 4 * x[1] * (x[2] - x[1]^2) - 0.2 * x[1] + 0.2,
-x[2] / 4 + 2 * x[1]^2 - 2 * x[2]
)
}
)
results <- sample_chain(
target_distribution = target_distribution,
target_distribution = ~ (-(x^2 + y^2) / 8 - (x^2 - y)^2 - (x - 1)^2 / 100),
initial_state = rnorm(2),
n_warm_up_iteration = 10000,
n_main_iteration = 10000,
)
plot(
results$traces[, "position1"],
results$traces[, "position2"],
xlab = expression(x[1]),
ylab = expression(x[2]),
col = "#1f77b4",
pch = 20
n_main_iteration = 10000
)
plot(results$traces[, "x"], results$traces[, "y"], col = "#1f77b4", pch = 20)
```

<img src="man/figures/README-banana-samples-1.png" width="100%" />
5 changes: 3 additions & 2 deletions man/example_gaussian_stan_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified man/figures/README-banana-samples-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading