Skip to content

Commit

Permalink
Allow passing only position component for initial state (#39)
Browse files Browse the repository at this point in the history
* Allow passing only position component for initial state

* Make sample_chain robust to zero iterations

* Add additional sample_chains test cases
  • Loading branch information
matt-graham authored Oct 21, 2024
1 parent 930dbd6 commit 330c39b
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 47 deletions.
15 changes: 12 additions & 3 deletions R/chains.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
#' stage.
#'
#' @inheritParams sample_metropolis_hastings
#' @param initial_state Initial chain state.
#' @param initial_state Initial chain state. Either a vector specifying just
#' the position component of the chain state or a list output by `chain_state`
#' specifying the full chain state.
#' @param n_warm_up_iteration Number of warm-up (adaptive) chain iterations to
#' run.
#' @param n_main_iteration Number of main (non-adaptive) chain iterations to
Expand Down Expand Up @@ -67,14 +69,21 @@ sample_chain <- function(
trace_warm_up = FALSE) {
progress_available <- requireNamespace("progress", quietly = TRUE)
use_progress_bar <- progress_available && show_progress_bar
if (is.vector(initial_state) && is.atomic(initial_state)) {
state <- chain_state(initial_state)
} else if (is.vector(initial_state) && "position" %in% names(initial_state)) {
state <- initial_state
} else {
stop("initial_state must be a vector or list with an entry named position.")
}
if (is.null(trace_function)) {
trace_function <- default_trace_function(target_distribution)
}
statistic_names <- list("accept_prob")
warm_up_results <- chain_loop(
stage_name = "Warm-up",
n_iteration = n_warm_up_iteration,
state = initial_state,
state = state,
target_distribution = target_distribution,
proposal = proposal,
adapters = adapters,
Expand Down Expand Up @@ -167,7 +176,7 @@ chain_loop <- function(
traces <- NULL
statistics <- NULL
}
for (s in 1:n_iteration) {
for (s in seq_len(n_iteration)) {
state_and_statistics <- sample_metropolis_hastings(
state, target_distribution, proposal
)
Expand Down
4 changes: 3 additions & 1 deletion man/sample_chain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions tests/testthat/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,26 @@ expect_all_different <- function(object, different) {
invisible(act$val)
}

expect_nrow <- function(object, n) {
act <- quasi_label(rlang::enquo(object), arg = "object")
act$nrow <- nrow(object)
expect(
act$nrow == n,
sprintf("%s has %i rows not %i.", act$lab, act$nrow, n)
)
invisible(act$val)
}

expect_ncol <- function(object, n) {
act <- quasi_label(rlang::enquo(object), arg = "object")
act$ncol <- ncol(object)
expect(
act$ncol == n,
sprintf("%s has %i columns not %i.", act$lab, act$ncol, n)
)
invisible(act$val)
}

random_covariance_matrix <- function(dimension) {
temp <- matrix(rnorm(dimension^2), dimension, dimension)
temp %*% t(temp)
Expand Down
108 changes: 65 additions & 43 deletions tests/testthat/test-chains.R
Original file line number Diff line number Diff line change
@@ -1,48 +1,70 @@
for (n_warm_up_iteration in c(1, 100)) {
for (n_main_iteration in c(1, 100)) {
for (trace_warm_up in c(TRUE, FALSE)) {
test_that(
sprintf(
paste0(
"Sampling chain with %i warm-up iterations, %i main iterations, ",
"and trace_warm_up = %i works"
),
n_warm_up_iteration, n_main_iteration, trace_warm_up
),
{
dimension <- 3
target_distribution <- standard_normal_target_distribution()
barker_proposal(target_distribution)
proposal <- barker_proposal(target_distribution)
adapters <- list(
scale_adapter(proposal, initial_scale = 1.)
)
withr::with_seed(default_seed(), {
position <- rnorm(dimension)
})
initial_state <- chain_state(position)
results <- sample_chain(
target_distribution = target_distribution,
proposal = proposal,
initial_state = initial_state,
n_warm_up_iteration = n_warm_up_iteration,
n_main_iteration = n_main_iteration,
adapters = adapters,
trace_warm_up = trace_warm_up
)
expected_results_names <- c("final_state", "traces", "statistics")
if (trace_warm_up) {
expected_results_names <- c(
expected_results_names, "warm_up_traces", "warm_up_statistics"
)
}
expect_named(
results,
expected_results_names,
ignore.order = TRUE,
for (n_warm_up_iteration in c(0, 1, 10)) {
for (n_main_iteration in c(0, 1, 10)) {
for (dimension in c(1, 2)) {
for (trace_warm_up in c(TRUE, FALSE)) {
for (wrapped_initial_state in c(TRUE, FALSE)) {
test_that(
sprintf(
paste0(
"Sampling chain with %i warm-up iterations, %i main iterations",
" dimension %i, wrapped_initial_state = %i ",
"and trace_warm_up = %i works"
),
n_warm_up_iteration,
n_main_iteration,
dimension,
wrapped_initial_state,
trace_warm_up
),
{
target_distribution <- standard_normal_target_distribution()
barker_proposal(target_distribution)
proposal <- barker_proposal(target_distribution)
adapters <- list(
scale_adapter(proposal, initial_scale = 1.)
)
withr::with_seed(default_seed(), {
position <- rnorm(dimension)
})
if (wrapped_initial_state) {
initial_state <- chain_state(position)
} else {
initial_state <- position
}
results <- sample_chain(
target_distribution = target_distribution,
proposal = proposal,
initial_state = initial_state,
n_warm_up_iteration = n_warm_up_iteration,
n_main_iteration = n_main_iteration,
adapters = adapters,
trace_warm_up = trace_warm_up
)
expected_results_names <- c("final_state", "traces", "statistics")
if (trace_warm_up) {
expected_results_names <- c(
expected_results_names, "warm_up_traces", "warm_up_statistics"
)
}
expect_named(
results,
expected_results_names,
ignore.order = TRUE,
)
expect_nrow(results$traces, n_main_iteration)
expect_ncol(results$traces, dimension + 1)
expect_nrow(results$statistics, n_main_iteration)
expect_ncol(results$statistics, 1)
if (trace_warm_up) {
expect_nrow(results$warm_up_traces, n_warm_up_iteration)
expect_ncol(results$warm_up_traces, dimension + 1)
expect_nrow(results$warm_up_statistics, n_warm_up_iteration)
expect_ncol(results$warm_up_statistics, 2)
}
}
)
}
)
}
}
}
}

0 comments on commit 330c39b

Please sign in to comment.