Skip to content

Commit

Permalink
973: underprediction/overprediction/etc. fails if all observations ar…
Browse files Browse the repository at this point in the history
…e above/below the medians (#974)

* add checks to `crps_sample()`

* use non-copying conversion to matrix
  • Loading branch information
sbfnk authored Nov 4, 2024
1 parent 48ce635 commit 0761099
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 15 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

Minor spelling / mathematical updates to Scoring rule vignette. (#969)

## Package updates

- A bug was fixed where `crps_sample()` could fail in edge cases.

# scoringutils 2.0.0

This update represents a major rewrite of the package and introduces breaking changes. If you want to keep using the older version, you can download it using `remotes::install_github("epiforecasts/[email protected]")`.
Expand Down
40 changes: 26 additions & 14 deletions R/metrics-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ crps_sample <- function(observed, predicted, separate_results = FALSE, ...) {
)

if (separate_results) {
if (is.null(dim(predicted))) {
## if `predicted` is a vector convert to matrix
dim(predicted) <- c(1, length(predicted))
}
medians <- apply(predicted, 1, median)
dispersion <- scoringRules::crps_sample(
y = medians,
Expand All @@ -313,21 +317,29 @@ crps_sample <- function(observed, predicted, separate_results = FALSE, ...) {
overprediction <- rep(0, length(observed))
underprediction <- rep(0, length(observed))

overprediction[observed < medians] <- scoringRules::crps_sample(
y = observed[observed < medians],
dat = predicted[observed < medians, , drop = FALSE],
...
)
underprediction[observed > medians] <- scoringRules::crps_sample(
y = observed[observed > medians],
dat = predicted[observed > medians, , drop = FALSE],
...
)
if (any(observed < medians)) {
overprediction[observed < medians] <- scoringRules::crps_sample(
y = observed[observed < medians],
dat = predicted[observed < medians, , drop = FALSE],
...
)
}
if (any(observed > medians)) {
underprediction[observed > medians] <- scoringRules::crps_sample(
y = observed[observed > medians],
dat = predicted[observed > medians, , drop = FALSE],
...
)
}

overprediction[overprediction > 0] <-
overprediction[overprediction > 0] - dispersion[overprediction > 0]
underprediction[underprediction > 0] <-
underprediction[underprediction > 0] - dispersion[underprediction > 0]
if (any(overprediction > 0)) {
overprediction[overprediction > 0] <-
overprediction[overprediction > 0] - dispersion[overprediction > 0]
}
if (any(underprediction > 0)) {
underprediction[underprediction > 0] <-
underprediction[underprediction > 0] - dispersion[underprediction > 0]
}

return(list(
crps = crps,
Expand Down
10 changes: 9 additions & 1 deletion tests/testthat/test-metrics-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,15 @@ test_that("crps_sample() components correspond to WIS components", {
expect_true(all(abs(c(dcrps - dwis, ocrps - owis, ucrps - uwis)) < 0.01))
})


test_that("crps_sample() works with a single observation", {
expect_no_condition(
crps <- crps_sample(
observed = 2.5, predicted = 1.5:10.5, separate_results = TRUE
)
)
expect_equal(length(crps), 4)
expect_equal(unique(vapply(crps, length, integer(1))), 1)
})

test_that("bias_sample() throws an error when missing observed", {
observed <- rpois(10, lambda = 1:10)
Expand Down

0 comments on commit 0761099

Please sign in to comment.