Skip to content

Commit 7015437

Browse files
authored
Merge pull request #411 from mrc-ide/mrc-4347
mrc-4347: cope with missing data
2 parents 07f4cb9 + f6bfa60 commit 7015437

22 files changed

+188
-79
lines changed

.gitattributes

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
R/cpp11.R linguist-generated=true
2+
src/cpp11.cpp linguist-generated=true

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: dust
22
Title: Iterate Multiple Realisations of Stochastic Models
3-
Version: 0.14.7
3+
Version: 0.14.8
44
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
55
email = "[email protected]"),
66
person("Alex", "Hill", role = "aut"),

R/cpp11.R

Lines changed: 12 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R/dust.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## Generated by dust (version 0.14.4) - do not edit
1+
## Generated by dust (version 0.14.8) - do not edit
22
logistic <- R6::R6Class(
33
"dust",
44
cloneable = FALSE,

inst/examples/sir.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,11 @@ class sir {
5959

6060
real_type compare_data(const real_type * state, const data_type& data,
6161
rng_state_type& rng_state) {
62-
const real_type incidence_modelled = state[4];
6362
const real_type incidence_observed = data.incidence;
63+
if (std::isnan(data.incidence)) {
64+
return 0;
65+
}
66+
const real_type incidence_modelled = state[4];
6467
const real_type lambda = incidence_modelled +
6568
dust::random::exponential(rng_state, shared->exp_noise);
6669
return dust::density::poisson(incidence_observed, lambda, true);

inst/include/dust/filter.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ filter(T * obj,
5959
obj->run(time);
6060
obj->compare_data(weights, d->second);
6161

62-
// TODO: we should cope better with the case where all weights
63-
// are 0; I think that is the behaviour in the model (or rather
64-
// the case where there is no data and so we do not resample)
65-
//
6662
// TODO: we should cope better with the case where one filter
6763
// has become impossible but others continue, but that's hard!
6864
auto wi = weights.begin();

inst/include/dust/filter_tools.hpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <vector>
99

1010
#include "dust/random/random.hpp"
11+
#include "dust/utils.hpp"
1112

1213
namespace dust {
1314
namespace filter {
@@ -17,23 +18,30 @@ void resample_weight(typename std::vector<real_type>::const_iterator w,
1718
size_t n, real_type u, size_t offset,
1819
typename std::vector<size_t>::iterator idx) {
1920
const real_type tot = std::accumulate(w, w + n, static_cast<real_type>(0));
20-
real_type ww = 0.0, uu0 = tot * u / n, du = tot / n;
21-
size_t j = offset;
22-
const size_t end = n + offset;
23-
for (size_t i = 0; i < n; ++i) {
24-
// We could accumulate uu by adding du at each iteration but that
25-
// suffers roundoff error here with floats.
26-
const real_type uu = uu0 + i * du;
27-
// The second clause (i.e., j - offset < n) should never be hit
28-
// but prevents any invalid read if we have pathalogical 'u' that
29-
// is within floating point eps of 1
30-
while (ww < uu && j < end) {
31-
ww += *w;
32-
++w;
33-
++j;
21+
if (tot == 0 && dust::utils::all_zero<real_type>(w, w + n)) {
22+
for (size_t i = 0; i < n; ++i) {
23+
*idx = offset + i;
24+
++idx;
25+
}
26+
} else {
27+
real_type ww = 0.0, uu0 = tot * u / n, du = tot / n;
28+
size_t j = offset;
29+
const size_t end = n + offset;
30+
for (size_t i = 0; i < n; ++i) {
31+
// We could accumulate uu by adding du at each iteration but that
32+
// suffers roundoff error here with floats.
33+
const real_type uu = uu0 + i * du;
34+
// The second clause (i.e., j - offset < n) should never be hit
35+
// but prevents any invalid read if we have pathalogical 'u' that
36+
// is within floating point eps of 1
37+
while (ww < uu && j < end) {
38+
ww += *w;
39+
++w;
40+
++j;
41+
}
42+
*idx = j == 0 ? 0 : j - 1;
43+
++idx;
3444
}
35-
*idx = j == 0 ? 0 : j - 1;
36-
++idx;
3745
}
3846
}
3947

inst/include/dust/gpu/kernels.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -301,12 +301,16 @@ void find_intervals(const real_type * cum_weights,
301301
real_type start_val = par_idx > 0 ? cum_weights[par_idx * n_particles_each - 1] : 0;
302302
real_type normalising_constant =
303303
cum_weights[(par_idx + 1) * n_particles_each - 1] - start_val;
304-
real_type u_particle = normalising_constant /
305-
static_cast<real_type>(n_particles_each) *
306-
(u[par_idx] + i % n_particles_each);
307-
index[i] = binary_interval_search(
308-
cum_weights + par_idx * n_particles_each,
309-
n_particles_each, u_particle, start_val) + par_idx * n_particles_each;
304+
if (normalising_constant == 0) {
305+
index[i] = i;
306+
} else {
307+
real_type u_particle = normalising_constant /
308+
static_cast<real_type>(n_particles_each) *
309+
(u[par_idx] + i % n_particles_each);
310+
index[i] = binary_interval_search(
311+
cum_weights + par_idx * n_particles_each,
312+
n_particles_each, u_particle, start_val) + par_idx * n_particles_each;
313+
}
310314
#ifdef __NVCC__
311315
}
312316
#else

inst/include/dust/random/version.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
#define DUST_VERSION_MAJOR 0
66
#define DUST_VERSION_MINOR 14
7-
#define DUST_VERSION_PATCH 7
8-
#define DUST_VERSION_STRING "0.14.7"
9-
#define DUST_VERSION_CODE 1407
7+
#define DUST_VERSION_PATCH 8
8+
#define DUST_VERSION_STRING "0.14.8"
9+
#define DUST_VERSION_CODE 1408
1010

1111
#endif

inst/include/dust/utils.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ T clamp(T x, T min, T max) {
107107
return std::max(std::min(x, max), min);
108108
}
109109

110+
template <typename real_type, typename It>
111+
bool all_zero(It begin, It end) {
112+
return std::all_of(begin, end, [](real_type x) { return x == 0; });
113+
}
114+
110115
}
111116
}
112117

0 commit comments

Comments
 (0)