Skip to content

Account for possibility of custom objective function in XGBoost boost_tree() #459

Closed
@smingerson

Description

@smingerson

Currently, passing a custom objective function causes an error downstream when predicting. This happens in xgb_pred() when using switch() off of the objective (usually a character string) to modify the output of predict.xgb.Booster().

library(xgboost)
library(parsnip)
library(workflows)
mod <- boost_tree("regression") %>% 
  set_engine("xgboost",
             objective = function(preds, dtrain) {
               truth <- as.numeric(getinfo(dtrain, "label"))
               error <- truth - preds
               gradient <- -2 * error
               hess <- rep.int(2, length(preds))
               list(grad = gradient, hess = hess)
             }
             )

dt <- data.frame(x = rnorm(15))
dt$y <- dt$x + rnorm(15, 0, .05)

wf <- workflow() %>% 
  add_model(mod) %>% 
  add_formula(y~x)
fitted <- fit(wf, data = dt)
predict(fitted, new_data = dt)
#> Error in switch(object$params$objective, `binary:logitraw` = stats::binomial()$linkinv(res), : EXPR must be a length 1 vector

Metadata

Metadata

Assignees

No one assigned

    Labels

    featurea feature request or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions