Skip to content

Commit

Permalink
Add java tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfryda committed Sep 15, 2023
1 parent c350e21 commit 0e630b3
Show file tree
Hide file tree
Showing 11 changed files with 1,677 additions and 28 deletions.
10 changes: 7 additions & 3 deletions h2o-algos/src/main/java/hex/deeplearning/DeepLearningModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,13 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
throw new UnsupportedOperationException("Only baseline SHAP is supported for this model. Please provide background frame.");

Log.info("Starting contributions calculation for "+this._key+"...");
List<Frame> tmpFrames = new LinkedList<>();
Frame adaptedFrame = null;
Frame adaptedBgFrame = null;
try {
List<Frame> tmpFrames = new LinkedList<>();
Frame adaptedBgFrame = adaptFrameForScore(backgroundFrame, false, tmpFrames);
adaptedBgFrame = adaptFrameForScore(backgroundFrame, false, tmpFrames);
DKV.put(adaptedBgFrame);
Frame adaptedFrame = adaptFrameForScore(frame, false, tmpFrames);
adaptedFrame = adaptFrameForScore(frame, false, tmpFrames);
DeepSHAPContributionsWithBackground contributions = new DeepSHAPContributionsWithBackground(
adaptedFrame,
adaptedBgFrame,
Expand All @@ -428,6 +430,8 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
colNames[colNames.length - 1] = "BiasTerm";
return contributions.runAndGetOutput(j, destination_key, colNames);
} finally {
if (null != adaptedFrame) Frame.deleteTempFrameAndItsNonSharedVecs(adaptedFrame, frame);
if (null != adaptedBgFrame) Frame.deleteTempFrameAndItsNonSharedVecs(adaptedBgFrame, backgroundFrame);
Log.info("Finished contributions calculation for "+this._key+"...");
}
}
Expand Down
25 changes: 18 additions & 7 deletions h2o-algos/src/main/java/hex/ensemble/StackedEnsembleModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
String[] columns = null;
baseModelsIdx.add(0);
Frame fr = new Frame();
Scope.enter();
Scope.track(fr);
Frame levelOneFrame = null;
Frame levelOneFrameBg = null;
Frame adaptFr = null;
Frame adaptFrBg = null;
try {
for (Key<Model> bm : _parms._base_models) {
if (isUsefulBaseModel(bm)) {
Expand Down Expand Up @@ -205,6 +207,8 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
}
}
if (!Arrays.equals(columns, contributions._names)) {
Frame.deleteTempFrameAndItsNonSharedVecs(contributions, fr);
fr.delete();
if (Original.equals(options._outputFormat)) {
throw new IllegalArgumentException("Base model contributions have different columns likely due to models using different categorical encoding. Please use output_format=\"compact\".");
}
Expand All @@ -217,6 +221,7 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
.toArray(String[]::new)
);
fr.add(contributions);
Frame.deleteTempFrameAndItsNonSharedVecs(contributions, fr);
baseModelsIdx.add(fr.numCols());
}
}
Expand All @@ -228,12 +233,12 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
columns = Arrays.copyOfRange(columns, 0, columns.length - 3);

List<Frame> tmpFrames = new ArrayList<>();
Frame adaptFr = adaptFrameForScore(frame, false, tmpFrames);
Frame levelOneFrame = getLevelOnePredictFrame(frame, adaptFr, null);
adaptFr = adaptFrameForScore(frame, false, tmpFrames);
levelOneFrame = getLevelOnePredictFrame(frame, adaptFr, null);

tmpFrames = new ArrayList<>();
Frame adaptFrBg = adaptFrameForScore(backgroundFrame, false, tmpFrames);
Frame levelOneFrameBg = getLevelOnePredictFrame(backgroundFrame, adaptFrBg, null);
adaptFrBg = adaptFrameForScore(backgroundFrame, false, tmpFrames);
levelOneFrameBg = getLevelOnePredictFrame(backgroundFrame, adaptFrBg, null);

