Skip to content

Commit 79831c3

Browse files
authored
stan-to-r (#973)
1 parent da54618 commit 79831c3

File tree

1 file changed

+210
-0
lines changed

1 file changed

+210
-0
lines changed

inst/dev/stan-to-R.R

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
## call from browser inside initialisation
2+
files <- c(
3+
"convolve.stan", "gaussian_process.stan", "pmfs.stan",
4+
"observation_model.stan", "secondary.stan", "params.stan",
5+
"rt.stan", "infections.stan", "delays.stan", "generated_quantities.stan"
6+
)
7+
suppressMessages(
8+
expose_stan_fns(files,
9+
target_dir = system.file("stan/functions", package = "EpiNow2")
10+
)
11+
)
12+
13+
simulate <- function(data,
14+
generation_time = gt_opts(),
15+
delays = delay_opts(),
16+
truncation = trunc_opts(),
17+
rt = rt_opts(),
18+
backcalc = backcalc_opts(),
19+
gp = gp_opts(),
20+
obs = obs_opts(),
21+
forecast = forecast_opts(),
22+
stan = stan_opts(),
23+
inits = NULL) {
24+
25+
seeding_time <- get_seeding_time(delays, generation_time, rt)
26+
27+
reported_cases <- default_fill_missing_obs(data, obs, "confirm")
28+
if (forecast$horizon > 0) {
29+
reported_cases <- add_horizon(
30+
reported_cases, forecast$horizon, forecast$accumulate
31+
)
32+
}
33+
reported_cases <- create_clean_reported_cases(
34+
reported_cases,
35+
filter_leading_zeros = TRUE,
36+
zero_threshold = Inf
37+
)
38+
reported_cases <- data.table::rbindlist(list(
39+
data.table::data.table(
40+
date = seq(
41+
min(reported_cases$date) - seeding_time - backcalc$prior_window,
42+
min(reported_cases$date) - 1,
43+
by = "days"
44+
),
45+
confirm = 0, accumulate = FALSE, breakpoint = 0
46+
),
47+
reported_cases[, .(date, confirm, accumulate, breakpoint)]
48+
))
49+
shifted_cases <- create_shifted_cases(
50+
reported_cases,
51+
seeding_time,
52+
backcalc$prior_window,
53+
forecast$horizon
54+
)
55+
reported_cases <- reported_cases[-(1:backcalc$prior_window)]
56+
57+
# Define stan model parameters
58+
stan_data <- create_stan_data(
59+
reported_cases,
60+
seeding_time = seeding_time,
61+
rt = rt,
62+
gp = gp,
63+
obs = obs,
64+
backcalc = backcalc,
65+
shifted_cases = shifted_cases$confirm,
66+
forecast = forecast
67+
)
68+
69+
stan_data <- c(stan_data, create_stan_delays(
70+
gt = generation_time,
71+
delay = delays,
72+
trunc = truncation,
73+
time_points = stan_data$t - stan_data$seeding_time - stan_data$horizon
74+
))
75+
76+
if (is.null(inits)) {
77+
init <- create_initial_conditions(stan_data)
78+
inits <- init()
79+
} else {
80+
if (stan_data$bp_n == 0) {
81+
inits$bp_sd <- array(numeric(0))
82+
inits$bp_effects <- array(numeric(0))
83+
}
84+
}
85+
for (n in names(inits)) assign(n, inits[[n]])
86+
for (n in names(stan_data)) assign(n, stan_data[[n]])
87+
88+
ot <- t - seeding_time - horizon
89+
ot_h <- ot + horizon
90+
noise_terms <- setup_noise(
91+
ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from
92+
)
93+
PHI <- setup_gp(M, L, noise_terms, gp_type == 1, w0)
94+
delay_type_max <- get_delay_type_max(
95+
delay_types, delay_types_p, delay_types_id,
96+
delay_types_groups, delay_max, delay_np_pmf_groups
97+
)
98+
initial_infections_guess <- max(
99+
0,
100+
log(mean(head(cases, ifelse(length(cases) > 7, 7, length(cases)))))
101+
)
102+
if (!fixed) {
103+
alpha <- get_param(
104+
alpha_id, params_fixed_lookup, params_variable_lookup, params_value,
105+
params
106+
)
107+
rescaled_rho <- 2 * get_param(
108+
rho_id, params_fixed_lookup, params_variable_lookup,
109+
params_value, params
110+
) / noise_terms
111+
noise <- update_gp(
112+
PHI, M, L, alpha, rescaled_rho, eta, gp_type, nu
113+
)
114+
} else {
115+
noise <- numeric(0)
116+
}
117+
if (estimate_r) {
118+
gt_rev_pmf <- get_delay_rev_pmf(
119+
gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id,
120+
delay_types_groups, delay_max, delay_np_pmf,
121+
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
122+
1, 1, 0
123+
)
124+
R0 <- get_param(
125+
R0_id, params_fixed_lookup, params_variable_lookup, params_value, params
126+
)
127+
R <- update_Rt(
128+
ot_h, R0, noise, breakpoints, bp_effects, stationary
129+
)
130+
frac_obs <- get_param(
131+
frac_obs_id, params_fixed_lookup, params_variable_lookup, params_value,
132+
params
133+
)
134+
pop <- get_param(
135+
pop_id, params_fixed_lookup, params_variable_lookup, params_value,
136+
params
137+
)
138+
infections <- generate_infections(
139+
R, seeding_time, gt_rev_pmf, initial_infections, pop,
140+
use_pop, future_time, obs_scale, frac_obs, 1
141+
)
142+
} else {
143+
infections <- deconvolve_infections(
144+
shifted_cases, noise, fixed, backcalc_prior
145+
)
146+
}
147+
delay_rev_pmf <- get_delay_rev_pmf(
148+
delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id,
149+
delay_types_groups, delay_max, delay_np_pmf,
150+
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
151+
0, 1, 0
152+
)
153+
reports <- convolve_to_report(infections, delay_rev_pmf, seeding_time)
154+
if (week_effect > 1) {
155+
reports <- day_of_week_effect(reports, day_of_week, day_of_week_simplex)
156+
}
157+
if (obs_scale) {
158+
frac_obs <- get_param(
159+
frac_obs_id, params_fixed_lookup, params_variable_lookup, params_value,
160+
params
161+
)
162+
reports <- scale_obs(reports, frac_obs)
163+
}
164+
if (trunc_id) {
165+
trunc_rev_cmf <- get_delay_rev_pmf(
166+
trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id,
167+
delay_types_groups, delay_max, delay_np_pmf,
168+
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
169+
0, 1, 1
170+
)
171+
obs_reports <- truncate_obs(reports[1:ot], trunc_rev_cmf, 0)
172+
} else {
173+
obs_reports <- reports[1:ot]
174+
}
175+
if (any_accumulate) {
176+
obs_reports <- accumulate_reports(obs_reports, accumulate)
177+
}
178+
if (!fixed) {
179+
gaussian_process_lp(eta)
180+
}
181+
delays_lp(
182+
delay_params, delay_params_mean, delay_params_sd, delay_params_groups,
183+
delay_dist, delay_weight
184+
)
185+
params_lp(
186+
params, prior_dist, prior_dist_params, params_lower, params_upper
187+
)
188+
rt_lp(
189+
initial_infections, bp_effects, bp_sd, bp_n,
190+
cases, initial_infections_guess
191+
)
192+
if (likelihood) {
193+
dispersion <- get_param(
194+
dispersion_id, params_fixed_lookup, params_variable_lookup, params_value,
195+
params
196+
)
197+
report_lp(
198+
cases, case_times, obs_reports, dispersion, model_type, obs_weight
199+
)
200+
}
201+
if (!fixed) {
202+
rescaled_rho <- get_param(
203+
rho_id, params_fixed_lookup, params_variable_lookup,
204+
params_value, params
205+
)
206+
x <- seq(1, noise_terms)
207+
rho <- rescaled_rho * 0.5 * (max(x) - 1)
208+
}
209+
return(reports)
210+
}

0 commit comments

Comments
 (0)