Skip to content

Commit 72321f1

Browse files
authored
Merge pull request #410 from mrc-ide/mrc-4307
2 parents 2e793db + 045467a commit 72321f1

25 files changed

+730
-7
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: dust
22
Title: Iterate Multiple Realisations of Stochastic Models
3-
Version: 0.14.10
3+
Version: 0.15.0
44
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
55
email = "[email protected]"),
66
person("Alex", "Hill", role = "aut"),
@@ -32,7 +32,7 @@ Imports:
3232
pkgload,
3333
withr
3434
LinkingTo:
35-
cpp11 (>= 0.4.0)
35+
cpp11 (>= 0.4.4)
3636
Suggests:
3737
bench,
3838
brio,

R/compile.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ glue_whisker <- function(template, data) {
137137
dust_template_data <- function(model, config, cuda, reload_data, linking_to,
138138
cpp_std, compiler_options, optimisation_level) {
139139
methods <- function(target) {
140-
nms <- c("alloc", "run", "simulate", "set_index", "n_state",
140+
nms <- c("alloc", "run", "simulate", "run_adjoint", "set_index", "n_state",
141141
"update_state", "state", "time", "reorder", "resample",
142142
"rng_state", "set_rng_state", "set_n_threads",
143143
"set_data", "compare_data", "filter", "set_stochastic_schedule",

R/cpp11.R

Lines changed: 32 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R/dust.R

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ logistic <- R6::R6Class(
3333
alloc = dust_ode_logistic_alloc,
3434
run = dust_ode_logistic_run,
3535
simulate = dust_ode_logistic_simulate,
36+
run_adjoint = dust_ode_logistic_run_adjoint,
3637
set_index = dust_ode_logistic_set_index,
3738
n_state = dust_ode_logistic_n_state,
3839
update_state = dust_ode_logistic_update_state,
@@ -91,6 +92,10 @@ logistic <- R6::R6Class(
9192
m
9293
},
9394

95+
run_adjoint = function() {
96+
private$methods_$run_adjoint(private$ptr_)
97+
},
98+
9499
set_index = function(index) {
95100
private$methods_$set_index(private$ptr_, index)
96101
private$index_ <- index
@@ -287,6 +292,7 @@ sir <- R6::R6Class(
287292
alloc = dust_cpu_sir_alloc,
288293
run = dust_cpu_sir_run,
289294
simulate = dust_cpu_sir_simulate,
295+
run_adjoint = dust_cpu_sir_run_adjoint,
290296
set_index = dust_cpu_sir_set_index,
291297
n_state = dust_cpu_sir_n_state,
292298
update_state = dust_cpu_sir_update_state,
@@ -345,6 +351,10 @@ sir <- R6::R6Class(
345351
m
346352
},
347353

354+
run_adjoint = function() {
355+
private$methods_$run_adjoint(private$ptr_)
356+
},
357+
348358
set_index = function(index) {
349359
private$methods_$set_index(private$ptr_, index)
350360
private$index_ <- index
@@ -541,6 +551,7 @@ sirs <- R6::R6Class(
541551
alloc = dust_cpu_sirs_alloc,
542552
run = dust_cpu_sirs_run,
543553
simulate = dust_cpu_sirs_simulate,
554+
run_adjoint = dust_cpu_sirs_run_adjoint,
544555
set_index = dust_cpu_sirs_set_index,
545556
n_state = dust_cpu_sirs_n_state,
546557
update_state = dust_cpu_sirs_update_state,
@@ -561,6 +572,7 @@ sirs <- R6::R6Class(
561572
alloc = dust_gpu_sirs_alloc,
562573
run = dust_gpu_sirs_run,
563574
simulate = dust_gpu_sirs_simulate,
575+
run_adjoint = dust_gpu_sirs_run_adjoint,
564576
set_index = dust_gpu_sirs_set_index,
565577
n_state = dust_gpu_sirs_n_state,
566578
update_state = dust_gpu_sirs_update_state,
@@ -615,6 +627,10 @@ sirs <- R6::R6Class(
615627
m
616628
},
617629

630+
run_adjoint = function() {
631+
private$methods_$run_adjoint(private$ptr_)
632+
},
633+
618634
set_index = function(index) {
619635
private$methods_$set_index(private$ptr_, index)
620636
private$index_ <- index
@@ -808,6 +824,7 @@ variable <- R6::R6Class(
808824
alloc = dust_cpu_variable_alloc,
809825
run = dust_cpu_variable_run,
810826
simulate = dust_cpu_variable_simulate,
827+
run_adjoint = dust_cpu_variable_run_adjoint,
811828
set_index = dust_cpu_variable_set_index,
812829
n_state = dust_cpu_variable_n_state,
813830
update_state = dust_cpu_variable_update_state,
@@ -828,6 +845,7 @@ variable <- R6::R6Class(
828845
alloc = dust_gpu_variable_alloc,
829846
run = dust_gpu_variable_run,
830847
simulate = dust_gpu_variable_simulate,
848+
run_adjoint = dust_gpu_variable_run_adjoint,
831849
set_index = dust_gpu_variable_set_index,
832850
n_state = dust_gpu_variable_n_state,
833851
update_state = dust_gpu_variable_update_state,
@@ -882,6 +900,10 @@ variable <- R6::R6Class(
882900
m
883901
},
884902

903+
run_adjoint = function() {
904+
private$methods_$run_adjoint(private$ptr_)
905+
},
906+
885907
set_index = function(index) {
886908
private$methods_$set_index(private$ptr_, index)
887909
private$index_ <- index
@@ -1075,6 +1097,7 @@ volatility <- R6::R6Class(
10751097
alloc = dust_cpu_volatility_alloc,
10761098
run = dust_cpu_volatility_run,
10771099
simulate = dust_cpu_volatility_simulate,
1100+
run_adjoint = dust_cpu_volatility_run_adjoint,
10781101
set_index = dust_cpu_volatility_set_index,
10791102
n_state = dust_cpu_volatility_n_state,
10801103
update_state = dust_cpu_volatility_update_state,
@@ -1133,6 +1156,10 @@ volatility <- R6::R6Class(
11331156
m
11341157
},
11351158

1159+
run_adjoint = function() {
1160+
private$methods_$run_adjoint(private$ptr_)
1161+
},
1162+
11361163
set_index = function(index) {
11371164
private$methods_$set_index(private$ptr_, index)
11381165
private$index_ <- index
@@ -1326,6 +1353,7 @@ walk <- R6::R6Class(
13261353
alloc = dust_cpu_walk_alloc,
13271354
run = dust_cpu_walk_run,
13281355
simulate = dust_cpu_walk_simulate,
1356+
run_adjoint = dust_cpu_walk_run_adjoint,
13291357
set_index = dust_cpu_walk_set_index,
13301358
n_state = dust_cpu_walk_n_state,
13311359
update_state = dust_cpu_walk_update_state,
@@ -1384,6 +1412,10 @@ walk <- R6::R6Class(
13841412
m
13851413
},
13861414

1415+
run_adjoint = function() {
1416+
private$methods_$run_adjoint(private$ptr_)
1417+
},
1418+
13871419
set_index = function(index) {
13881420
private$methods_$set_index(private$ptr_, index)
13891421
private$index_ <- index

R/dust_generator.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,14 @@ dust_generator <- R6::R6Class(
176176
simulate = function(time_end) {
177177
},
178178

179+
##' @description
180+
##'
181+
##' Run model with gradient information (if supported). The
182+
##' interface here will change, and documentation written once it
183+
##' stabilises.
184+
run_adjoint = function() {
185+
},
186+
179187
##' @description
180188
##' Set the "index" vector that is used to return a subset of pars
181189
##' after using `run()`. If this is not used then `run()` returns

0 commit comments

Comments
 (0)