Frame metalearnerContrib = ((Model.Contributions) _output._metalearner).scoreContributions(levelOneFrame,
Key.make(destination_key + "_" + _output._metalearner._key), j,
Expand All @@ -242,12 +247,14 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
.setOutputSpace(options._outputSpace)
.setOutputPerReference(true),
levelOneFrameBg);

metalearnerContrib.setNames(Arrays.stream(metalearnerContrib._names)
.map(name -> "metalearner_" + name)
.toArray(String[]::new));

fr.add(metalearnerContrib);
Frame.deleteTempFrameAndItsNonSharedVecs(metalearnerContrib, fr);


Frame indivContribs = new GDeepSHAP(columns, baseModels.toArray(new String[0]),
fr._names, baseModelsIdx.toArray(new Integer[0]), _parms._metalearner_transform)
Expand All @@ -268,7 +275,11 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
}
} finally {
Log.info("Finished contributions calculation for " + this._key + "...");
if (null != levelOneFrame) levelOneFrame.delete();
if (null != levelOneFrameBg) levelOneFrameBg.delete();
Frame.deleteTempFrameAndItsNonSharedVecs(fr, frame);
if (null != adaptFr) Frame.deleteTempFrameAndItsNonSharedVecs(adaptFr, frame);
if (null != adaptFrBg) Frame.deleteTempFrameAndItsNonSharedVecs(adaptFrBg, backgroundFrame);
}
}

Expand Down
13 changes: 7 additions & 6 deletions h2o-algos/src/main/java/hex/glm/GLMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,15 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
}

List<Frame> tmpFrames = new ArrayList<>();

Frame adaptedFrame = null;
Frame adaptedBgFrame = null;
if (backgroundFrame == null)
throw H2O.unimpl("GLM supports contribution calculation only with background frame.");

Log.info("Starting contributions calculation for " + this._key + "...");
try {
Frame adaptedBgFrame = adaptFrameForScore(backgroundFrame, false, tmpFrames);
DKV.put(adaptedBgFrame);
Frame adaptedFrame = adaptFrameForScore(frame, false, tmpFrames);
adaptedBgFrame = adaptFrameForScore(backgroundFrame, false, tmpFrames);
adaptedFrame = adaptFrameForScore(frame, false, tmpFrames);
DataInfo dinfo = _output._dinfo.clone();
dinfo._adaptedFrame = adaptedFrame;
GLMContributionsWithBackground contributions = new GLMContributionsWithBackground(dinfo,
Expand All @@ -276,9 +276,10 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
}).toArray(String[]::new)
: _output._coefficient_names, 0, colNames, 0, colNames.length - 1);
colNames[colNames.length - 1] = "BiasTerm";
return contributions.runAndGetOutput(j, destination_key,
colNames);
return contributions.runAndGetOutput(j, destination_key, colNames);
} finally {
if (null != adaptedFrame) Frame.deleteTempFrameAndItsNonSharedVecs(adaptedFrame, frame);
if (null != adaptedBgFrame) Frame.deleteTempFrameAndItsNonSharedVecs(adaptedBgFrame, backgroundFrame);
Log.info("Finished contributions calculation for " + this._key + "...");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,19 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
}

Log.info("Starting contributions calculation for " + this._key + "...");
Frame adaptedFrame = null;
Frame adaptedBgFrame = null;
try {
if (options._outputFormat == ContributionsOutputFormat.Compact || _output._domains == null) {
Frame adaptFrm = removeSpecialColumns(frame);
Frame adaptBackgroundFrm = removeSpecialColumns(backgroundFrame);
adaptedFrame = removeSpecialColumns(frame);
adaptedBgFrame = removeSpecialColumns(backgroundFrame);

final String[] outputNames = ArrayUtils.append(adaptFrm.names(), "BiasTerm");
return getScoreContributionsWithBackgroundTask(this, adaptFrm, adaptBackgroundFrm, false, null, options)
final String[] outputNames = ArrayUtils.append(adaptedFrame.names(), "BiasTerm");
return getScoreContributionsWithBackgroundTask(this, adaptedFrame, adaptedBgFrame, false, null, options)
.runAndGetOutput(j, destination_key, outputNames);
} else {
Frame adaptFrm = removeSpecialColumns(frame);
Frame adaptBackgroundFrm = removeSpecialColumns(backgroundFrame);
adaptedFrame = removeSpecialColumns(frame);
adaptedBgFrame = removeSpecialColumns(backgroundFrame);
assert Parameters.CategoricalEncodingScheme.Enum.equals(_parms._categorical_encoding) : "Unsupported categorical encoding. Only enum is supported.";
int[] catOffsets = new int[_output._domains.length + 1];

Expand Down Expand Up @@ -163,10 +165,12 @@ public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Fra
}
}

