Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved sampling efficiency and conversion of coalitions to strings #426

Merged
merged 57 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
f94a8ee
Started woerking on the change from integers to string. A lot of smal…
LHBO Dec 3, 2024
c17ff3c
Change to whitespace seperated strings in Rcpp
LHBO Dec 11, 2024
94d5f60
Ensure even number of coalitions when paired_shap_sampling is TRUE
LHBO Dec 11, 2024
f5ea67c
Update next iteration parameters
LHBO Dec 11, 2024
5d74a3a
Rcpp
LHBO Dec 11, 2024
3491954
Updates to main shapley setup which now introduces strigns as the def…
LHBO Dec 11, 2024
168aee4
Started to simplify code
LHBO Dec 11, 2024
a13fb96
Removed `coalitions_tmp` as it is no longer needed as `coalitions_str…
LHBO Dec 11, 2024
8f76e8a
Wrong name in `dt_valid_causal_coalitions`.
LHBO Dec 11, 2024
7dde5a4
Change from comma separated to whitespace separated coaltions in the …
LHBO Dec 11, 2024
19a7cca
Rewritten the sampling procedure for asymmetric Shapley value to the …
LHBO Dec 11, 2024
cf57301
Added that timing inserts NA for missing values and do not count thos…
LHBO Dec 11, 2024
0e55361
Fix bug related to
LHBO Dec 13, 2024
fc1d705
Add reference to data.table
LHBO Dec 13, 2024
2dbbd48
Remove causal ordering stuff from forecast.
LHBO Dec 13, 2024
c69f43a
Removed old code
LHBO Dec 13, 2024
38e9b21
lintr + styler
LHBO Dec 16, 2024
884b258
Change max_n_coalitions due to model rank deficience.
LHBO Dec 16, 2024
db81ed7
Bug in `explain_forecast()` (#425)
LHBO Dec 17, 2024
f97058f
Started woerking on the change from integers to string. A lot of smal…
LHBO Dec 3, 2024
62fc8f0
Change to whitespace seperated strings in Rcpp
LHBO Dec 11, 2024
db6462e
Ensure even number of coalitions when paired_shap_sampling is TRUE
LHBO Dec 11, 2024
d5bc8ed
Update next iteration parameters
LHBO Dec 11, 2024
ed8b1e3
Rcpp
LHBO Dec 11, 2024
8ea705c
Updates to main shapley setup which now introduces strigns as the def…
LHBO Dec 11, 2024
6f62690
Started to simplify code
LHBO Dec 11, 2024
16d76ed
Removed `coalitions_tmp` as it is no longer needed as `coalitions_str…
LHBO Dec 11, 2024
4f97041
Wrong name in `dt_valid_causal_coalitions`.
LHBO Dec 11, 2024
c97ba7b
Change from comma separated to whitespace separated coaltions in the …
LHBO Dec 11, 2024
3b96f3e
Rewritten the sampling procedure for asymmetric Shapley value to the …
LHBO Dec 11, 2024
2ee8678
Added that timing inserts NA for missing values and do not count thos…
LHBO Dec 11, 2024
8d1f2b4
Fix bug related to
LHBO Dec 13, 2024
3feb670
Add reference to data.table
LHBO Dec 13, 2024
d063b26
Remove causal ordering stuff from forecast.
LHBO Dec 13, 2024
89a35fa
Removed old code
LHBO Dec 13, 2024
6766c6e
lintr + styler
LHBO Dec 16, 2024
78c48cd
Change max_n_coalitions due to model rank deficience.
LHBO Dec 16, 2024
42f4e74
Merge branch 'Lars/String_coalitions' of github.com:NorskRegnesentral…
LHBO Dec 17, 2024
a334a4f
Merge branch 'Lars/String_coalitions' of github.com:NorskRegnesentral…
LHBO Dec 17, 2024
b2cb3be
Merge branch 'Lars/String_coalitions' of github.com:NorskRegnesentral…
LHBO Dec 17, 2024
6af3c1e
Delete commented code
LHBO Dec 18, 2024
e0ec963
Updated manuals with the new parameters and manuals for new functions
LHBO Dec 18, 2024
51aff62
Add documentation to sample_coalition_table
LHBO Dec 18, 2024
ae6be2c
Moved documentation
LHBO Dec 18, 2024
472e815
Updated regular-setup test files
LHBO Dec 18, 2024
333d113
regular-output
LHBO Dec 18, 2024
19724a8
Forecast setup
LHBO Dec 18, 2024
8be683c
Asym-caus-output
LHBO Dec 18, 2024
bb7dc17
forecast output
LHBO Dec 18, 2024
90f659e
iterative output
LHBO Dec 18, 2024
7f22f6c
Add comment
LHBO Dec 18, 2024
dbc6bee
regression
LHBO Dec 18, 2024
2a487a9
Added bugfix to forecast such that we remeber the sampled coalitions …
LHBO Dec 19, 2024
d1172f7
Added global variables to zzz
LHBO Dec 19, 2024
0a05c4d
update forecast test rds files
LHBO Dec 19, 2024
aa73f75
Upadeted documentation
LHBO Dec 19, 2024
098c7de
Fix warning with argument
LHBO Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,16 @@ sample_features_cpp <- function(m, n_features) {
.Call(`_shapr_sample_features_cpp`, m, n_features)
}

