Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Account for possibility of custom objective function in XGBoost boost_tree() #459

Closed
smingerson opened this issue Apr 2, 2021 · 4 comments · Fixed by #787
Closed

Account for possibility of custom objective function in XGBoost boost_tree() #459

smingerson opened this issue Apr 2, 2021 · 4 comments · Fixed by #787
Labels
feature a feature request or enhancement

Comments

@smingerson
Copy link

Currently, passing a custom objective function causes an error downstream when predicting. This happens in xgb_pred() when using switch() off of the objective (usually a character string) to modify the output of predict.xgb.Booster().

library(xgboost)
library(parsnip)
library(workflows)
mod <- boost_tree("regression") %>% 
  set_engine("xgboost",
             objective = function(preds, dtrain) {
               truth <- as.numeric(getinfo(dtrain, "label"))
               error <- truth - preds
               gradient <- -2 * error
               hess <- rep.int(2, length(preds))
               list(grad = gradient, hess = hess)
             }
             )

dt <- data.frame(x = rnorm(15))
dt$y <- dt$x + rnorm(15, 0, .05)

wf <- workflow() %>% 
  add_model(mod) %>% 
  add_formula(y~x)
fitted <- fit(wf, data = dt)
predict(fitted, new_data = dt)
#> Error in switch(object$params$objective, `binary:logitraw` = stats::binomial()$linkinv(res), : EXPR must be a length 1 vector
@juliasilge juliasilge added the feature a feature request or enhancement label Apr 6, 2021
@jcpsantiago
Copy link

jcpsantiago commented Apr 20, 2021

I also see this error when using parsnip::set_engine("xgboost", params = list(eval_metric = "aucpr")) without setting the objective argument. I came across this error after updating parsnip to 0.1.5 from 0.1.4, and tune::tune_grid started failing. (tidymodels and the other individual packages were also updated in that time i.e. {workflows}, {tune}).

this test is passing:

