Skip to content

Commit

Permalink
Beeswarm plot update (#424)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju authored Nov 28, 2024
1 parent 5a52e6a commit a7efa3e
Show file tree
Hide file tree
Showing 23 changed files with 490 additions and 321 deletions.
38 changes: 31 additions & 7 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@
#' Whether to include the average feature value in a group on the y-axis or not.
#' If `FALSE` (default), then no value is shown for the groups. If `TRUE`, then `shapr` includes the mean of the
#' features in each group.
#' @param ... Currently not used.
#' @param beeswarm_cex Numeric.
#' The cex argument of [ggbeeswarm::geom_beeswarm()], controlling the spacing in the beeswarm plots.
#' @param ... Other arguments passed to underlying functions,
#' like [ggbeeswarm::geom_beeswarm()] for `plot_type = "beeswarm"`.
#'
#' @details See the examples below, or `vignette("understanding_shapr", package = "shapr")` for an examples of
#' how you should use the function.
Expand Down Expand Up @@ -105,7 +108,7 @@
#' n_MC_samples = 1e2
#' )
#'
#' if (requireNamespace("ggplot2", quietly = TRUE)) {
#' if (requireNamespace("ggplot2", quietly = TRUE) && requireNamespace("ggbeeswarm", quietly = TRUE)) {
#' # The default plotting option is a bar plot of the Shapley values
#' # We draw bar plots for the first 4 observations
#' plot(x, index_x_explain = 1:4)
Expand All @@ -123,6 +126,12 @@
#' # Or a beeswarm plot summarising the Shapley values and feature values for all features
#' plot(x, plot_type = "beeswarm")
#' plot(x, plot_type = "beeswarm", col = c("red", "black")) # we can change colors
#'
#' # Additional arguments can be passed to ggbeeswarm::geom_beeswarm() using the '...' argument.
#' # For instance, sometimes the beeswarm plots overlap too much.
#' # This can be fixed with the 'corral="wrap" argument.
#' # See ?ggbeeswarm::geom_beeswarm for more information.
#' plot(x, plot_type = "beeswarm", corral = "wrap")
#' }
#'
#' # Example of scatter and beeswarm plot with factor variables
Expand Down Expand Up @@ -155,7 +164,7 @@
#' n_MC_samples = 1e2
#' )
#'
#' if (requireNamespace("ggplot2", quietly = TRUE)) {
#' if (requireNamespace("ggplot2", quietly = TRUE) && requireNamespace("ggbeeswarm", quietly = TRUE)) {
#' plot(x, plot_type = "scatter")
#' plot(x, plot_type = "beeswarm")
#' }
Expand All @@ -172,6 +181,7 @@ plot.shapr <- function(x,
scatter_features = NULL,
scatter_hist = TRUE,
include_group_feature_means = FALSE,
beeswarm_cex = 1 / length(index_x_explain)^(1 / 4),
...) {
if (!requireNamespace("ggplot2", quietly = TRUE)) {
stop("ggplot2 is not installed. Please run install.packages('ggplot2')")
Expand All @@ -190,6 +200,7 @@ plot.shapr <- function(x,

if (is.null(index_x_explain)) index_x_explain <- seq(x$internal$parameters$n_explain)
if (is.null(top_k_features)) top_k_features <- x$internal$parameters$n_features + 1
if (length(beeswarm_cex) == 0) beeswarm_cex <- 1 / length(index_x_explain)^(1 / 4) # Update if index_x_explain is

is_groupwise <- x$internal$parameters$is_groupwise

Expand Down Expand Up @@ -229,7 +240,7 @@ plot.shapr <- function(x,

# melting Kshap
shap_names <- x$internal$parameters$shap_names
dt_shap <- round(data.table::copy(x$shapley_values_est), digits = digits)
dt_shap <- signif(data.table::copy(x$shapley_values_est))
dt_shap[, id := .I]
dt_shap_long <- data.table::melt(dt_shap, id.vars = "id", value.name = "phi")
dt_shap_long[, sign := factor(sign(phi), levels = c(1, -1), labels = c("Increases", "Decreases"))]
Expand Down Expand Up @@ -283,7 +294,14 @@ plot.shapr <- function(x,
dt_plot <- dt_plot[id %in% index_x_explain]
gg <- make_scatter_plot(dt_plot, scatter_features, scatter_hist, col, factor_features)
} else if (plot_type == "beeswarm") {
gg <- make_beeswarm_plot(dt_plot, col, index_x_explain, x, factor_features)
gg <- make_beeswarm_plot(dt_plot,
col,
index_x_explain,
x,
factor_features,
beeswarm_cex = beeswarm_cex,
...
)
} else { # if bar or waterfall plot
# Only plot the desired observations
dt_plot <- dt_plot[id %in% index_x_explain]
Expand Down Expand Up @@ -552,7 +570,13 @@ process_factor_data <- function(dt, factor_cols) {
}


make_beeswarm_plot <- function(dt_plot, col, index_x_explain, x, factor_cols) {
make_beeswarm_plot <- function(dt_plot,
col,
index_x_explain,
x,
factor_cols,
beeswarm_cex,
...) {
if (!requireNamespace("ggbeeswarm", quietly = TRUE)) {
stop("geom_beeswarm is not installed. Please run install.packages('ggbeeswarm')")
}
Expand Down Expand Up @@ -604,7 +628,7 @@ make_beeswarm_plot <- function(dt_plot, col, index_x_explain, x, factor_cols) {

gg <- ggplot2::ggplot(dt_plot, ggplot2::aes(x = variable, y = phi, color = feature_value_scaled)) +
ggplot2::geom_hline(yintercept = 0, color = "grey70", linewidth = 0.5) +
ggbeeswarm::geom_beeswarm(priority = "random", cex = 1 / length(index_x_explain)^(1 / 4)) +
ggbeeswarm::geom_beeswarm(priority = "random", cex = beeswarm_cex, ...) +
ggplot2::coord_flip() +
ggplot2::theme_classic() +
ggplot2::theme(panel.grid.major.y = ggplot2::element_line(colour = "grey90", linetype = "dashed")) +
Expand Down
Binary file added explanation.rds
Binary file not shown.
134 changes: 134 additions & 0 deletions inst/scripts/bugfix_beeswarm.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@

library(tidyverse)
library(data.table)
library(shapr)


Data_O <- data.table::fread(file = "https://github.com/user-attachments/files/17777841/synthetic_data.csv")
# Remove rows with missing values
Data_O <- Data_O[complete.cases(Data_O),]
# Handle extremes of target
Data_O <- Data_O %>% filter(actief_in_inst_2022_SCH > 0.60)
Data_O$actief_in_inst_2022_SCH <- sqrt(Data_O$actief_in_inst_2022_SCH)

# Features
check <- as.data.frame(model.matrix(~., data = Data_O[, c(3, 32:36, 38, 55:68)]))
check[] <- lapply(check, as.numeric)
check <- as.matrix(check)
check <- check[, -1]

# Outcome variable
y <- as.numeric(Data_O$actief_in_inst_2022_SCH)

# Split dataset into training (70%) and test (30%) sets
samp <- sample(nrow(Data_O), 0.7 * nrow(Data_O))

Train1 <- check[samp, ]
Train1 <- as.data.frame(Train1)

Test1 <- check[-samp, ]
Test1 <- as.data.frame(Test1)

Y_train <- y[samp]
Y_test <- y[-samp]

# Train Random Forest model
rf.fit <- ranger::ranger(Y_train ~ .,
data = Train1,
mtry = 14,
max.depth = 3,
replace = FALSE,
min.node.size = 40,
sample.fraction = 0.8,
respect.unordered.factors = "order",
importance = "permutation")

# SHAPR
p <- mean(Y_train)
library(shapr)

progressr::handlers(global = TRUE)
explanation <- shapr::explain(
rf.fit,
Test1,
Train1,
approach = "gaussian",
max_n_coalitions = 20,
iterative_args = list(initial_n_coalitions=20),
phi0 = p
)

library(ggplot2)
library(ggbeeswarm)

if (requireNamespace("ggplot2", quietly = TRUE)) {
plot(explanation, plot_type = "scatter")
plot(explanation, plot_type = "beeswarm")
}

saveRDS(explanation, "explanation.rds")



explanation <- readRDS("explanation.rds")

plot(explanation, plot_type = "beeswarm", corral = "wrap")


























tmp_list <- plot_shapr(explanation, plot_type = "beeswarm")


gg_old <- make_beeswarm_plot_old(dt_plot = tmp_list$dt_plot,
col = tmp_list$col,
index_x_explain = tmp_list$index_x_explain,
x = tmp_list$x,
factor_cols = tmp_list$factor_features)

gg_new_cex <- make_beeswarm_plot_new_cex(dt_plot = tmp_list$dt_plot,
col = tmp_list$col,
index_x_explain = tmp_list$index_x_explain,
x = tmp_list$x,
factor_cols = tmp_list$factor_features)

gg_new <- make_beeswarm_plot_new(dt_plot = tmp_list$dt_plot,
col = tmp_list$col,
index_x_explain = tmp_list$index_x_explain,
x = tmp_list$x,
factor_cols = tmp_list$factor_features,
corral.corral = "wrap", # Default. Other options: "none" (default in geom_beeswarm), "gutter", "random", "omit"
corral.method = "swarm", # Default (and default in geom_beeswarm). Other options: "compactswarm", "hex", "square", "center
corral.priority = "random", # Default . Other options: "ascending" (default in geom_beeswarm), "descending", "density"
corral.width = 0.75, # Default. 0.9 is default in geom_beeswarm
corral.cex = 0.75) # Default. 1 is default in geom_beeswarm

gg_paper3 <- make_beeswarm_plot_paper3(dt_plot = tmp_list$dt_plot,
col = tmp_list$col,
index_x_explain = tmp_list$index_x_explain,
x = tmp_list$x,
factor_cols = tmp_list$factor_features)

ggpubr::ggarrange(gg_old, gg_new_cex, gg_new, gg_paper3, labels = c("Old", "New_cex", "New", "Paper3"), nrow = 1, vjust = 2)
17 changes: 14 additions & 3 deletions man/plot.shapr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 7 additions & 7 deletions tests/testthat/_snaps/plot/bar-plot-default.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit a7efa3e

Please sign in to comment.