Skip to content

Commit

Permalink
Avoid Matrix::diagonal and limit progress bar updates (#47)
Browse files Browse the repository at this point in the history
* Use diag rather than Matrix::diagonal

* Limit progress bar updates

* Increase maximum complexity in lintr config
  • Loading branch information
matt-graham authored Oct 25, 2024
1 parent 5d29b85 commit 4321d29
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .lintr
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
linters: linters_with_defaults(
line_length_linter = line_length_linter(88L),
object_length_linter = object_length_linter(length=50L)
object_length_linter = object_length_linter(length=50L),
cyclocomp_linter = cyclocomp_linter(complexity_limit = 25L)
)
encoding: "UTF-8"
exclusions: list(
Expand Down
10 changes: 9 additions & 1 deletion R/chains.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ chain_loop <- function(
trace_function,
statistic_names) {
progress_bar <- get_progress_bar(use_progress_bar, n_iteration, stage_name)
# Only show 10% increments in progress bar to avoid progress bar updates being
# a bottleneck when chain iteration rate is high
tick_amount <- max(n_iteration %/% 10, 1)
for (adapter in adapters) {
adapter$initialize(proposal, state)
}
Expand Down Expand Up @@ -191,8 +194,13 @@ chain_loop <- function(
c(state_and_statistics$statistics, adapter_states)
)
}
if (!is.null(progress_bar)) progress_bar$tick()
if (!is.null(progress_bar) && (chain_iteration %% tick_amount == 0)) {
progress_bar$tick(tick_amount)
}
}
# Ensure progress bar shows completed in cases tick_amount not a factor of
# n_iteration
if (!is.null(progress_bar) && !progress_bar$finished) progress_bar$update(1)
for (adapter in adapters) {
if (!is.null(adapter$finalize)) adapter$finalize(proposal)
}
Expand Down
2 changes: 1 addition & 1 deletion R/proposal.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ get_shape_matrix <- function(scale, shape) {
stop("Scale should be a non-negative scalar")
}
if (!is.null(shape) && is_non_scalar_vector(shape)) {
shape <- Matrix::Diagonal(x = shape)
shape <- diag(shape)
}
if (is.null(scale) && is.null(shape)) {
stop("One of scale and shape parameters must be set")
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-proposal.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ test_that("get_shape_matrix works", {
expect_identical(get_shape_matrix(NULL, 1), 1)
expect_identical(get_shape_matrix(0.5, 3), 1.5)
expect_identical(get_shape_matrix(0.5, 3), 1.5)
expect_identical(get_shape_matrix(2, c(3, 0.5)), Matrix::Diagonal(x = c(6, 1)))
expect_identical(get_shape_matrix(NULL, c(3, 2)), Matrix::Diagonal(x = c(3, 2)))
expect_identical(get_shape_matrix(2, c(3, 0.5)), diag(x = c(6, 1)))
expect_identical(get_shape_matrix(NULL, c(3, 2)), diag(x = c(3, 2)))
expect_identical(get_shape_matrix(0.5, diag(3, 2)), diag(1.5, 2))
expect_identical(get_shape_matrix(NULL, diag(3)), diag(3))
expect_error(get_shape_matrix(NULL, NULL), "must be set")
Expand Down

0 comments on commit 4321d29

Please sign in to comment.