return getScoreContributionsWithBackgroundTask(this, adaptFrm, adaptBackgroundFrm, true, catOffsets, options)
return getScoreContributionsWithBackgroundTask(this, adaptedFrame, adaptedBgFrame, true, catOffsets, options)
.runAndGetOutput(j, destination_key, outputNames);
}
} finally {
if (null != adaptedFrame) Frame.deleteTempFrameAndItsNonSharedVecs(adaptedFrame, frame);
if (null != adaptedBgFrame) Frame.deleteTempFrameAndItsNonSharedVecs(adaptedBgFrame, backgroundFrame);
Log.info("Finished contributions calculation for " + this._key + "...");
}
}
Expand Down
211 changes: 211 additions & 0 deletions h2o-algos/src/test/java/hex/deeplearning/DeepLearningSHAPTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
package hex.deeplearning;

import hex.Model;
import hex.deeplearning.DeepLearningModel.DeepLearningParameters;
import org.apache.commons.lang.math.LongRange;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.parser.ParseDataset;
import water.rapids.Rapids;
import water.rapids.Val;
import water.rapids.vals.ValFrame;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class DeepLearningSHAPTest extends TestUtil {

/*
NOTE: These test do not test all required properties for SHAP.
To be more sure after doing some changes to the SHAP, please run the python test:
h2o-py/tests/testdir_misc/pyunit_SHAP.py
*/

@BeforeClass
public static void setup() {
stall_till_cloudsize(1);
}


@Test
public void testClassificationCompactSHAP() {
NFSFileVec nfs = TestUtil.makeNfsFileVec("smalldata/titanic/titanic_expanded.csv");
Frame fr = ParseDataset.parse(Key.make(), nfs._key);
Frame bgFr = fr.deepSlice(new LongRange(0, 50).toArray(), null);
Frame test = fr.deepSlice(new LongRange(51, 101).toArray(), null);
DeepLearningModel model = null;
Frame scored = null;
Frame contribs = null;
Frame res = null;
try {
// Launch Deep Learning
DeepLearningParameters params = new DeepLearningParameters();
params._train = fr._key;
params._epochs = 5;
params._response_column = "survived";

model = new DeepLearning(params).trainModel().get();

assert model != null;
scored = model.score(test);
contribs = model.scoreContributions(test, Key.make(), null,
new Model.Contributions.ContributionsOptions().setOutputFormat(Model.Contributions.ContributionsOutputFormat.Compact),
bgFr);

assert fr.numCols() >= contribs.numCols();

Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
assertColsEquals(scored, res, 2, 0, 1e-8);
} finally {
fr.delete();
bgFr.delete();
test.delete();
if (null != res) res.delete();
if (null != scored) scored.delete();
if (null != contribs) contribs.delete();
if (null != model) model.delete();
}
}


@Test
public void testClassificationOriginalSHAP() {
NFSFileVec nfs = TestUtil.makeNfsFileVec("smalldata/titanic/titanic_expanded.csv");
Frame fr = ParseDataset.parse(Key.make(), nfs._key);
Frame bgFr = fr.deepSlice(new LongRange(0, 50).toArray(), null);
Frame test = fr.deepSlice(new LongRange(51, 101).toArray(), null);
DeepLearningModel model = null;
Frame scored = null;
Frame contribs = null;
Frame res = null;
try {
// Launch Deep Learning
DeepLearningParameters params = new DeepLearningParameters();
params._train = fr._key;
params._epochs = 5;
params._response_column = "survived";

model = new DeepLearning(params).trainModel().get();

assert model != null;
scored = model.score(test);
contribs = model.scoreContributions(test, Key.make(), null,
new Model.Contributions.ContributionsOptions().setOutputFormat(Model.Contributions.ContributionsOutputFormat.Original),
bgFr);

assert fr.numCols() < contribs.numCols(); // Titanic has categorical vars

Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
assertColsEquals(scored, res, 2, 0, 1e-8);
} finally {
fr.delete();
bgFr.delete();
test.delete();
if (null != res) res.delete();
if (null != scored) scored.delete();
if (null != contribs) contribs.delete();
if (null != model) model.delete();
}
}


@Test
public void testRegressionCompactSHAP() {
NFSFileVec nfs = TestUtil.makeNfsFileVec("smalldata/titanic/titanic_expanded.csv");
Frame fr = ParseDataset.parse(Key.make(), nfs._key);
Frame bgFr = fr.deepSlice(new LongRange(0, 50).toArray(), null);
Frame test = fr.deepSlice(new LongRange(51, 101).toArray(), null);
DeepLearningModel model = null;
Frame scored = null;
Frame contribs = null;
Frame res = null;
try {
// Launch Deep Learning
DeepLearningParameters params = new DeepLearningParameters();
params._train = fr._key;
params._epochs = 5;
params._response_column = "fare";

model = new DeepLearning(params).trainModel().get();

assert model != null;
scored = model.score(test);
contribs = model.scoreContributions(test, Key.make(), null,
new Model.Contributions.ContributionsOptions().setOutputFormat(Model.Contributions.ContributionsOutputFormat.Compact),
bgFr);

assert fr.numCols() >= contribs.numCols();

Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
assertColsEquals(scored, res, 0, 0, 1e-5);
} finally {
fr.delete();
bgFr.delete();
test.delete();
if (null != res) res.delete();
if (null != scored) scored.delete();
if (null != contribs) contribs.delete();
if (null != model) model.delete();
}
}


@Test
public void testRegressionOriginalSHAP() {
NFSFileVec nfs = TestUtil.makeNfsFileVec("smalldata/titanic/titanic_expanded.csv");
Frame fr = ParseDataset.parse(Key.make(), nfs._key);
Frame bgFr = fr.deepSlice(new LongRange(0, 50).toArray(), null);
Frame test = fr.deepSlice(new LongRange(51, 101).toArray(), null);
DeepLearningModel model = null;
Frame scored = null;
Frame contribs = null;
Frame res = null;
try {
// Launch Deep Learning
DeepLearningParameters params = new DeepLearningParameters();
params._train = fr._key;
params._epochs = 5;
params._response_column = "fare";

model = new DeepLearning(params).trainModel().get();

assert model != null;
scored = model.score(test);
contribs = model.scoreContributions(test, Key.make(), null,
new Model.Contributions.ContributionsOptions().setOutputFormat(Model.Contributions.ContributionsOutputFormat.Original),
bgFr);

assert fr.numCols() < contribs.numCols(); // Titanic has categorical vars

Val val = Rapids.exec("(sumaxis " + contribs._key + " 0 1)");
assertTrue(val instanceof ValFrame);
res = val.getFrame();
assertColsEquals(scored, res, 0, 0, 1e-5);
} finally {
fr.delete();
bgFr.delete();
test.delete();
if (null != res) res.delete();
if (null != scored) scored.delete();
if (null != contribs) contribs.delete();
if (null != model) model.delete();
}
}

private static void assertColsEquals(Frame expected, Frame actual, int colExpected, int colActual, double eps) {
assertEquals(expected.numRows(), actual.numRows());
for (int i = 0; i < expected.numRows(); i++) {
assertEquals("Wrong sum in row " + i, expected.vec(colExpected).at(i), actual.vec(colActual).at(i), eps);
}
}
}
Loading

0 comments on commit 0e630b3

Please sign in to comment.