Closed
Description
Currently, passing a custom objective function causes an error downstream when predicting. This happens in xgb_pred()
when using switch()
off of the objective (usually a character string) to modify the output of predict.xgb.Booster()
.
library(xgboost)
library(parsnip)
library(workflows)
mod <- boost_tree("regression") %>%
set_engine("xgboost",
objective = function(preds, dtrain) {
truth <- as.numeric(getinfo(dtrain, "label"))
error <- truth - preds
gradient <- -2 * error
hess <- rep.int(2, length(preds))
list(grad = gradient, hess = hess)
}
)
dt <- data.frame(x = rnorm(15))
dt$y <- dt$x + rnorm(15, 0, .05)
wf <- workflow() %>%
add_model(mod) %>%
add_formula(y~x)
fitted <- fit(wf, data = dt)
predict(fitted, new_data = dt)
#> Error in switch(object$params$objective, `binary:logitraw` = stats::binomial()$linkinv(res), : EXPR must be a length 1 vector