From 665201f946479ef6fdb2b70875eb5ed9b8773cc5 Mon Sep 17 00:00:00 2001 From: Adam Valenta Date: Thu, 14 Sep 2023 19:06:43 +0200 Subject: [PATCH] Add categorical test --- .../test/java/hex/adaboost/AdaBoostTest.java | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/h2o-algos/src/test/java/hex/adaboost/AdaBoostTest.java b/h2o-algos/src/test/java/hex/adaboost/AdaBoostTest.java index 79b419683dda..2d63e580f713 100644 --- a/h2o-algos/src/test/java/hex/adaboost/AdaBoostTest.java +++ b/h2o-algos/src/test/java/hex/adaboost/AdaBoostTest.java @@ -1,5 +1,6 @@ package hex.adaboost; +import hex.Model; import hex.genmodel.algos.tree.SharedTreeSubgraph; import hex.glm.GLM; import hex.glm.GLMModel; @@ -182,6 +183,39 @@ public void testBasicTrainAndScore() { } } + @Test + public void testBasicTrainAndScoreCategorical() { + try { + Scope.enter(); + Frame train = parseTestFile("smalldata/prostate/prostate.csv"); + Scope.track(train); + String response = "CAPSULE"; + train.toCategoricalCol(response); + train.toCategoricalCol("RACE"); + train.toCategoricalCol("DPROS"); + train.toCategoricalCol("DCAPS"); + train.toCategoricalCol("GLEASON"); + AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters(); + p._train = train._key; + p._seed = 0xDECAF; + p._n_estimators = 50; + p._response_column = response; + p._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.OneHotExplicit; + + AdaBoost adaBoost = new AdaBoost(p); + AdaBoostModel adaBoostModel = adaBoost.trainModel().get(); + Scope.track_generic(adaBoostModel); + assertNotNull(adaBoostModel); + + System.out.println("train.toTwoDimTable() = " + train.toTwoDimTable()); + + Frame score = adaBoostModel.score(train); + Scope.track(score); + } finally { + Scope.exit(); + } + } + // @Test // public void testBasicTrainAndScoreGLM() { // try {