@@ -151,35 +151,39 @@ real_type poisson_cauchy(rng_state_type& rng_state, real_type lambda) {
151
151
// Dieter 1980 ("Sampling from Binomial and Poisson Distributions",
152
152
// Computing 25 193-208) is meant to be the fastest with a
153
153
// constantly changing lambda, but is more complex to implement.
154
- //
155
- // Unfortunately, this is incorrect for single precision with fairly
156
- // large lambda (1e6 or more), giving a mean that is correct but
157
- // inflated variance. The underlying issue is not the cauchy as
158
- // that's correct, and it could just be precision loss?
159
- if (std::is_same<real_type, float >::value && lambda > 1e6 ) {
160
- throw std::runtime_error (" Single precision Poisson with lambda > 1e6 not yet supported" );
161
- }
162
154
real_type result = 0 ;
163
- const real_type log_lambda = dust::math::log<real_type>(lambda);
164
- const real_type sqrt_2lambda = dust::math::sqrt<real_type>(2 * lambda);
165
- const real_type magic_val = lambda * log_lambda - dust::math::lgamma<real_type>(1 + lambda);
166
- for (;;) {
167
- real_type comp_dev;
155
+ if (std::is_same<real_type, float >::value && lambda > 1e6 ) {
156
+ // This algorithm suffers bias in single precision with large
157
+ // lambda (as in var(Poisson(lambda)) / lambda ~ 10 rather than
158
+ // 1); it looks like we're passing back too much of the cauchy
159
+ // somehow. An alternative normal distribution based rejection
160
+ // sampling algorithm does better, to lambda between 1e9 and 1e10,
161
+ // then gets stuck in an finite loop probably because of precision
162
+ // loss (see mrc-4287 for implementation). So we just fall back on
163
+ // double precision here, which gets the job done.
164
+ result = poisson_cauchy<double >(rng_state, static_cast <double >(lambda));
165
+ } else {
166
+ const real_type log_lambda = dust::math::log<real_type>(lambda);
167
+ const real_type sqrt_2lambda = dust::math::sqrt<real_type>(2 * lambda);
168
+ const real_type magic_val = lambda * log_lambda - dust::math::lgamma<real_type>(1 + lambda);
168
169
for (;;) {
169
- comp_dev = cauchy<real_type>(rng_state, 0 , 1 );
170
- result = sqrt_2lambda * comp_dev + lambda;
171
- if (result >= 0 ) {
170
+ real_type comp_dev;
171
+ for (;;) {
172
+ comp_dev = cauchy<real_type>(rng_state, 0 , 1 );
173
+ result = sqrt_2lambda * comp_dev + lambda;
174
+ if (result >= 0 ) {
175
+ break ;
176
+ }
177
+ }
178
+ result = dust::math::trunc<real_type>(result);
179
+ const real_type check = static_cast <real_type>(0.9 ) *
180
+ (1 + comp_dev * comp_dev) *
181
+ dust::math::exp<real_type>(result * log_lambda - dust::math::lgamma<real_type>(1 + result) - magic_val);
182
+ const real_type u = random_real<real_type>(rng_state);
183
+ if (u <= check) {
172
184
break ;
173
185
}
174
186
}
175
- result = dust::math::trunc<real_type>(result);
176
- const real_type check = static_cast <real_type>(0.9 ) *
177
- (1 + comp_dev * comp_dev) *
178
- dust::math::exp<real_type>(result * log_lambda - dust::math::lgamma<real_type>(1 + result) - magic_val);
179
- const real_type u = random_real<real_type>(rng_state);
180
- if (u <= check) {
181
- break ;
182
- }
183
187
}
184
188
return result;
185
189
}
0 commit comments