Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-6698: Shapley values support for ensemble models #15734

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
26045d4
Shap for GLM
tomasfryda Jun 30, 2023
2d12105
Add bSHAP for GLM
tomasfryda Jul 14, 2023
1558c65
Interventional treeSHAP support
tomasfryda Jul 31, 2023
c30281e
GLM related edits
tomasfryda Aug 8, 2023
dc07d29
Initial DeepShap
tomasfryda Aug 22, 2023
66200b4
DeepSHAP classification
tomasfryda Aug 22, 2023
987a1f8
DeepSHAP for tanh and relu with dropout
tomasfryda Aug 23, 2023
77da787
DeepSHAP with MaxOut
tomasfryda Aug 28, 2023
1ef3872
Add output_space option
tomasfryda Aug 29, 2023
2dc9772
Add Stacked Ensemble support
tomasfryda Aug 30, 2023
8bd5dae
Fix stacked ensemble with metalearner transform
tomasfryda Aug 31, 2023
feb771e
Add automl tests
tomasfryda Sep 1, 2023
9bc2425
Add aggregation to get the SHAP not just per reference
tomasfryda Sep 5, 2023
a716fb5
Fix tests
tomasfryda Sep 6, 2023
e5cbce3
Fix python tests
tomasfryda Sep 6, 2023
f115f6b
Fix python tests
tomasfryda Sep 8, 2023
942502d
Enable use in R client
tomasfryda Sep 8, 2023
6f7c05d
Add R and basic R support
tomasfryda Sep 12, 2023
dbe3cbf
Add python explain support
tomasfryda Sep 12, 2023
fdc1891
Add explain support and tests
tomasfryda Sep 14, 2023
5f55eac
R formatting around equal sign
tomasfryda Sep 14, 2023
b825b83
Add java tests
tomasfryda Sep 15, 2023
fc04afc
Add minimal documentation
tomasfryda Sep 15, 2023
3c10825
Fix different distributions with output_space=True
tomasfryda Sep 18, 2023
b98e937
Fix tweedie GLM link in python test
tomasfryda Sep 19, 2023
1926964
Check for memory
tomasfryda Sep 20, 2023
dafb53e
Fix multi-node issue and improve memory usage
tomasfryda Sep 27, 2023
5d5db7c
Incorporate Veronika's suggestions
tomasfryda Sep 28, 2023
662c25f
Improve tests
tomasfryda Sep 29, 2023
d47b9dd
Make java tests less strict (eps = 1e-8 -> 1e-6)
tomasfryda Oct 2, 2023
89f065c
Merge branch 'master' into tomf_GH-6698_-shapley-values-support-for-e…
tomasfryda Oct 2, 2023
94d4f6c
Merge branch 'master' into tomf_GH-6698_-shapley-values-support-for-e…
tomasfryda Oct 3, 2023
a7fa460
Fix R tests/cran checks
tomasfryda Oct 3, 2023
ac0cacc
Merge branch 'master' into tomf_GH-6698_-shapley-values-support-for-e…
tomasfryda Oct 3, 2023
6592b0e
Improve parallelization for SE when there is just a small bg set + do…
tomasfryda Oct 4, 2023
869bf19
Skip NOPASS tests in the Py3.7 Changed Only stage
tomasfryda Oct 5, 2023
0085fb4
Unify error messages when no background frame is provided
tomasfryda Oct 5, 2023
c8bcd52
Add tests with pregenerated DeepSHAP contributions using pytorch+shap
tomasfryda Oct 6, 2023
7443265
Add caveat about original output format to doc
tomasfryda Oct 6, 2023
673cd47
Fix formatting in documentation
tomasfryda Oct 6, 2023
112aa22
Remove duplicate 'between' in the doc
tomasfryda Oct 6, 2023
1fde594
Merge branch 'master' into tomf_GH-6698_-shapley-values-support-for-e…
tomasfryda Oct 6, 2023
b4b7029
ht/added references & readability updates
hannah-tillman Oct 6, 2023
fad0f8d
Merge branch 'master' into tomf_GH-6698_-shapley-values-support-for-e…
tomasfryda Oct 9, 2023
6380c7d
Make R tests less strict to make tests more stable
tomasfryda Oct 9, 2023
218cefa
Fix hex/genmodel/algos/tree/ContributionsPredictorTest.java test by "…
tomasfryda Oct 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions h2o-algos/src/main/java/hex/ContributionsMeanAggregator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package hex;

import water.Job;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;

import java.util.stream.Stream;

public class ContributionsMeanAggregator extends MRTask<ContributionsMeanAggregator> {
final int _nBgRows;
double[][] _partialSums;
final int _rowIdxIdx;
final int _nRows;
final int _nCols;
int _startIndex;
final Job _j;

public ContributionsMeanAggregator(Job j, int nRows, int nCols, int nBgRows) {
_j = j;
_nRows = nRows;
_nCols = nCols;
_rowIdxIdx = nCols;
_nBgRows = nBgRows;
_startIndex = 0;
}

public ContributionsMeanAggregator setStartIndex(int startIndex) {
_startIndex = startIndex;
return this;
}

@Override
public void map(Chunk[] cs, NewChunk[] ncs) {
if (isCancelled() || null != _j && _j.stop_requested()) return;
_partialSums = MemoryManager.malloc8d(_nRows, _nCols);
for (int i = 0; i < cs[0]._len; i++) {
final int rowIdx = (int) cs[_rowIdxIdx].at8(i);

for (int j = 0; j < _nCols; j++) {
_partialSums[rowIdx - _startIndex][j] += cs[j].atd(i);
}
}
}

@Override
public void reduce(ContributionsMeanAggregator mrt) {
for (int i = 0; i < _partialSums.length; i++) {
for (int j = 0; j < _partialSums[0].length; j++) {
_partialSums[i][j] += mrt._partialSums[i][j];
}
}
mrt._partialSums = null;
}

@Override
protected void postGlobal() {
NewChunk[] ncs = Stream.of(appendables()).map(vec -> vec.chunkForChunkIdx(0)).toArray(NewChunk[]::new);
for (int i = 0; i < _partialSums.length; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you just call vec.mean()? I know we may run into problems with categorical columns though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. I need something like frame.groupby("RowIdx").mean(). The number of rows is nrow(frame)*nrow(background_frame) and the result should have nrow(frame) rows.

for (int j = 0; j < _partialSums[0].length; j++) {
ncs[j].addNum(_partialSums[i][j] / _nBgRows);
}
}
_partialSums = null;
for (NewChunk nc : ncs)
nc.close(0, _fs);
}
}
253 changes: 253 additions & 0 deletions h2o-algos/src/main/java/hex/ContributionsWithBackgroundFrameTask.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
package hex;

import water.*;
import water.fvec.*;
import water.util.Log;
import water.util.fp.Function;

import java.util.*;
import java.util.stream.IntStream;

import static water.SplitToChunksApplyCombine.concatFrames;

/***
* Calls map(Chunk[] frame, Chunk[] background, NewChunk[] ncs) by copying the smaller frame across the nodes.
* @param <T>
*/
public abstract class ContributionsWithBackgroundFrameTask<T extends ContributionsWithBackgroundFrameTask<T>> extends MRTask<T> {
transient Frame _frame;
transient Frame _backgroundFrame;
Key<Frame> _frameKey;
Key<Frame> _backgroundFrameKey;

final boolean _aggregate;

boolean _isFrameBigger;

long _startRow;
long _endRow;
Job _job;

public ContributionsWithBackgroundFrameTask(Key<Frame> frKey, Key<Frame> backgroundFrameKey, boolean perReference) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what this class is for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to be able to calculate baseline SHAP for each point from frame with each point from background_frame as the reference.

in pseudo-code this class is approximately responsible for:

result = Frame()
for x in frame:
    for bg in background_frame:
        result.append(baselineSHAP(x, bg)

This yields contribution for all the points against all the references. It's not very useful by itself but it's required for calculating the SHAP for Stacked Ensembles. Some other uses can be for example explaining ensemble of models where the models are proprietary and can't be shared, e.g., from one company you get fraud score, from another credit score and you input these values to your model, those companies can share the baseline shap with you and you can find calculate contributions for the input features without knowing the fraud/credit models. (This example is taken from https://www.nature.com/articles/s41467-022-31384-3)

These baseline shap values are often designated as $\phi_i(model, x^{(e)}, x^{(b)})$ where $x^{(e)}$ is the row that we want to explain and $x^{(b)}$ is the reference.

Note the commonly used SHAP values are $$\phi_i(model, x^{(e)}) = \frac{1}{|D|}\sum_{x^{(b)} \in D} \phi_i(model, x^{(e)}, x^{(b)})$$ and this is called marginal SHAP (in older papers interventional SHAP).

assert null != frKey.get();
assert null != backgroundFrameKey.get();

_frameKey = frKey;
_backgroundFrameKey = backgroundFrameKey;

_frame = frKey.get();
_backgroundFrame = backgroundFrameKey.get();
assert _frame.numRows() > 0 : "Frame has to contain at least one row.";
assert _backgroundFrame.numRows() > 0 : "Background frame has to contain at least one row.";

_isFrameBigger = _frame.numRows() > _backgroundFrame.numRows();
_aggregate = !perReference;
_startRow = -1;
_endRow = -1;
}

protected void loadFrames() {
if (null == _frame)
_frame = _frameKey.get();
if (null == _backgroundFrame)
_backgroundFrame = _backgroundFrameKey.get();
assert _frame != null && _backgroundFrame != null;
}

@Override
public void map(Chunk[] cs, NewChunk[] ncs) {
loadFrames();
Frame smallerFrame = _isFrameBigger ? _backgroundFrame : _frame;
long sfIdx = 0;
long maxSfIdx = smallerFrame.numRows();
if (!_isFrameBigger && _startRow != -1 && _endRow != -1) {
sfIdx = _startRow;
maxSfIdx = _endRow;
}

while (sfIdx < maxSfIdx) {
if (isCancelled() || null != _job && _job.stop_requested()) return;

long finalSfIdx = sfIdx;
Chunk[] sfCs = IntStream
.range(0, smallerFrame.numCols())
.mapToObj(col -> smallerFrame.vec(col).chunkForRow(finalSfIdx))
.toArray(Chunk[]::new);
NewChunk[] ncsSlice = Arrays.copyOf(ncs, ncs.length - 2);
if (_isFrameBigger) {
map(cs, sfCs, ncsSlice);
for (int i = 0; i < cs[0]._len; i++) {
for (int j = 0; j < sfCs[0]._len; j++) {
ncs[ncs.length - 2].addNum(cs[0].start() + i); // row idx
ncs[ncs.length - 1].addNum(sfCs[0].start() + j); // background idx
}
}
} else {
map(sfCs, cs, ncsSlice);
for (int i = 0; i < sfCs[0]._len; i++) {
for (int j = 0; j < cs[0]._len; j++) {
ncs[ncs.length - 2].addNum(sfCs[0].start() + i); // row idx
ncs[ncs.length - 1].addNum(cs[0].start() + j); // background idx
}
}
}
sfIdx += sfCs[0]._len;
}
}

public static double estimateRequiredMemory(int nCols, Frame frame, Frame backgroundFrame) {
return 8 * nCols * frame.numRows() * backgroundFrame.numRows();
}

public static double estimatePerNodeMinimalMemory(int nCols, Frame frame, Frame backgroundFrame){
boolean isFrameBigger = frame.numRows() > backgroundFrame.numRows();
double reqMem = estimateRequiredMemory(nCols, frame, backgroundFrame);
Frame biggerFrame = isFrameBigger ? frame : backgroundFrame;
long[] frESPC = biggerFrame.anyVec().espc();
// Guess the max size of the chunk from the bigger frame as 2 * average chunk
double maxMinChunkSizeInVectorGroup = 2 * 8 * nCols * biggerFrame.numRows() / (double) biggerFrame.anyVec().nChunks();

// Try to compute it exactly
if (null != frESPC) {
long maxFr = 0;
for (int i = 0; i < frESPC.length-1; i++) {
maxFr = Math.max(maxFr, frESPC[i+1]-frESPC[i]);
}
maxMinChunkSizeInVectorGroup = Math.max(maxMinChunkSizeInVectorGroup, 8*nCols*maxFr);
}
long nRowsOfSmallerFrame = isFrameBigger ? backgroundFrame.numRows() : frame.numRows();

// We need the whole smaller frame on each node and one chunk per col of the bigger frame (at minimum)
return Math.max(reqMem / H2O.CLOUD._memary.length, maxMinChunkSizeInVectorGroup + nRowsOfSmallerFrame * nCols * 8);
}

double estimatePerNodeMinimalMemory(int nCols) {
return estimatePerNodeMinimalMemory(nCols, _frame, _backgroundFrame);
}


public static long minMemoryPerNode() {
long minMem = Long.MAX_VALUE;
for (H2ONode h2o : H2O.CLOUD._memary) {
long mem = h2o._heartbeat.get_free_mem(); // in bytes
if (mem < minMem)
minMem = mem;
}
return minMem;
}

public static long totalFreeMemory() {
long mem = 0;
for (H2ONode h2o : H2O.CLOUD._memary) {
mem += h2o._heartbeat.get_free_mem(); // in bytes
}
return mem;
}

public static boolean enoughMinMemory(double estimatedMemory) {
return minMemoryPerNode() > estimatedMemory;
}

abstract protected void map(Chunk[] cs, Chunk[] bgCs, NewChunk[] ncs);


void setChunkRange(int startCIdx, int endCIdx) {
assert !_isFrameBigger;
_startRow = _frame.anyVec().chunkForChunkIdx(startCIdx).start();
_endRow = _frame.anyVec().chunkForChunkIdx(endCIdx).start() + _frame.anyVec().chunkForChunkIdx(endCIdx)._len;
}


// takes care of mapping over the bigger frame
public Frame runAndGetOutput(Job j, Key<Frame> destinationKey, String[] names) {
_job = j;
loadFrames();
double reqMem = estimateRequiredMemory(names.length + 2, _frame, _backgroundFrame);
double reqPerNodeMem = estimatePerNodeMinimalMemory(names.length + 2);

String[] namesWithRowIdx = new String[names.length + 2];
System.arraycopy(names, 0, namesWithRowIdx, 0, names.length);
namesWithRowIdx[names.length] = "RowIdx";
namesWithRowIdx[names.length + 1] = "BackgroundRowIdx";
Key<Frame> individualContributionsKey = _aggregate ? Key.make(destinationKey + "_individual_contribs") : destinationKey;

if (!_aggregate) {
if (!enoughMinMemory(reqPerNodeMem)) {
throw new RuntimeException("Not enough memory. Estimated minimal total memory is " + reqMem + "B. " +
"Estimated minimal per node memory (assuming perfectly balanced datasets) is " + reqPerNodeMem + "B. " +
"Node with minimum memory has " + minMemoryPerNode() + "B. Total available memory is " + totalFreeMemory() + "B."
);
}
Frame indivContribs = withPostMapAction(JobUpdatePostMap.forJob(j))
.doAll(namesWithRowIdx.length, Vec.T_NUM, _isFrameBigger ? _frame : _backgroundFrame)
.outputFrame(individualContributionsKey, namesWithRowIdx, null);

return indivContribs;
} else {
if (!enoughMinMemory(reqPerNodeMem)) {
if (minMemoryPerNode() < 5 * (names.length + 2) * _frame.numRows() * 8) {
throw new RuntimeException("Not enough memory. Estimated minimal total memory is " + reqMem + "B. " +
"Estimated minimal per node memory (assuming perfectly balanced datasets) is " + reqPerNodeMem + "B. " +
"Node with minimum memory has " + minMemoryPerNode() + "B. Total available memory is " + totalFreeMemory() + "B."
);
}
// Split the _frame in subsections and calculate baselines (expand the frame) and then the average (reduce the frame sieze)
int nChunks = _frame.anyVec().nChunks();
// last iteration we need memory for ~whole aggregated frame + expanded subframe
int nSubFrames = (int) Math.ceil(2*reqMem / (minMemoryPerNode() - 8 * _frame.numRows() * (names.length)));
nSubFrames = nChunks;
int chunksPerIter = (int) Math.max(1, Math.floor(nChunks / nSubFrames));

Log.warn("Not enough memory to calculate SHAP at once. Calculating in " + (nSubFrames) + " iterations.");
_isFrameBigger = false; // ensure we map over the BG frame so we can average over the results properly;
Frame result = null;
List<Frame> subFrames = new LinkedList<Frame>();
try {
for (int i = 0; i < nSubFrames; i++) {
setChunkRange(i * chunksPerIter, Math.min(nChunks - 1, (i + 1) * chunksPerIter - 1));
Frame indivContribs = clone().withPostMapAction(JobUpdatePostMap.forJob(j))
.doAll(namesWithRowIdx.length, Vec.T_NUM, _backgroundFrame)
.outputFrame(Key.make(destinationKey + "_individual_contribs_" + i), namesWithRowIdx, null);

subFrames.add(new ContributionsMeanAggregator(_job,(int) (_endRow - _startRow), names.length, (int) _backgroundFrame.numRows())
.setStartIndex((int) _startRow)
.withPostMapAction(JobUpdatePostMap.forJob(j))
.doAll(names.length, Vec.T_NUM, indivContribs)
.outputFrame(Key.make(destinationKey + "_part_" + i), names, null));
indivContribs.delete();
}

result = concatFrames(subFrames, destinationKey);
Set<String> homes = new HashSet<>();
for (int i = 0; i < result.anyVec().nChunks(); i++) {
for (int k = 0; k < result.numCols(); k++) {
homes.add(result.vec(k).chunkKey(i).home_node().getIpPortString());
}
}
return result;
} finally {
if (null != result) {
for (Frame fr : subFrames) {
Frame.deleteTempFrameAndItsNonSharedVecs(fr, result);
}
} else {
for (Frame fr : subFrames)
fr.delete();
}
}
} else {
Frame indivContribs = withPostMapAction(JobUpdatePostMap.forJob(j))
.doAll(namesWithRowIdx.length, Vec.T_NUM, _isFrameBigger ? _frame : _backgroundFrame)
.outputFrame(individualContributionsKey, namesWithRowIdx, null);
try {
return new ContributionsMeanAggregator(_job, (int) _frame.numRows(), names.length, (int) _backgroundFrame.numRows())
.withPostMapAction(JobUpdatePostMap.forJob(j))
.doAll(names.length, Vec.T_NUM, indivContribs)
.outputFrame(destinationKey, names, null);
} finally {
indivContribs.delete(true);
}
}
}
}
}
31 changes: 24 additions & 7 deletions h2o-algos/src/main/java/hex/DataInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

import static water.util.ArrayUtils.findLongestCommonPrefix;

Expand Down Expand Up @@ -823,12 +824,11 @@ public final String[] coefNames() {
return res;
}

public final int[] coefOriginalColumnIndices() {
if (_coefOriginalIndices != null) return _coefOriginalIndices; // already computed
public final int[] coefOriginalColumnIndices(Frame adaptedFrame) {
int k = 0;
final int n = fullN(); // total number of columns to compute
int[] res = new int[n];
final Vec [] vecs = _adaptedFrame.vecs();
final Vec [] vecs = adaptedFrame.vecs();

// first do all of the expanded categorical names
for(int i = 0; i < _cats; ++i) {
Expand Down Expand Up @@ -868,19 +868,32 @@ public final int[] coefOriginalColumnIndices() {
res[k++] = i+_cats;
}
}
_coefOriginalIndices = res;
if (null != _adaptedFrame && Objects.equals(_adaptedFrame._key, adaptedFrame._key))
_coefOriginalIndices = res;
return res;
}

public final int[] coefOriginalColumnIndices() {
if (_coefOriginalIndices != null) return _coefOriginalIndices; // already computed
return coefOriginalColumnIndices(_adaptedFrame);
}

public final String[] coefOriginalNames() {
int[] coefOriginalIndices = coefOriginalColumnIndices();
String[] originalNames = new String[coefOriginalIndices[coefOriginalIndices.length - 1]];
public final String[] coefOriginalNames(Frame adaptedFrame) {
int[] coefOriginalIndices = coefOriginalColumnIndices(adaptedFrame);
String[] originalNames = new String[coefOriginalIndices[coefOriginalIndices.length - 1] + 1]; //needs +1 since we have 0 based indexing so if we have index N we need to have N+1 elements
int i = 0, j = 0;
while (i < coefOriginalIndices.length && j < originalNames.length) {
List<Integer> coefOriginalIndicesList = new ArrayList<>(coefOriginalIndices.length);
for (int value : coefOriginalIndices) coefOriginalIndicesList.add(value);
int end = coefOriginalIndicesList.lastIndexOf(coefOriginalIndices[i]);
String prefix = findLongestCommonPrefix(Arrays.copyOfRange(coefNames(), i, end + 1));
if (end > i) { // categorical variable
// Let's hope levels in this categorical variable don't have common prefix with '.'
// We know that we encode cat. vars as "variable_name.level" so we know that the prefix should end
// with ".". So make sure it's the case otherwise this can break on categorical variables like "pclass" in titanic
// dataset where every level starts with "Class " which leads to "pclass.Class " as the original name
prefix = prefix.substring(0, prefix.lastIndexOf("."));
}
if (".".equals(prefix.substring(prefix.length() - 1))) {
prefix = prefix.substring(0, prefix.length() - 1);
}
Expand All @@ -890,6 +903,10 @@ public final String[] coefOriginalNames() {
}
return originalNames;
}

public final String[] coefOriginalNames() {
return coefOriginalNames(_adaptedFrame);
}

// Return permutation matrix mapping input names to adaptedFrame colnames
public int[] mapNames(String[] names) {
Expand Down
Loading