-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GH-6698: Shapley values support for ensemble models #15734
Changes from all commits
26045d4
2d12105
1558c65
c30281e
dc07d29
66200b4
987a1f8
77da787
1ef3872
2dc9772
8bd5dae
feb771e
9bc2425
a716fb5
e5cbce3
f115f6b
942502d
6f7c05d
dbe3cbf
fdc1891
5f55eac
b825b83
fc04afc
3c10825
b98e937
1926964
dafb53e
5d5db7c
662c25f
d47b9dd
89f065c
94d4f6c
a7fa460
ac0cacc
6592b0e
869bf19
0085fb4
c8bcd52
7443265
673cd47
112aa22
1fde594
b4b7029
fad0f8d
6380c7d
218cefa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package hex; | ||
|
||
import water.Job; | ||
import water.MRTask; | ||
import water.MemoryManager; | ||
import water.fvec.Chunk; | ||
import water.fvec.NewChunk; | ||
|
||
import java.util.stream.Stream; | ||
|
||
public class ContributionsMeanAggregator extends MRTask<ContributionsMeanAggregator> { | ||
final int _nBgRows; | ||
double[][] _partialSums; | ||
final int _rowIdxIdx; | ||
final int _nRows; | ||
final int _nCols; | ||
int _startIndex; | ||
final Job _j; | ||
|
||
public ContributionsMeanAggregator(Job j, int nRows, int nCols, int nBgRows) { | ||
_j = j; | ||
_nRows = nRows; | ||
_nCols = nCols; | ||
_rowIdxIdx = nCols; | ||
_nBgRows = nBgRows; | ||
_startIndex = 0; | ||
} | ||
|
||
public ContributionsMeanAggregator setStartIndex(int startIndex) { | ||
_startIndex = startIndex; | ||
return this; | ||
} | ||
|
||
@Override | ||
public void map(Chunk[] cs, NewChunk[] ncs) { | ||
if (isCancelled() || null != _j && _j.stop_requested()) return; | ||
_partialSums = MemoryManager.malloc8d(_nRows, _nCols); | ||
for (int i = 0; i < cs[0]._len; i++) { | ||
final int rowIdx = (int) cs[_rowIdxIdx].at8(i); | ||
|
||
for (int j = 0; j < _nCols; j++) { | ||
_partialSums[rowIdx - _startIndex][j] += cs[j].atd(i); | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public void reduce(ContributionsMeanAggregator mrt) { | ||
for (int i = 0; i < _partialSums.length; i++) { | ||
for (int j = 0; j < _partialSums[0].length; j++) { | ||
_partialSums[i][j] += mrt._partialSums[i][j]; | ||
} | ||
} | ||
mrt._partialSums = null; | ||
} | ||
|
||
@Override | ||
protected void postGlobal() { | ||
NewChunk[] ncs = Stream.of(appendables()).map(vec -> vec.chunkForChunkIdx(0)).toArray(NewChunk[]::new); | ||
for (int i = 0; i < _partialSums.length; i++) { | ||
for (int j = 0; j < _partialSums[0].length; j++) { | ||
ncs[j].addNum(_partialSums[i][j] / _nBgRows); | ||
} | ||
} | ||
_partialSums = null; | ||
for (NewChunk nc : ncs) | ||
nc.close(0, _fs); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
package hex; | ||
|
||
import water.*; | ||
import water.fvec.*; | ||
import water.util.Log; | ||
import water.util.fp.Function; | ||
|
||
import java.util.*; | ||
import java.util.stream.IntStream; | ||
|
||
import static water.SplitToChunksApplyCombine.concatFrames; | ||
|
||
/*** | ||
* Calls map(Chunk[] frame, Chunk[] background, NewChunk[] ncs) by copying the smaller frame across the nodes. | ||
* @param <T> | ||
*/ | ||
public abstract class ContributionsWithBackgroundFrameTask<T extends ContributionsWithBackgroundFrameTask<T>> extends MRTask<T> { | ||
transient Frame _frame; | ||
transient Frame _backgroundFrame; | ||
Key<Frame> _frameKey; | ||
Key<Frame> _backgroundFrameKey; | ||
|
||
final boolean _aggregate; | ||
|
||
boolean _isFrameBigger; | ||
|
||
long _startRow; | ||
long _endRow; | ||
Job _job; | ||
|
||
public ContributionsWithBackgroundFrameTask(Key<Frame> frKey, Key<Frame> backgroundFrameKey, boolean perReference) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand what this class is for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to be able to calculate baseline SHAP for each point from in pseudo-code this class is approximately responsible for: result = Frame()
for x in frame:
for bg in background_frame:
result.append(baselineSHAP(x, bg) This yields contribution for all the points against all the references. It's not very useful by itself but it's required for calculating the SHAP for Stacked Ensembles. Some other uses can be for example explaining ensemble of models where the models are proprietary and can't be shared, e.g., from one company you get fraud score, from another credit score and you input these values to your model, those companies can share the baseline shap with you and you can find calculate contributions for the input features without knowing the fraud/credit models. (This example is taken from https://www.nature.com/articles/s41467-022-31384-3) These baseline shap values are often designated as Note the commonly used SHAP values are |
||
assert null != frKey.get(); | ||
assert null != backgroundFrameKey.get(); | ||
|
||
_frameKey = frKey; | ||
_backgroundFrameKey = backgroundFrameKey; | ||
|
||
_frame = frKey.get(); | ||
_backgroundFrame = backgroundFrameKey.get(); | ||
assert _frame.numRows() > 0 : "Frame has to contain at least one row."; | ||
assert _backgroundFrame.numRows() > 0 : "Background frame has to contain at least one row."; | ||
|
||
_isFrameBigger = _frame.numRows() > _backgroundFrame.numRows(); | ||
_aggregate = !perReference; | ||
_startRow = -1; | ||
_endRow = -1; | ||
} | ||
|
||
protected void loadFrames() { | ||
if (null == _frame) | ||
_frame = _frameKey.get(); | ||
if (null == _backgroundFrame) | ||
_backgroundFrame = _backgroundFrameKey.get(); | ||
assert _frame != null && _backgroundFrame != null; | ||
} | ||
|
||
@Override | ||
public void map(Chunk[] cs, NewChunk[] ncs) { | ||
loadFrames(); | ||
Frame smallerFrame = _isFrameBigger ? _backgroundFrame : _frame; | ||
long sfIdx = 0; | ||
long maxSfIdx = smallerFrame.numRows(); | ||
if (!_isFrameBigger && _startRow != -1 && _endRow != -1) { | ||
sfIdx = _startRow; | ||
maxSfIdx = _endRow; | ||
} | ||
|
||
while (sfIdx < maxSfIdx) { | ||
if (isCancelled() || null != _job && _job.stop_requested()) return; | ||
|
||
long finalSfIdx = sfIdx; | ||
Chunk[] sfCs = IntStream | ||
.range(0, smallerFrame.numCols()) | ||
.mapToObj(col -> smallerFrame.vec(col).chunkForRow(finalSfIdx)) | ||
.toArray(Chunk[]::new); | ||
NewChunk[] ncsSlice = Arrays.copyOf(ncs, ncs.length - 2); | ||
if (_isFrameBigger) { | ||
map(cs, sfCs, ncsSlice); | ||
for (int i = 0; i < cs[0]._len; i++) { | ||
for (int j = 0; j < sfCs[0]._len; j++) { | ||
ncs[ncs.length - 2].addNum(cs[0].start() + i); // row idx | ||
ncs[ncs.length - 1].addNum(sfCs[0].start() + j); // background idx | ||
} | ||
} | ||
} else { | ||
map(sfCs, cs, ncsSlice); | ||
for (int i = 0; i < sfCs[0]._len; i++) { | ||
for (int j = 0; j < cs[0]._len; j++) { | ||
ncs[ncs.length - 2].addNum(sfCs[0].start() + i); // row idx | ||
ncs[ncs.length - 1].addNum(cs[0].start() + j); // background idx | ||
} | ||
} | ||
} | ||
sfIdx += sfCs[0]._len; | ||
} | ||
} | ||
|
||
public static double estimateRequiredMemory(int nCols, Frame frame, Frame backgroundFrame) { | ||
return 8 * nCols * frame.numRows() * backgroundFrame.numRows(); | ||
} | ||
|
||
public static double estimatePerNodeMinimalMemory(int nCols, Frame frame, Frame backgroundFrame){ | ||
boolean isFrameBigger = frame.numRows() > backgroundFrame.numRows(); | ||
double reqMem = estimateRequiredMemory(nCols, frame, backgroundFrame); | ||
Frame biggerFrame = isFrameBigger ? frame : backgroundFrame; | ||
long[] frESPC = biggerFrame.anyVec().espc(); | ||
// Guess the max size of the chunk from the bigger frame as 2 * average chunk | ||
double maxMinChunkSizeInVectorGroup = 2 * 8 * nCols * biggerFrame.numRows() / (double) biggerFrame.anyVec().nChunks(); | ||
|
||
// Try to compute it exactly | ||
if (null != frESPC) { | ||
long maxFr = 0; | ||
for (int i = 0; i < frESPC.length-1; i++) { | ||
maxFr = Math.max(maxFr, frESPC[i+1]-frESPC[i]); | ||
} | ||
maxMinChunkSizeInVectorGroup = Math.max(maxMinChunkSizeInVectorGroup, 8*nCols*maxFr); | ||
} | ||
long nRowsOfSmallerFrame = isFrameBigger ? backgroundFrame.numRows() : frame.numRows(); | ||
|
||
// We need the whole smaller frame on each node and one chunk per col of the bigger frame (at minimum) | ||
return Math.max(reqMem / H2O.CLOUD._memary.length, maxMinChunkSizeInVectorGroup + nRowsOfSmallerFrame * nCols * 8); | ||
} | ||
|
||
double estimatePerNodeMinimalMemory(int nCols) { | ||
return estimatePerNodeMinimalMemory(nCols, _frame, _backgroundFrame); | ||
} | ||
|
||
|
||
public static long minMemoryPerNode() { | ||
long minMem = Long.MAX_VALUE; | ||
for (H2ONode h2o : H2O.CLOUD._memary) { | ||
long mem = h2o._heartbeat.get_free_mem(); // in bytes | ||
if (mem < minMem) | ||
minMem = mem; | ||
} | ||
return minMem; | ||
} | ||
|
||
public static long totalFreeMemory() { | ||
long mem = 0; | ||
for (H2ONode h2o : H2O.CLOUD._memary) { | ||
mem += h2o._heartbeat.get_free_mem(); // in bytes | ||
} | ||
return mem; | ||
} | ||
|
||
public static boolean enoughMinMemory(double estimatedMemory) { | ||
return minMemoryPerNode() > estimatedMemory; | ||
} | ||
|
||
abstract protected void map(Chunk[] cs, Chunk[] bgCs, NewChunk[] ncs); | ||
|
||
|
||
void setChunkRange(int startCIdx, int endCIdx) { | ||
assert !_isFrameBigger; | ||
_startRow = _frame.anyVec().chunkForChunkIdx(startCIdx).start(); | ||
_endRow = _frame.anyVec().chunkForChunkIdx(endCIdx).start() + _frame.anyVec().chunkForChunkIdx(endCIdx)._len; | ||
} | ||
|
||
|
||
// takes care of mapping over the bigger frame | ||
public Frame runAndGetOutput(Job j, Key<Frame> destinationKey, String[] names) { | ||
_job = j; | ||
loadFrames(); | ||
double reqMem = estimateRequiredMemory(names.length + 2, _frame, _backgroundFrame); | ||
double reqPerNodeMem = estimatePerNodeMinimalMemory(names.length + 2); | ||
|
||
String[] namesWithRowIdx = new String[names.length + 2]; | ||
System.arraycopy(names, 0, namesWithRowIdx, 0, names.length); | ||
namesWithRowIdx[names.length] = "RowIdx"; | ||
namesWithRowIdx[names.length + 1] = "BackgroundRowIdx"; | ||
Key<Frame> individualContributionsKey = _aggregate ? Key.make(destinationKey + "_individual_contribs") : destinationKey; | ||
|
||
if (!_aggregate) { | ||
if (!enoughMinMemory(reqPerNodeMem)) { | ||
throw new RuntimeException("Not enough memory. Estimated minimal total memory is " + reqMem + "B. " + | ||
"Estimated minimal per node memory (assuming perfectly balanced datasets) is " + reqPerNodeMem + "B. " + | ||
"Node with minimum memory has " + minMemoryPerNode() + "B. Total available memory is " + totalFreeMemory() + "B." | ||
); | ||
} | ||
Frame indivContribs = withPostMapAction(JobUpdatePostMap.forJob(j)) | ||
.doAll(namesWithRowIdx.length, Vec.T_NUM, _isFrameBigger ? _frame : _backgroundFrame) | ||
.outputFrame(individualContributionsKey, namesWithRowIdx, null); | ||
|
||
return indivContribs; | ||
} else { | ||
if (!enoughMinMemory(reqPerNodeMem)) { | ||
if (minMemoryPerNode() < 5 * (names.length + 2) * _frame.numRows() * 8) { | ||
throw new RuntimeException("Not enough memory. Estimated minimal total memory is " + reqMem + "B. " + | ||
"Estimated minimal per node memory (assuming perfectly balanced datasets) is " + reqPerNodeMem + "B. " + | ||
"Node with minimum memory has " + minMemoryPerNode() + "B. Total available memory is " + totalFreeMemory() + "B." | ||
); | ||
} | ||
// Split the _frame in subsections and calculate baselines (expand the frame) and then the average (reduce the frame sieze) | ||
int nChunks = _frame.anyVec().nChunks(); | ||
// last iteration we need memory for ~whole aggregated frame + expanded subframe | ||
int nSubFrames = (int) Math.ceil(2*reqMem / (minMemoryPerNode() - 8 * _frame.numRows() * (names.length))); | ||
nSubFrames = nChunks; | ||
int chunksPerIter = (int) Math.max(1, Math.floor(nChunks / nSubFrames)); | ||
|
||
Log.warn("Not enough memory to calculate SHAP at once. Calculating in " + (nSubFrames) + " iterations."); | ||
_isFrameBigger = false; // ensure we map over the BG frame so we can average over the results properly; | ||
Frame result = null; | ||
List<Frame> subFrames = new LinkedList<Frame>(); | ||
try { | ||
for (int i = 0; i < nSubFrames; i++) { | ||
setChunkRange(i * chunksPerIter, Math.min(nChunks - 1, (i + 1) * chunksPerIter - 1)); | ||
Frame indivContribs = clone().withPostMapAction(JobUpdatePostMap.forJob(j)) | ||
.doAll(namesWithRowIdx.length, Vec.T_NUM, _backgroundFrame) | ||
.outputFrame(Key.make(destinationKey + "_individual_contribs_" + i), namesWithRowIdx, null); | ||
|
||
subFrames.add(new ContributionsMeanAggregator(_job,(int) (_endRow - _startRow), names.length, (int) _backgroundFrame.numRows()) | ||
.setStartIndex((int) _startRow) | ||
.withPostMapAction(JobUpdatePostMap.forJob(j)) | ||
.doAll(names.length, Vec.T_NUM, indivContribs) | ||
.outputFrame(Key.make(destinationKey + "_part_" + i), names, null)); | ||
indivContribs.delete(); | ||
} | ||
|
||
result = concatFrames(subFrames, destinationKey); | ||
Set<String> homes = new HashSet<>(); | ||
for (int i = 0; i < result.anyVec().nChunks(); i++) { | ||
for (int k = 0; k < result.numCols(); k++) { | ||
homes.add(result.vec(k).chunkKey(i).home_node().getIpPortString()); | ||
} | ||
} | ||
return result; | ||
} finally { | ||
if (null != result) { | ||
for (Frame fr : subFrames) { | ||
Frame.deleteTempFrameAndItsNonSharedVecs(fr, result); | ||
} | ||
} else { | ||
for (Frame fr : subFrames) | ||
fr.delete(); | ||
} | ||
} | ||
} else { | ||
Frame indivContribs = withPostMapAction(JobUpdatePostMap.forJob(j)) | ||
.doAll(namesWithRowIdx.length, Vec.T_NUM, _isFrameBigger ? _frame : _backgroundFrame) | ||
.outputFrame(individualContributionsKey, namesWithRowIdx, null); | ||
try { | ||
return new ContributionsMeanAggregator(_job, (int) _frame.numRows(), names.length, (int) _backgroundFrame.numRows()) | ||
.withPostMapAction(JobUpdatePostMap.forJob(j)) | ||
.doAll(names.length, Vec.T_NUM, indivContribs) | ||
.outputFrame(destinationKey, names, null); | ||
} finally { | ||
indivContribs.delete(true); | ||
} | ||
} | ||
} | ||
} | ||
tomasfryda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you just call vec.mean()? I know we may run into problems with categorical columns though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. I need something like
frame.groupby("RowIdx").mean()
. The number of rows isnrow(frame)*nrow(background_frame)
and the result should havenrow(frame)
rows.