#' We here return a vector of strings/characters, i.e., a CharacterVector,
#' where each string is a space-separated list of integers.
#' @param m Integer The number of elements to sample from, i.e., the number of features.
#' @param n_features IntegerVector The number of features to sample for each feature combination.
#' @param paired_shap_sampling Logical Should we return both the sampled coalition S and its complement Sbar.
#' @keywords internal
sample_features_cpp_str_paired <- function(m, n_features, paired_shap_sampling = TRUE) {
.Call(`_shapr_sample_features_cpp_str_paired`, m, n_features, paired_shap_sampling)
}

#' Get imputed data
#'
#' @param index_xtrain Positive integer. Represents a sequence of row indices from \code{xtrain},
Expand Down
14 changes: 10 additions & 4 deletions R/check_convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@ check_convergence <- function(internal) {
paired_shap_sampling <- internal$parameters$paired_shap_sampling
n_shapley_values <- internal$parameters$n_shapley_values

n_sampled_coalitions <- internal$iter_list[[iter]]$n_sampled_coalitions
exact <- internal$iter_list[[iter]]$exact

shap_names <- internal$parameters$shap_names
shap_names_with_none <- c("none", shap_names)

dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est
dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd

n_sampled_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2 # Subtract the zero and full predictions
if (!all.equal(names(dt_shapley_est), names(dt_shapley_sd))) {
stop("The column names of the dt_shapley_est and dt_shapley_df are not equal.")
}

max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = -1, by = .I]$V1 # Max per prediction
max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = shap_names_with_none, by = .I]$V1 # Max per prediction
max_sd0 <- max_sd * sqrt(n_sampled_coalitions) # Scales UP the sd as it scales at this rate

dt_shapley_est0 <- copy(dt_shapley_est)
Expand All @@ -33,8 +39,8 @@ check_convergence <- function(internal) {
} else {
converged_exact <- FALSE
if (!is.null(convergence_tol)) {
dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = shap_names, by = .I]
dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = shap_names, by = .I]
dt_shapley_est0[, max_sd0 := max_sd0]
dt_shapley_est0[, req_samples := (max_sd0 / ((maxval - minval) * convergence_tol))^2]
dt_shapley_est0[, conv_measure := max_sd0 / ((maxval - minval) * sqrt(n_sampled_coalitions))]
Expand Down
14 changes: 8 additions & 6 deletions R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
dt_vS_this <- dt_vS[, dt_cols, with = FALSE]
result[[i]] <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS_this, n_boot_samps, seed)
}
result <- rbindlist(result, fill = TRUE)
result <- cbind(internal$parameters$output_labels, rbindlist(result, fill = TRUE))
} else {
X <- internal$iter_list[[iter]]$X
n_shapley_values <- internal$parameters$n_shapley_values
Expand All @@ -310,8 +310,10 @@ bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, d
boot_sd_array <- array(NA, dim = c(n_explain, n_shapley_values + 1, n_boot_samps))

X_keep <- X_org[c(1, .N), .(id_coalition, coalitions, coalition_size, N)]
X_samp <- X_org[-c(1, .N), .(id_coalition, coalitions, coalition_size, N, shapley_weight, sample_freq)]
X_samp[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")]
X_samp <- X_org[
-c(1, .N),
.(id_coalition, coalitions, coalitions_str, coalition_size, N, shapley_weight, sample_freq)
]

n_coalitions_boot <- X_samp[, sum(sample_freq)]

Expand All @@ -331,12 +333,12 @@ bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, d

X_boot00_paired <- copy(X_boot00[, .(coalitions, boot_id)])
X_boot00_paired[, coalitions := lapply(coalitions, function(x) seq(n_shapley_values)[-x])]
X_boot00_paired[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")]
X_boot00_paired[, coalitions_str := sapply(coalitions, paste, collapse = " ")]

# Extract the paired coalitions from X_samp
X_boot00_paired <- merge(X_boot00_paired,
X_samp[, .(id_coalition, coalition_size, N, shapley_weight, coalitions_tmp)],
by = "coalitions_tmp"
X_samp[, .(id_coalition, coalition_size, N, shapley_weight, coalitions_str)],
by = "coalitions_str"
)
X_boot0 <- rbind(
X_boot00[, .(boot_id, id_coalition, coalitions, coalition_size, N)],
Expand Down
5 changes: 4 additions & 1 deletion R/prepare_next_iteration.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ prepare_next_iteration <- function(internal) {

est_remaining_coalitions <- internal$iter_list[[iter]]$est_remaining_coalitions
n_coal_next_iter_factor <- internal$iter_list[[iter]]$n_coal_next_iter_factor
current_n_coalitions <- internal$iter_list[[iter]]$n_coalitions
current_n_coalitions <- internal$iter_list[[iter]]$n_sampled_coalitions + 2 # Used instead of n_coalitions to
# deal with forecast special case
current_coal_samples <- internal$iter_list[[iter]]$coal_samples
current_coal_samples_n_unique <- internal$iter_list[[iter]]$coal_samples_n_unique

if (is.null(fixed_n_coalitions_per_iter)) {
proposal_next_n_coalitions <- current_n_coalitions + ceiling(est_remaining_coalitions * n_coal_next_iter_factor)
Expand Down Expand Up @@ -69,6 +71,7 @@ prepare_next_iteration <- function(internal) {


next_iter_list$prev_coal_samples <- current_coal_samples
next_iter_list$prev_coal_samples_n_unique <- current_coal_samples_n_unique
} else {
next_iter_list <- list()
}
Expand Down
9 changes: 4 additions & 5 deletions R/print_iter.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,9 @@ print_iter <- function(internal) {
}

if ("shapley" %in% verbose) {
n_explain <- internal$parameters$n_explain

dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est[, -1]
dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd[, -1]
shap_names_with_none <- c("none", internal$parameters$shap_names)
dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est[, shap_names_with_none, with = FALSE]
dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd[, shap_names_with_none, with = FALSE]

# Printing the current Shapley values
matrix1 <- format(round(dt_shapley_est, 3), nsmall = 2, justify = "right")
Expand All @@ -99,7 +98,7 @@ print_iter <- function(internal) {
print_dt <- as.data.table(matrix1)
} else {
msg <- paste0(msg, "estimated Shapley values (sd)")
print_dt <- as.data.table(matrix(paste(matrix1, " (", matrix2, ") ", sep = ""), nrow = n_explain))
print_dt <- as.data.table(matrix(paste(matrix1, " (", matrix2, ") ", sep = ""), nrow = nrow(matrix1)))
}

cli::cli_h3(msg)
Expand Down
54 changes: 29 additions & 25 deletions R/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,6 @@ get_extra_parameters <- function(internal, type) {
internal$parameters$n_groups <- length(group)
internal$parameters$group_names <- names(group)
internal$parameters$group <- group
internal$parameters$n_shapley_values <- internal$parameters$n_groups

if (type == "forecast") {
if (internal$parameters$group_lags) {
Expand All @@ -543,8 +542,9 @@ get_extra_parameters <- function(internal, type) {
internal$parameters$n_groups <- NULL
internal$parameters$group_names <- NULL
internal$parameters$shap_names <- internal$parameters$feature_names
internal$parameters$n_shapley_values <- internal$parameters$n_features
}
internal$parameters$n_shapley_values <- length(internal$parameters$shap_names)


# Get the number of unique approaches
internal$parameters$n_approaches <- length(internal$parameters$approach)
Expand Down Expand Up @@ -616,7 +616,7 @@ check_and_set_parameters <- function(internal, type) {

# Check the arguments related to asymmetric and causal Shapley
# Check the causal_ordering, which must happen before checking the causal sampling
internal <- check_and_set_causal_ordering(internal)
if (type == "normal") internal <- check_and_set_causal_ordering(internal)
if (!is.null(internal$parameters$confounding)) internal <- check_and_set_confounding(internal)

# Check the causal sampling
Expand Down Expand Up @@ -798,7 +798,7 @@ check_and_set_asymmetric <- function(internal) {
internal$objects$dt_valid_causal_coalitions[-c(1, .N), shapley_weight_norm := shapley_weight / sum(shapley_weight)]

# Convert the coalitions to strings. Needed when sampling the coalitions in `sample_coalition_table()`.
internal$objects$dt_valid_causal_coalitions[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")]
internal$objects$dt_valid_causal_coalitions[, coalitions_str := sapply(coalitions, paste, collapse = " ")]

return(internal)
}
Expand Down Expand Up @@ -898,36 +898,36 @@ adjust_max_n_coalitions <- function(internal) {
}
} else { # group wise
# Set max_n_coalitions to upper bound
if (is.null(max_n_coalitions) || max_n_coalitions > 2^n_groups) {
max_n_coalitions <- 2^n_groups
if (is.null(max_n_coalitions) || max_n_coalitions > 2^n_shapley_values) {
max_n_coalitions <- 2^n_shapley_values
message(
paste0(
"Success with message:\n",
"max_n_coalitions is NULL or larger than or 2^n_groups = ", 2^n_groups, ", \n",
"and is therefore set to 2^n_groups = ", 2^n_groups, ".\n"
"max_n_coalitions is NULL or larger than or 2^n_groups = ", 2^n_shapley_values, ", \n",
"and is therefore set to 2^n_groups = ", 2^n_shapley_values, ".\n"
)
)
}
# Set max_n_coalitions to lower bound
if (isFALSE(is.null(max_n_coalitions)) && max_n_coalitions < min(10, n_groups + 1)) {
if (n_groups <= 3) {
max_n_coalitions <- 2^n_groups
if (isFALSE(is.null(max_n_coalitions)) && max_n_coalitions < min(10, n_shapley_values + 1)) {
if (n_shapley_values <= 3) {
max_n_coalitions <- 2^n_shapley_values
message(
paste0(
"Success with message:\n",
"n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (", 2^n_groups, ") ",
"that we should use all to get reliable results.\n",
"max_n_coalitions is therefore set to 2^n_groups = ", 2^n_groups, ".\n"
"n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (",
2^n_shapley_values, ") that we should use all to get reliable results.\n",
"max_n_coalitions is therefore set to 2^n_groups = ", 2^n_shapley_values, ".\n"
)
)
} else {
max_n_coalitions <- min(10, n_groups + 1)
max_n_coalitions <- min(10, n_shapley_values + 1)
message(
paste0(
"Success with message:\n",
"max_n_coalitions is smaller than max(10, n_groups + 1 = ", n_groups + 1, "),",
"max_n_coalitions is smaller than max(10, n_groups + 1 = ", n_shapley_values + 1, "),",
"which will result in unreliable results.\n",
"It is therefore set to ", max(10, n_groups + 1), ".\n"
"It is therefore set to ", max(10, n_shapley_values + 1), ".\n"
)
)
}
Expand All @@ -943,6 +943,7 @@ check_max_n_coalitions_fc <- function(internal) {
max_n_coalitions <- internal$parameters$max_n_coalitions
n_features <- internal$parameters$n_features
n_groups <- internal$parameters$n_groups
n_shapley_values <- internal$parameters$n_shapley_values

type <- internal$parameters$type

Expand All @@ -953,7 +954,7 @@ check_max_n_coalitions_fc <- function(internal) {
xreg <- internal$data$xreg

if (!is_groupwise) {
if (max_n_coalitions <= n_features) {
if (max_n_coalitions <= n_shapley_values) {
stop(paste0(
"`max_n_coalitions` (", max_n_coalitions, ") has to be greater than the number of ",
"components to decompose the forecast onto:\n",
Expand All @@ -962,7 +963,7 @@ check_max_n_coalitions_fc <- function(internal) {
))
}
} else {
if (max_n_coalitions <= n_groups) {
if (max_n_coalitions <= n_shapley_values) {
stop(paste0(
"`max_n_coalitions` (", max_n_coalitions, ") has to be greater than the number of ",
"components to decompose the forecast onto:\n",
Expand Down Expand Up @@ -1168,18 +1169,15 @@ check_and_set_iterative <- function(internal) {

set_exact <- function(internal) {
max_n_coalitions <- internal$parameters$max_n_coalitions
n_features <- internal$parameters$n_features
n_groups <- internal$parameters$n_groups
is_groupwise <- internal$parameters$is_groupwise
n_shapley_values <- internal$parameters$n_shapley_values
iterative <- internal$parameters$iterative
asymmetric <- internal$parameters$asymmetric
max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal

if (isFALSE(iterative) &&
(
(isTRUE(asymmetric) && max_n_coalitions == max_n_coalitions_causal) ||
(isFALSE(is_groupwise) && max_n_coalitions == 2^n_features) ||
(isTRUE(is_groupwise) && max_n_coalitions == 2^n_groups)
(max_n_coalitions == 2^n_shapley_values)
)
) {
exact <- TRUE
Expand Down Expand Up @@ -1455,6 +1453,11 @@ set_iterative_parameters <- function(internal, prev_iter_list = NULL) {
iterative_args$initial_n_coalitions <- iterative_args$max_n_coalitions
}

# If paired_shap_sampling is TRUE, we need the number of coalitions to be even
if (internal$parameters$paired_shap_sampling) {
iterative_args$initial_n_coalitions <- ceiling(iterative_args$initial_n_coalitions * 0.5) * 2
}

check_iterative_args(iterative_args)

# Translate any null input
Expand Down Expand Up @@ -1644,7 +1647,8 @@ get_iterative_args_default <- function(internal,
5,
internal$parameters$n_features,
(2^internal$parameters$n_features) / 10
)
),
internal$parameters$max_n_coalitions
)
),
fixed_n_coalitions_per_iter = NULL,
Expand Down
Loading
Loading