test_that('xgboost alternate objective', {

  spec <-
    boost_tree() %>%
    set_engine("xgboost", objective = "reg:pseudohubererror") %>%
    set_mode("regression")

  xgb_fit <- spec %>% fit(mpg ~ ., data = mtcars)

but if objective is not a string (I guess this is the reason for labelling this as a feature request instead of a bug) it fails like in OPs code. Additionally, if one adds anything else to set_engine it fails with the same error -- are the ... all added to the same vector?

library(xgboost)
library(parsnip)
library(workflows)

mod <- boost_tree("classification") %>% 
  set_engine(
    "xgboost", 
    objective = "binary:logistic",
    params = list(eval_metric = "aucpr") # <- added this and changed the data to be a classification problem
  )

dt <- data.frame(
  x = rnorm(15),
  y = rnorm(15) + rnorm(15, 0, .05),
  target = as.factor(rbinom(15, 1, 0.5))
)

wf <- workflow() %>% 
  add_model(mod) %>% 
  add_formula(target ~ x + y)

fitted <- fit(wf, data = dt)
predict(fitted, new_data = dt)
#> Error in switch(object$params$objective, `binary:logitraw` = stats::binomial()$linkinv(res), : EXPR must be a length 1 vector

Created on 2021-04-20 by the reprex package (v2.0.0)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.0.2 (2020-06-22)
#>  os       macOS  10.16                
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Europe/Berlin               
#>  date     2021-04-20                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package     * version date       lib source        
#>    assertthat    0.2.1   2019-03-21 [1] CRAN (R 4.0.2)
#>  P cli           2.4.0   2021-04-05 [?] CRAN (R 4.0.2)
#>  P codetools     0.2-18  2020-11-04 [3] CRAN (R 4.0.2)
#>  P crayon        1.4.1   2021-02-08 [?] CRAN (R 4.0.2)
#>  P data.table    1.14.0  2021-02-21 [?] CRAN (R 4.0.2)
#>  P DBI           1.1.1   2021-01-15 [?] CRAN (R 4.0.2)
#>    digest        0.6.27  2020-10-24 [1] CRAN (R 4.0.2)
#>  P dplyr         1.0.5   2021-03-05 [?] CRAN (R 4.0.2)
#>    ellipsis      0.3.1   2020-05-15 [1] CRAN (R 4.0.2)
#>  P evaluate      0.14    2019-05-28 [?] CRAN (R 4.0.0)
#>  P fansi         0.4.2   2021-01-15 [?] CRAN (R 4.0.2)
#>  P fs            1.5.0   2020-07-31 [?] CRAN (R 4.0.2)
#>  P generics      0.1.0   2020-10-31 [?] CRAN (R 4.0.2)
#>    globals       0.14.0  2020-11-22 [1] CRAN (R 4.0.2)
#>    glue          1.4.2   2020-08-27 [1] CRAN (R 4.0.2)
#>  P hardhat       0.1.5   2020-11-09 [?] CRAN (R 4.0.2)
#>  P highr         0.9     2021-04-16 [?] CRAN (R 4.0.2)
#>  P htmltools     0.5.1.1 2021-01-22 [?] CRAN (R 4.0.2)
#>  P knitr         1.32    2021-04-14 [?] CRAN (R 4.0.2)
#>  P lattice       0.20-41 2020-04-02 [3] CRAN (R 4.0.2)
#>  P lifecycle     1.0.0   2021-02-15 [?] CRAN (R 4.0.2)
#>    magrittr      2.0.1   2020-11-17 [1] CRAN (R 4.0.2)
#>  P Matrix        1.3-2   2021-01-06 [?] CRAN (R 4.0.2)
#>  P parsnip     * 0.1.5   2021-01-19 [?] CRAN (R 4.0.2)
#>  P pillar        1.6.0   2021-04-13 [?] CRAN (R 4.0.2)
#>    pkgconfig     2.0.3   2019-09-22 [1] CRAN (R 4.0.2)
#>    purrr         0.3.4   2020-04-17 [1] CRAN (R 4.0.2)
#>    R6            2.5.0   2020-10-28 [1] CRAN (R 4.0.2)
#>  P reprex        2.0.0   2021-04-02 [?] CRAN (R 4.0.2)
#>  P rlang         0.4.10  2020-12-30 [?] CRAN (R 4.0.2)
#>  P rmarkdown     2.7     2021-02-19 [?] CRAN (R 4.0.2)
#>    rstudioapi    0.13    2020-11-12 [1] CRAN (R 4.0.2)
#>    sessioninfo   1.1.1   2018-11-05 [3] CRAN (R 4.0.2)
#>    stringi       1.5.3   2020-09-09 [1] CRAN (R 4.0.2)
#>    stringr       1.4.0   2019-02-10 [1] CRAN (R 4.0.2)
#>  P tibble        3.1.1   2021-04-18 [?] CRAN (R 4.0.2)
#>  P tidyr         1.1.3   2021-03-03 [?] CRAN (R 4.0.2)
#>    tidyselect    1.1.0   2020-05-11 [1] CRAN (R 4.0.2)
#>  P utf8          1.2.1   2021-03-12 [?] CRAN (R 4.0.2)
#>  P vctrs         0.3.7   2021-03-29 [?] CRAN (R 4.0.2)
#>  P withr         2.4.2   2021-04-18 [?] CRAN (R 4.0.2)
#>  P workflows   * 0.2.2   2021-03-10 [?] CRAN (R 4.0.2)
#>  P xfun          0.22    2021-03-11 [?] CRAN (R 4.0.2)
#>    xgboost     * 1.3.2.1 2021-01-18 [1] CRAN (R 4.0.2)
#>    yaml          2.2.1   2020-02-01 [1] CRAN (R 4.0.2)
#> 
#> [1] /Users/santiago/code/ds-models-fraud/renv/library/R-4.0/x86_64-apple-darwin17.0
#> [2] /private/var/folders/8d/zxgx1qkx44n7_wp6crx3ycsh0000gn/T/Rtmp6h44Di/renv-system-library
#> [3] /Library/Frameworks/R.framework/Versions/4.0/Resources/library
#> 
#>  P ── Loaded and on-disk path mismatch.

To fix it I had to change my code to:

mod <- boost_tree("classification") %>% 
  set_engine(
    "xgboost",
    params = list(
      eval_metric = "aucpr",
      objective = "binary:logistic" # <- MUST be present
    )
  )

the objective must be explicitly declared if params is used, otherwise object$params$objective is NULL. Not sure if this is expected behavior i.e. the default was dropped.

@amazongodman
Copy link

amazongodman commented Apr 19, 2022

There are similar reports here as well.

tidymodels/butcher#214

@simonpcouch
Copy link
Contributor

Related to #774.

@github-actions
Copy link

github-actions bot commented Sep 1, 2022

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Sep 1, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
feature a feature request or enhancement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants