Skip to content

Commit

Permalink
Add R test
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfryda committed Sep 11, 2023
1 parent 5420002 commit d50a58b
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
1 change: 1 addition & 0 deletions h2o-r/h2o-package/R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -1639,6 +1639,7 @@ handle_pdp <- function(newdata, column, target, show_logodds, row_index, models_
}

.check_model_suitability_for_calculation_of_contributions <- function(model, background_frame=NULL) {
if (is.null(model)) stop("Model is NULL.")
is_h2o_model <- .is_h2o_model(model)
if (!is_h2o_model || !(.is_h2o_tree_model(model) || model@algorithm == "generic" ||
(!is.null(background_frame) && tolower(model@algorithm) %in% c("glm", "deeplearning", "stackedensemble")))) {
Expand Down
97 changes: 97 additions & 0 deletions h2o-r/tests/testdir_misc/runit_SHAP.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
setwd(normalizePath(dirname(R.utils::commandArgs(asValues = TRUE)$"f")))
source("../../scripts/h2o-r-test-setup.R")


ALGOS <- c("deeplearning", "drf", "gbm", "glm", "stackedensemble", "xgboost")


aml_models_regression_test <- function() {
df <- h2o.importFile(locate("smalldata/titanic/titanic_expanded.csv"))
df$name <- NULL
dfs <- h2o.splitFrame(df)

train <- dfs[[1]]
test <- dfs[[2]]

nColOriginal <- -1
aml <- h2o.automl(y="fare", training_frame = train, max_models = 12)

for (algo in ALGOS){
print(algo)
model <- h2o.get_best_model(aml, algo)
contr <- h2o.predict_contributions(model, test, output_format = if (algo == "stackedensemble") "compact" else "original", background_frame = train)
nColOriginal <- max(nColOriginal, ncol(contr))
expect_true(all.equal(h2o.predict(model, test), h2o.sum(contr, axis=1, return_frame = TRUE)))
}

for (algo in ALGOS){
print(algo)
model <- h2o.get_best_model(aml, algo)
contr <- h2o.predict_contributions(model, test[1:3,], output_format = "compact", background_frame = train, output_per_reference = TRUE)
expect_true(nColOriginal >= ncol(contr))
eps <- 1e-5
if (algo %in% c("xgboost"))
eps <- 1e-3

contr0 <- contr[contr$RowIdx == 0,]
contr0 <- contr0[order(as.vector(contr0$BackgroundRowIdx)),]

contr1 <- contr[contr$RowIdx == 1,]
contr1 <- contr1[order(as.vector(contr1$BackgroundRowIdx)),]

expect_true(all.equal(contr0$BiasTerm, contr1$BiasTerm))

expect_true(max(abs(as.vector(h2o.predict(model, train)) - as.vector(contr0$BiasTerm))) < eps)
}
}


aml_models_binomial_test <- function() {
df <- h2o.importFile(locate("smalldata/titanic/titanic_expanded.csv"))
df$name <- NULL
dfs <- h2o.splitFrame(df)

train <- dfs[[1]]
test <- dfs[[2]]

nColCompact <- -1

aml <- h2o.automl(y="survived", training_frame = train, max_models = 12)

for (algo in ALGOS){
print(algo)
model <- h2o.get_best_model(aml, algo)
contr <- h2o.predict_contributions(model, test, output_format = "compact", background_frame = train, output_space = TRUE)
nColCompact <- max(nColCompact, ncol(contr))
expect_true(h2o.all(h2o.abs(h2o.predict(model, test)[, 3] - h2o.sum(contr, axis=1, return_frame = TRUE)) < 1e-5))
}

for (algo in ALGOS){
print(algo)
model <- h2o.get_best_model(aml, algo)
eps <- 1e-5
if (algo %in% c("xgboost"))
eps <- 1e-3

contr <- h2o.predict_contributions(model, test[1:3,], output_format = if (algo == "stackedensemble") "compact" else "original", background_frame = train, output_per_reference = TRUE)
expect_true(nColCompact <= ncol(contr))

contr0 <- contr[contr$RowIdx == 0,]
contr0 <- contr0[order(as.vector(contr0$BackgroundRowIdx)),]

contr1 <- contr[contr$RowIdx == 1,]
contr1 <- contr1[order(as.vector(contr1$BackgroundRowIdx)),]

expect_true(all.equal(contr0$BiasTerm, contr1$BiasTerm))

link <- if (algo %in% c("gbm", "xgboost", "glm", "stackedensemble")) binomial()$linkinv else function(x) x

expect_true(max(abs(as.vector(h2o.predict(model, train)[,3]) - link(as.vector(contr0$BiasTerm)))) < eps)
}
}


doSuite("SHAP Tests", makeSuite(
aml_models_regression_test,
aml_models_binomial_test
))

0 comments on commit d50a58b

Please sign in to comment.