-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Account for possibility of custom objective function in XGBoost boost_tree()
#459
Comments
I also see this error when using this test is passing:
spec <-
boost_tree() %>%
set_engine("xgboost", objective = "reg:pseudohubererror") %>%
set_mode("regression")
xgb_fit <- spec %>% fit(mpg ~ ., data = mtcars) but if library(xgboost)
library(parsnip)
library(workflows)
mod <- boost_tree("classification") %>%
set_engine(
"xgboost",
objective = "binary:logistic",
params = list(eval_metric = "aucpr") # <- added this and changed the data to be a classification problem
)
dt <- data.frame(
x = rnorm(15),
y = rnorm(15) + rnorm(15, 0, .05),
target = as.factor(rbinom(15, 1, 0.5))
)
wf <- workflow() %>%
add_model(mod) %>%
add_formula(target ~ x + y)
fitted <- fit(wf, data = dt)
predict(fitted, new_data = dt)
#> Error in switch(object$params$objective, `binary:logitraw` = stats::binomial()$linkinv(res), : EXPR must be a length 1 vector Created on 2021-04-20 by the reprex package (v2.0.0) Session infosessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.0.2 (2020-06-22)
#> os macOS 10.16
#> system x86_64, darwin17.0
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Europe/Berlin
#> date 2021-04-20
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> ! package * version date lib source
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.0.2)
#> P cli 2.4.0 2021-04-05 [?] CRAN (R 4.0.2)
#> P codetools 0.2-18 2020-11-04 [3] CRAN (R 4.0.2)
#> P crayon 1.4.1 2021-02-08 [?] CRAN (R 4.0.2)
#> P data.table 1.14.0 2021-02-21 [?] CRAN (R 4.0.2)
#> P DBI 1.1.1 2021-01-15 [?] CRAN (R 4.0.2)
#> digest 0.6.27 2020-10-24 [1] CRAN (R 4.0.2)
#> P dplyr 1.0.5 2021-03-05 [?] CRAN (R 4.0.2)
#> ellipsis 0.3.1 2020-05-15 [1] CRAN (R 4.0.2)
#> P evaluate 0.14 2019-05-28 [?] CRAN (R 4.0.0)
#> P fansi 0.4.2 2021-01-15 [?] CRAN (R 4.0.2)
#> P fs 1.5.0 2020-07-31 [?] CRAN (R 4.0.2)
#> P generics 0.1.0 2020-10-31 [?] CRAN (R 4.0.2)
#> globals 0.14.0 2020-11-22 [1] CRAN (R 4.0.2)
#> glue 1.4.2 2020-08-27 [1] CRAN (R 4.0.2)
#> P hardhat 0.1.5 2020-11-09 [?] CRAN (R 4.0.2)
#> P highr 0.9 2021-04-16 [?] CRAN (R 4.0.2)
#> P htmltools 0.5.1.1 2021-01-22 [?] CRAN (R 4.0.2)
#> P knitr 1.32 2021-04-14 [?] CRAN (R 4.0.2)
#> P lattice 0.20-41 2020-04-02 [3] CRAN (R 4.0.2)
#> P lifecycle 1.0.0 2021-02-15 [?] CRAN (R 4.0.2)
#> magrittr 2.0.1 2020-11-17 [1] CRAN (R 4.0.2)
#> P Matrix 1.3-2 2021-01-06 [?] CRAN (R 4.0.2)
#> P parsnip * 0.1.5 2021-01-19 [?] CRAN (R 4.0.2)
#> P pillar 1.6.0 2021-04-13 [?] CRAN (R 4.0.2)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.0.2)
#> purrr 0.3.4 2020-04-17 [1] CRAN (R 4.0.2)
#> R6 2.5.0 2020-10-28 [1] CRAN (R 4.0.2)
#> P reprex 2.0.0 2021-04-02 [?] CRAN (R 4.0.2)
#> P rlang 0.4.10 2020-12-30 [?] CRAN (R 4.0.2)
#> P rmarkdown 2.7 2021-02-19 [?] CRAN (R 4.0.2)
#> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.0.2)
#> sessioninfo 1.1.1 2018-11-05 [3] CRAN (R 4.0.2)
#> stringi 1.5.3 2020-09-09 [1] CRAN (R 4.0.2)
#> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.0.2)
#> P tibble 3.1.1 2021-04-18 [?] CRAN (R 4.0.2)
#> P tidyr 1.1.3 2021-03-03 [?] CRAN (R 4.0.2)
#> tidyselect 1.1.0 2020-05-11 [1] CRAN (R 4.0.2)
#> P utf8 1.2.1 2021-03-12 [?] CRAN (R 4.0.2)
#> P vctrs 0.3.7 2021-03-29 [?] CRAN (R 4.0.2)
#> P withr 2.4.2 2021-04-18 [?] CRAN (R 4.0.2)
#> P workflows * 0.2.2 2021-03-10 [?] CRAN (R 4.0.2)
#> P xfun 0.22 2021-03-11 [?] CRAN (R 4.0.2)
#> xgboost * 1.3.2.1 2021-01-18 [1] CRAN (R 4.0.2)
#> yaml 2.2.1 2020-02-01 [1] CRAN (R 4.0.2)
#>
#> [1] /Users/santiago/code/ds-models-fraud/renv/library/R-4.0/x86_64-apple-darwin17.0
#> [2] /private/var/folders/8d/zxgx1qkx44n7_wp6crx3ycsh0000gn/T/Rtmp6h44Di/renv-system-library
#> [3] /Library/Frameworks/R.framework/Versions/4.0/Resources/library
#>
#> P ── Loaded and on-disk path mismatch. To fix it I had to change my code to: mod <- boost_tree("classification") %>%
set_engine(
"xgboost",
params = list(
eval_metric = "aucpr",
objective = "binary:logistic" # <- MUST be present
)
) the objective must be explicitly declared if |
There are similar reports here as well. |
Related to #774. |
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
Currently, passing a custom objective function causes an error downstream when predicting. This happens in
xgb_pred()
when usingswitch()
off of the objective (usually a character string) to modify the output ofpredict.xgb.Booster()
.The text was updated successfully, but these errors were encountered: