Skip to content

Commit

Permalink
Improved sampling efficiency and conversion of coalitions to strings (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
LHBO authored Dec 19, 2024
1 parent db81ed7 commit f89ead4
Show file tree
Hide file tree
Showing 116 changed files with 977 additions and 580 deletions.
10 changes: 10 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,16 @@ sample_features_cpp <- function(m, n_features) {
.Call(`_shapr_sample_features_cpp`, m, n_features)
}

#' We here return a vector of strings/characters, i.e., a CharacterVector,
#' where each string is a space-separated list of integers.
#' @param m Integer The number of elements to sample from, i.e., the number of features.
#' @param n_features IntegerVector The number of features to sample for each feature combination.
#' @param paired_shap_sampling Logical Should we return both the sampled coalition S and its complement Sbar.
#' @keywords internal
sample_features_cpp_str_paired <- function(m, n_features, paired_shap_sampling = TRUE) {
.Call(`_shapr_sample_features_cpp_str_paired`, m, n_features, paired_shap_sampling)
}

#' Get imputed data
#'
#' @param index_xtrain Positive integer. Represents a sequence of row indices from \code{xtrain},
Expand Down
12 changes: 7 additions & 5 deletions R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,10 @@ bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, d
boot_sd_array <- array(NA, dim = c(n_explain, n_shapley_values + 1, n_boot_samps))

X_keep <- X_org[c(1, .N), .(id_coalition, coalitions, coalition_size, N)]
X_samp <- X_org[-c(1, .N), .(id_coalition, coalitions, coalition_size, N, shapley_weight, sample_freq)]
X_samp[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")]
X_samp <- X_org[
-c(1, .N),
.(id_coalition, coalitions, coalitions_str, coalition_size, N, shapley_weight, sample_freq)
]

n_coalitions_boot <- X_samp[, sum(sample_freq)]

Expand All @@ -331,12 +333,12 @@ bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, d

X_boot00_paired <- copy(X_boot00[, .(coalitions, boot_id)])
X_boot00_paired[, coalitions := lapply(coalitions, function(x) seq(n_shapley_values)[-x])]
X_boot00_paired[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")]
X_boot00_paired[, coalitions_str := sapply(coalitions, paste, collapse = " ")]

# Extract the paired coalitions from X_samp
X_boot00_paired <- merge(X_boot00_paired,
X_samp[, .(id_coalition, coalition_size, N, shapley_weight, coalitions_tmp)],
by = "coalitions_tmp"
X_samp[, .(id_coalition, coalition_size, N, shapley_weight, coalitions_str)],
by = "coalitions_str"
)
X_boot0 <- rbind(
X_boot00[, .(boot_id, id_coalition, coalitions, coalition_size, N)],
Expand Down
2 changes: 2 additions & 0 deletions R/prepare_next_iteration.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ prepare_next_iteration <- function(internal) {
current_n_coalitions <- internal$iter_list[[iter]]$n_sampled_coalitions + 2 # Used instead of n_coalitions to
# deal with forecast special case
current_coal_samples <- internal$iter_list[[iter]]$coal_samples
current_coal_samples_n_unique <- internal$iter_list[[iter]]$coal_samples_n_unique

if (is.null(fixed_n_coalitions_per_iter)) {
proposal_next_n_coalitions <- current_n_coalitions + ceiling(est_remaining_coalitions * n_coal_next_iter_factor)
Expand Down Expand Up @@ -70,6 +71,7 @@ prepare_next_iteration <- function(internal) {


next_iter_list$prev_coal_samples <- current_coal_samples
next_iter_list$prev_coal_samples_n_unique <- current_coal_samples_n_unique
} else {
next_iter_list <- list()
}
Expand Down
12 changes: 9 additions & 3 deletions R/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ check_and_set_parameters <- function(internal, type) {

# Check the arguments related to asymmetric and causal Shapley
# Check the causal_ordering, which must happen before checking the causal sampling
internal <- check_and_set_causal_ordering(internal)
if (type == "normal") internal <- check_and_set_causal_ordering(internal)
if (!is.null(internal$parameters$confounding)) internal <- check_and_set_confounding(internal)

# Check the causal sampling
Expand Down Expand Up @@ -798,7 +798,7 @@ check_and_set_asymmetric <- function(internal) {
internal$objects$dt_valid_causal_coalitions[-c(1, .N), shapley_weight_norm := shapley_weight / sum(shapley_weight)]

# Convert the coalitions to strings. Needed when sampling the coalitions in `sample_coalition_table()`.
internal$objects$dt_valid_causal_coalitions[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")]
internal$objects$dt_valid_causal_coalitions[, coalitions_str := sapply(coalitions, paste, collapse = " ")]

return(internal)
}
Expand Down Expand Up @@ -1453,6 +1453,11 @@ set_iterative_parameters <- function(internal, prev_iter_list = NULL) {
iterative_args$initial_n_coalitions <- iterative_args$max_n_coalitions
}

# If paired_shap_sampling is TRUE, we need the number of coalitions to be even
if (internal$parameters$paired_shap_sampling) {
iterative_args$initial_n_coalitions <- ceiling(iterative_args$initial_n_coalitions * 0.5) * 2
}

check_iterative_args(iterative_args)

# Translate any null input
Expand Down Expand Up @@ -1642,7 +1647,8 @@ get_iterative_args_default <- function(internal,
5,
internal$parameters$n_features,
(2^internal$parameters$n_features) / 10
)
),
internal$parameters$max_n_coalitions
)
),
fixed_n_coalitions_per_iter = NULL,
Expand Down
Loading

0 comments on commit f89ead4

Please sign in to comment.