Skip to content

Commit 29cedd4

Browse files
authored
Improvements to sample_chain interface (#43)
* Copy initial state to avoid mutating * Don't require passing proposal to adapter initializers * Update adapter initialise syntax in vignette * Remove unused parameter docstring * Cover default target_accept_prob logic in tests * Add test for adapter interface
1 parent 04ef350 commit 29cedd4

File tree

10 files changed

+133
-97
lines changed

10 files changed

+133
-97
lines changed

R/adaptation.R

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
#' Create object to adapt proposal scale to coerce average acceptance rate.
22
#'
3-
#' @param proposal Proposal object to adapt. Must define an `update` function
4-
#' which accepts a parameter `scale` for setting scale parameter of proposal.
53
#' @param initial_scale Initial value to use for scale parameter. If not set
64
#' explicitly a proposal and dimension dependent default will be used.
75
#' @param target_accept_prob Target value for average accept probability for
86
#' chain. If not set a proposal dependent default will be used.
97
#' @param kappa Decay rate exponent in `[0.5, 1]` for adaptation learning rate.
108
#'
119
#' @return List of functions with entries
12-
#' * `initialize`, a function for initializing adapter state at beginning of
13-
#' chain,
10+
#' * `initialize`, a function for initializing adapter state and proposal
11+
#' parameters at beginning of chain,
1412
#' * `update` a function for updating adapter state and proposal parameters on
1513
#' each chain iteration,
1614
#' * `finalize` a function for performing any final updates to adapter state and
@@ -27,24 +25,22 @@
2725
#' grad_log_density = function(x) -x
2826
#' )
2927
#' proposal <- barker_proposal(target_distribution)
30-
#' adapter <- scale_adapter(
31-
#' proposal,
32-
#' initial_scale = 1., target_accept_prob = 0.4
33-
#' )
28+
#' adapter <- scale_adapter(initial_scale = 1., target_accept_prob = 0.4)
29+
#' adapter$initialize(proposal, chain_state(c(0, 0)))
3430
scale_adapter <- function(
35-
proposal, initial_scale = NULL, target_accept_prob = NULL, kappa = 0.6) {
31+
initial_scale = NULL, target_accept_prob = NULL, kappa = 0.6) {
3632
log_scale <- NULL
37-
if (is.null(target_accept_prob)) {
38-
target_accept_prob <- proposal$default_target_accept_prob()
39-
}
40-
initialize <- function(initial_state) {
33+
initialize <- function(proposal, initial_state) {
4134
if (is.null(initial_scale)) {
4235
initial_scale <- proposal$default_initial_scale(initial_state$dimension())
4336
}
4437
log_scale <<- log(initial_scale)
4538
proposal$update(scale = initial_scale)
4639
}
47-
update <- function(sample_index, state_and_statistics) {
40+
update <- function(proposal, sample_index, state_and_statistics) {
41+
if (is.null(target_accept_prob)) {
42+
target_accept_prob <- proposal$default_target_accept_prob()
43+
}
4844
gamma <- sample_index^(-kappa)
4945
accept_prob <- state_and_statistics$statistics$accept_prob
5046
log_scale <<- log_scale + gamma * (accept_prob - target_accept_prob)
@@ -53,16 +49,14 @@ scale_adapter <- function(
5349
list(
5450
initialize = initialize,
5551
update = update,
56-
finalize = function() {},
52+
finalize = NULL,
5753
state = function() list(log_scale = log_scale)
5854
)
5955
}
6056

6157
#' Create object to adapt proposal with per dimension scales based on estimates
6258
#' of target distribution variances.
6359
#'
64-
#' @param proposal Proposal object to adapt. Must define an `update` function
65-
#' which accepts a parameter `shape` for setting shape parameter of proposal.
6660
#' @param kappa Decay rate exponent in `[0.5, 1]` for adaptation learning rate.
6761
#'
6862
#' @inherit scale_adapter return
@@ -74,15 +68,16 @@ scale_adapter <- function(
7468
#' grad_log_density = function(x) -x
7569
#' )
7670
#' proposal <- barker_proposal(target_distribution)
77-
#' adapter <- variance_adapter(proposal)
78-
variance_adapter <- function(proposal, kappa = 0.6) {
71+
#' adapter <- variance_adapter()
72+
#' adapter$initialize(proposal, chain_state(c(0, 0)))
73+
variance_adapter <- function(kappa = 0.6) {
7974
mean_estimate <- NULL
8075
variance_estimate <- NULL
81-
initialize <- function(initial_state) {
76+
initialize <- function(proposal, initial_state) {
8277
mean_estimate <<- initial_state$position()
8378
variance_estimate <<- rep(1., initial_state$dimension())
8479
}
85-
update <- function(sample_index, state_and_statistics) {
80+
update <- function(proposal, sample_index, state_and_statistics) {
8681
gamma <- sample_index^(-kappa)
8782
position <- state_and_statistics$state$position()
8883
mean_estimate <<- mean_estimate + gamma * (position - mean_estimate)
@@ -124,20 +119,23 @@ variance_adapter <- function(proposal, kappa = 0.6) {
124119
#' grad_log_density = function(x) -x
125120
#' )
126121
#' proposal <- barker_proposal(target_distribution)
127-
#' adapter <- robust_shape_adapter(
128-
#' proposal,
129-
#' initial_scale = 1.,
130-
#' target_accept_prob = 0.4
131-
#' )
122+
#' adapter <- robust_shape_adapter(initial_scale = 1., target_accept_prob = 0.4)
123+
#' adapter$initialize(proposal, chain_state(c(0, 0)))
132124
robust_shape_adapter <- function(
133-
proposal, initial_scale, target_accept_prob = 0.4, kappa = 0.6) {
125+
initial_scale = NULL, target_accept_prob = NULL, kappa = 0.6) {
134126
rlang::check_installed("ramcmc", reason = "to use this function")
135127
shape <- NULL
136-
initialize <- function(initial_state) {
128+
initialize <- function(proposal, initial_state) {
129+
if (is.null(initial_scale)) {
130+
initial_scale <- proposal$default_initial_scale(initial_state$dimension())
131+
}
137132
shape <<- initial_scale * diag(initial_state$dimension())
138133
proposal$update(shape = shape)
139134
}
140-
update <- function(sample_index, state_and_statistics) {
135+
update <- function(proposal, sample_index, state_and_statistics) {
136+
if (is.null(target_accept_prob)) {
137+
target_accept_prob <- proposal$default_target_accept_prob()
138+
}
141139
momentum <- state_and_statistics$proposed_state$momentum()
142140
accept_prob <- state_and_statistics$statistics$accept_prob
143141
shape <<- ramcmc::adapt_S(

R/chains.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ sample_chain <- function(
7272
if (is.vector(initial_state) && is.atomic(initial_state)) {
7373
state <- chain_state(initial_state)
7474
} else if (is.vector(initial_state) && "position" %in% names(initial_state)) {
75-
state <- initial_state
75+
state <- initial_state$copy()
7676
} else {
7777
stop("initial_state must be a vector or list with an entry named position.")
7878
}
@@ -163,7 +163,7 @@ chain_loop <- function(
163163
statistic_names) {
164164
progress_bar <- get_progress_bar(use_progress_bar, n_iteration, stage_name)
165165
for (adapter in adapters) {
166-
adapter$initialize(state)
166+
adapter$initialize(proposal, state)
167167
}
168168
if (record_traces_and_statistics) {
169169
trace_names <- names(unlist(trace_function(state)))
@@ -176,25 +176,25 @@ chain_loop <- function(
176176
traces <- NULL
177177
statistics <- NULL
178178
}
179-
for (s in seq_len(n_iteration)) {
179+
for (chain_iteration in seq_len(n_iteration)) {
180180
state_and_statistics <- sample_metropolis_hastings(
181181
state, target_distribution, proposal
182182
)
183183
for (adapter in adapters) {
184-
adapter$update(s + 1, state_and_statistics)
184+
adapter$update(proposal, chain_iteration + 1, state_and_statistics)
185185
}
186186
state <- state_and_statistics$state
187187
if (record_traces_and_statistics) {
188-
traces[s, ] <- unlist(trace_function(state))
188+
traces[chain_iteration, ] <- unlist(trace_function(state))
189189
adapter_states <- lapply(adapters, function(a) a$state())
190-
statistics[s, ] <- unlist(
190+
statistics[chain_iteration, ] <- unlist(
191191
c(state_and_statistics$statistics, adapter_states)
192192
)
193193
}
194194
if (!is.null(progress_bar)) progress_bar$tick()
195195
}
196196
for (adapter in adapters) {
197-
if (!is.null(adapter$finalize)) adapter$finalize()
197+
if (!is.null(adapter$finalize)) adapter$finalize(proposal)
198198
}
199199
list(final_state = state, traces = traces, statistics = statistics)
200200
}

README.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ results <- sample_chain(
5959
initial_state = rnorm(dimension),
6060
n_warm_up_iteration = 1000,
6161
n_main_iteration = 1000,
62-
adapters = list(scale_adapter(proposal), variance_adapter(proposal))
62+
adapters = list(scale_adapter(), variance_adapter())
6363
)
6464
mean_accept_prob <- mean(results$statistics[, "accept_prob"])
6565
adapted_shape <- proposal$parameters()$shape

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ results <- sample_chain(
5656
initial_state = rnorm(dimension),
5757
n_warm_up_iteration = 1000,
5858
n_main_iteration = 1000,
59-
adapters = list(scale_adapter(proposal), variance_adapter(proposal))
59+
adapters = list(scale_adapter(), variance_adapter())
6060
)
6161
mean_accept_prob <- mean(results$statistics[, "accept_prob"])
6262
adapted_shape <- proposal$parameters()$shape

man/robust_shape_adapter.Rd

Lines changed: 6 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/scale_adapter.Rd

Lines changed: 5 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/variance_adapter.Rd

Lines changed: 5 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)