Skip to content

Commit

Permalink
Merge pull request #980 from epiforecasts/fix-inplace-renaming
Browse files Browse the repository at this point in the history
Fix inplace renaming of columns
  • Loading branch information
nikosbosse authored Jan 8, 2025
2 parents e73e2b8 + 72eb030 commit 1629541
Show file tree
Hide file tree
Showing 12 changed files with 123 additions and 56 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Minor spelling / mathematical updates to Scoring rule vignette. (#969)

- A bug was fixed where `crps_sample()` could fail in edge cases.
- Implemented a new forecast class, `forecast_ordinal` with appropriate metrics. Ordinal forecasts are a form of categorical forecasts. The main difference between ordinal and nominal forecasts is that the outcome is ordered, rather than unordered.
- Refactored the way that columns get internally renamed in `as_forecast_<type>()` functions (#980)

# scoringutils 2.0.0

Expand Down
7 changes: 6 additions & 1 deletion R/class-forecast-binary.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ as_forecast_binary <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL) {
data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- as_forecast_generic(
data,
forecast_unit,
observed = observed,
predicted = predicted
)
data <- new_forecast(data, "forecast_binary")
assert_forecast(data)
return(data)
Expand Down
14 changes: 7 additions & 7 deletions R/class-forecast-nominal.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ as_forecast_nominal <- function(data,
observed = NULL,
predicted = NULL,
predicted_label = NULL) {
assert_character(predicted_label, len = 1, null.ok = TRUE)
assert_subset(predicted_label, names(data), empty.ok = TRUE)
if (!is.null(predicted_label)) {
setnames(data, old = predicted_label, new = "predicted_label")
}

data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- as_forecast_generic(
data,
forecast_unit,
observed = observed,
predicted = predicted,
predicted_label = predicted_label
)
data <- new_forecast(data, "forecast_nominal")
assert_forecast(data)
return(data)
Expand Down
14 changes: 7 additions & 7 deletions R/class-forecast-ordinal.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ as_forecast_ordinal <- function(data,
observed = NULL,
predicted = NULL,
predicted_label = NULL) {
assert_character(predicted_label, len = 1, null.ok = TRUE)
assert_subset(predicted_label, names(data), empty.ok = TRUE)
if (!is.null(predicted_label)) {
setnames(data, old = predicted_label, new = "predicted_label")
}

data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- as_forecast_generic(
data,
forecast_unit,
observed = observed,
predicted = predicted,
predicted_label = predicted_label
)
data <- new_forecast(data, "forecast_ordinal")
assert_forecast(data)
return(data)
Expand Down
7 changes: 6 additions & 1 deletion R/class-forecast-point.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ as_forecast_point.default <- function(data,
observed = NULL,
predicted = NULL,
...) {
data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- as_forecast_generic(
data,
forecast_unit,
observed = observed,
predicted = predicted
)
data <- new_forecast(data, "forecast_point")
assert_forecast(data)
return(data)
Expand Down
14 changes: 7 additions & 7 deletions R/class-forecast-quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ as_forecast_quantile.default <- function(data,
predicted = NULL,
quantile_level = NULL,
...) {
assert_character(quantile_level, len = 1, null.ok = TRUE)
assert_subset(quantile_level, names(data), empty.ok = TRUE)
if (!is.null(quantile_level)) {
setnames(data, old = quantile_level, new = "quantile_level")
}

data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- as_forecast_generic(
data,
forecast_unit,
observed = observed,
predicted = predicted,
quantile_level = quantile_level
)
data <- new_forecast(data, "forecast_quantile")
assert_forecast(data)
return(data)
Expand Down
14 changes: 7 additions & 7 deletions R/class-forecast-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ as_forecast_sample <- function(data,
observed = NULL,
predicted = NULL,
sample_id = NULL) {
assert_character(sample_id, len = 1, null.ok = TRUE)
assert_subset(sample_id, names(data), empty.ok = TRUE)
if (!is.null(sample_id)) {
setnames(data, old = sample_id, new = "sample_id")
}

data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- as_forecast_generic(
data,
forecast_unit,
observed = observed,
predicted = predicted,
sample_id = sample_id
)
data <- new_forecast(data, "forecast_sample")
assert_forecast(data)
return(data)
Expand Down
34 changes: 19 additions & 15 deletions R/class-forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,29 @@
#' It renames the required columns, where appropriate, and sets the forecast
#' unit.
#' @inheritParams as_forecast_doc_template
#' @param ... Named arguments that are used to rename columns. The names of the
#' arguments are the names of the columns that should be renamed. The values
#' are the new names.
#' @keywords as_forecast
as_forecast_generic <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL) {
# check inputs - general
...) {
data <- ensure_data.table(data)
assert_character(observed, len = 1, null.ok = TRUE)
assert_subset(observed, names(data), empty.ok = TRUE)

assert_character(predicted, len = 1, null.ok = TRUE)
assert_subset(predicted, names(data), empty.ok = TRUE)

# rename columns - general
if (!is.null(observed)) {
setnames(data, old = observed, new = "observed")
}
if (!is.null(predicted)) {
setnames(data, old = predicted, new = "predicted")
oldnames <- list(...)
newnames <- names(oldnames)
provided <- !sapply(oldnames, is.null)

lapply(seq_along(oldnames), function(i) {
var <- oldnames[[i]]
varname <- names(oldnames)[i]
assert_character(var, len = 1, null.ok = TRUE, .var.name = varname)
assert_subset(var, names(data), empty.ok = TRUE, .var.name = varname)
})

oldnames <- unlist(oldnames[provided])
newnames <- unlist(newnames[provided])
if (!is.null(oldnames) && length(oldnames) > 0) {
setnames(data, old = oldnames, new = newnames)
}

# set forecast unit (error handling is done in `set_forecast_unit()`)
Expand Down
15 changes: 4 additions & 11 deletions man/as_forecast_generic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 39 additions & 0 deletions man/assert_input_categorical.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions scoringutils.Rproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Version: 1.0
ProjectId: 008f911e-df6e-4218-825c-db1095ac43c4

RestoreWorkspace: No
SaveWorkspace: No
Expand Down
19 changes: 19 additions & 0 deletions tests/testthat/test-class-forecast-quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,25 @@ test_that("as_forecast_quantiles issue 557 fix", {
expect_equal(any(is.na(out$interval_coverage_deviation)), FALSE)
})

test_that("as_forecast_quantile doesn't modify column names in place", {
quantile_data <- data.table(
my_quantile = c(0.25, 0.5),
forecast_value = c(1, 2),
observed_value = c(5, 5)
)
pre <- names(quantile_data)

quantile_forecast <- quantile_data %>%
as_forecast_quantile(
predicted = "forecast_value",
observed = "observed_value",
quantile_level = "my_quantile"
)

post <- names(quantile_data)
expect_equal(pre, post)
})


# ==============================================================================
# is_forecast_quantile()
Expand Down

0 comments on commit 1629541

Please sign in to comment.