Skip to content

Commit 2fbb2a7

Browse files
authored
Merge pull request #408 from mrc-ide/mrc-4287
Fall back on fp64 for large Poisson in fp32 mode
2 parents b198275 + e20a08b commit 2fbb2a7

File tree

4 files changed

+41
-34
lines changed

4 files changed

+41
-34
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: dust
22
Title: Iterate Multiple Realisations of Stochastic Models
3-
Version: 0.14.5
3+
Version: 0.14.6
44
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
55
email = "[email protected]"),
66
person("Alex", "Hill", role = "aut"),

inst/include/dust/random/poisson.hpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -151,35 +151,39 @@ real_type poisson_cauchy(rng_state_type& rng_state, real_type lambda) {
151151
// Dieter 1980 ("Sampling from Binomial and Poisson Distributions",
152152
// Computing 25 193-208) is meant to be the fastest with a
153153
// constantly changing lambda, but is more complex to implement.
154-
//
155-
// Unfortunately, this is incorrect for single precision with fairly
156-
// large lambda (1e6 or more), giving a mean that is correct but
157-
// inflated variance. The underlying issue is not the cauchy as
158-
// that's correct, and it could just be precision loss?
159-
if (std::is_same<real_type, float>::value && lambda > 1e6) {
160-
throw std::runtime_error("Single precision Poisson with lambda > 1e6 not yet supported");
161-
}
162154
real_type result = 0;
163-
const real_type log_lambda = dust::math::log<real_type>(lambda);
164-
const real_type sqrt_2lambda = dust::math::sqrt<real_type>(2 * lambda);
165-
const real_type magic_val = lambda * log_lambda - dust::math::lgamma<real_type>(1 + lambda);
166-
for (;;) {
167-
real_type comp_dev;
155+
if (std::is_same<real_type, float>::value && lambda > 1e6) {
156+
// This algorithm suffers bias in single precision with large
157+
// lambda (as in var(Poisson(lambda)) / lambda ~ 10 rather than
158+
// 1); it looks like we're passing back too much of the cauchy
159+
// somehow. An alternative normal distribution based rejection
160+
// sampling algorithm does better, to lambda between 1e9 and 1e10,
161+
// then gets stuck in an finite loop probably because of precision
162+
// loss (see mrc-4287 for implementation). So we just fall back on
163+
// double precision here, which gets the job done.
164+
result = poisson_cauchy<double>(rng_state, static_cast<double>(lambda));
165+
} else {
166+
const real_type log_lambda = dust::math::log<real_type>(lambda);
167+
const real_type sqrt_2lambda = dust::math::sqrt<real_type>(2 * lambda);
168+
const real_type magic_val = lambda * log_lambda - dust::math::lgamma<real_type>(1 + lambda);
168169
for (;;) {
169-
comp_dev = cauchy<real_type>(rng_state, 0, 1);
170-
result = sqrt_2lambda * comp_dev + lambda;
171-
if (result >= 0) {
170+
real_type comp_dev;
171+
for (;;) {
172+
comp_dev = cauchy<real_type>(rng_state, 0, 1);
173+
result = sqrt_2lambda * comp_dev + lambda;
174+
if (result >= 0) {
175+
break;
176+
}
177+
}
178+
result = dust::math::trunc<real_type>(result);
179+
const real_type check = static_cast<real_type>(0.9) *
180+
(1 + comp_dev * comp_dev) *
181+
dust::math::exp<real_type>(result * log_lambda - dust::math::lgamma<real_type>(1 + result) - magic_val);
182+
const real_type u = random_real<real_type>(rng_state);
183+
if (u <= check) {
172184
break;
173185
}
174186
}
175-
result = dust::math::trunc<real_type>(result);
176-
const real_type check = static_cast<real_type>(0.9) *
177-
(1 + comp_dev * comp_dev) *
178-
dust::math::exp<real_type>(result * log_lambda - dust::math::lgamma<real_type>(1 + result) - magic_val);
179-
const real_type u = random_real<real_type>(rng_state);
180-
if (u <= check) {
181-
break;
182-
}
183187
}
184188
return result;
185189
}

inst/include/dust/random/version.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
#define DUST_VERSION_MAJOR 0
66
#define DUST_VERSION_MINOR 14
7-
#define DUST_VERSION_PATCH 5
8-
#define DUST_VERSION_STRING "0.14.5"
9-
#define DUST_VERSION_CODE 1405
7+
#define DUST_VERSION_PATCH 6
8+
#define DUST_VERSION_STRING "0.14.6"
9+
#define DUST_VERSION_CODE 1406
1010

1111
#endif

tests/testthat/test-rng.R

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,6 @@ test_that("Poisson numbers only valid for 0 <= lambda < Inf", {
259259
})
260260

261261

262-
test_that("Single precision poisson numbers only valid to 1e6", {
263-
expect_error(dust_rng$new(1, real_type = "float")$poisson(1, 1e7),
264-
"Single precision Poisson with lambda > 1e6 not yet supported")
265-
})
266-
267-
268262
test_that("Short circuit exit does not update rng state", {
269263
rng <- dust_rng$new(1)
270264
s <- rng$state()
@@ -1412,3 +1406,12 @@ test_that("regression tests of binomial issues", {
14121406
expect_setequal(rng$binomial(10000, 4, 0.2), 0:4)
14131407
expect_setequal(rng$binomial(10000, 10, 0.5), 0:10)
14141408
})
1409+
1410+
1411+
test_that("Very big poisson with single precision now work", {
1412+
n <- 1000000
1413+
lambda <- 1e8
1414+
ans <- dust_rng$new(1, real_type = "float")$poisson(n, lambda)
1415+
expect_equal(mean(ans), lambda, tolerance = 1e-2)
1416+
expect_equal(var(ans), lambda, tolerance = 1e-2)
1417+
})

0 commit comments

Comments
 (0)