Skip to content

Commit

Permalink
Refactor n_estimators to nlearners
Browse files Browse the repository at this point in the history
  • Loading branch information
valenad1 committed Sep 19, 2023
1 parent 719a356 commit 1bdcf39
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 25 deletions.
14 changes: 7 additions & 7 deletions h2o-algos/src/main/java/hex/adaboost/AdaBoost.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
*/
public class AdaBoost extends ModelBuilder<AdaBoostModel, AdaBoostModel.AdaBoostParameters, AdaBoostModel.AdaBoostOutput> {
private static final Logger LOG = Logger.getLogger(AdaBoost.class);
private static final int MAX_ESTIMATORS = 100_000;
private static final int MAX_LEARNERS = 100_000;

private AdaBoostModel _model;
private String _weightsName = "weights";
Expand Down Expand Up @@ -59,9 +59,9 @@ public boolean haveMojo() {
@Override
public void init(boolean expensive) {
super.init(expensive);
if(_parms._n_estimators < 1 || _parms._n_estimators > MAX_ESTIMATORS)
if(_parms._nlearners < 1 || _parms._nlearners > MAX_LEARNERS)
error("n_estimators", "Parameter n_estimators must be in interval [1, "
+ MAX_ESTIMATORS + "] but it is " + _parms._n_estimators);
+ MAX_LEARNERS + "] but it is " + _parms._nlearners);
if (_parms._weak_learner == AdaBoostModel.Algorithm.AUTO) {
_parms._weak_learner = AdaBoostModel.Algorithm.DRF;
}
Expand Down Expand Up @@ -96,8 +96,8 @@ public void computeImpl() {
}

private void buildAdaboost() {
_model._output.alphas = new double[(int)_parms._n_estimators];
_model._output.models = new Key[(int)_parms._n_estimators];
_model._output.alphas = new double[(int)_parms._nlearners];
_model._output.models = new Key[(int)_parms._nlearners];

Frame _trainWithWeights;
if (_parms._weights_column == null) {
Expand All @@ -112,7 +112,7 @@ private void buildAdaboost() {
_trainWithWeights = _parms.train();
}

for (int n = 0; n < _parms._n_estimators; n++) {
for (int n = 0; n < _parms._nlearners; n++) {
Timer timer = new Timer();
ModelBuilder job = chooseWeakLearner(_trainWithWeights);
job._parms._seed += n;
Expand Down Expand Up @@ -233,7 +233,7 @@ public TwoDimTable createModelSummaryTable() {
"");
int row = 0;
int col = 0;
table.set(row, col++, _parms._n_estimators);
table.set(row, col++, _parms._nlearners);
table.set(row, col++, _parms._learn_rate);
table.set(row, col++, _parms._weak_learner.toString());
table.set(row, col, _parms._seed);
Expand Down
6 changes: 3 additions & 3 deletions h2o-algos/src/main/java/hex/adaboost/AdaBoostModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public static class AdaBoostParameters extends Model.Parameters {
/**
* Number of weak learners to train. Defaults to 50.
*/
public int _n_estimators;
public int _nlearners;

/**
* Choose a weak learner type. Defaults to DRF.
Expand Down Expand Up @@ -121,12 +121,12 @@ public String javaName() {

@Override
public long progressUnits() {
return _n_estimators;
return _nlearners;
}

public AdaBoostParameters() {
super();
_n_estimators = 50;
_nlearners = 50;
_weak_learner = Algorithm.AUTO;
_learn_rate = 0.5;
}
Expand Down
28 changes: 13 additions & 15 deletions h2o-algos/src/test/java/hex/adaboost/AdaBoostTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import hex.Model;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -49,7 +47,7 @@ public void testBasicTrain() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 50;
p._nlearners = 50;
p._response_column = response;

AdaBoost adaBoost = new AdaBoost(p);
Expand Down Expand Up @@ -89,7 +87,7 @@ public void testBasicTrainGLM() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 50;
p._nlearners = 50;
p._weak_learner = AdaBoostModel.Algorithm.GLM;
p._response_column = response;

Expand All @@ -113,7 +111,7 @@ public void testBasicTrainLarge() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 50;
p._nlearners = 50;
p._response_column = response;

AdaBoost adaBoost = new AdaBoost(p);
Expand All @@ -136,7 +134,7 @@ public void testBasicTrainAndScore() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 50;
p._nlearners = 50;
p._response_column = response;

AdaBoost adaBoost = new AdaBoost(p);
Expand Down Expand Up @@ -168,7 +166,7 @@ public void testBasicTrainAndScoreCategorical() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 50;
p._nlearners = 50;
p._response_column = response;
p._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.OneHotExplicit;

Expand Down Expand Up @@ -197,7 +195,7 @@ public void testBasicTrainAndScoreLarge() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 50;
p._nlearners = 50;
p._response_column = response;

AdaBoost adaBoost = new AdaBoost(p);
Expand Down Expand Up @@ -226,7 +224,7 @@ public void testBasicTrainAirlines() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 50;
p._nlearners = 50;
p._response_column = response;

AdaBoost adaBoost = new AdaBoost(p);
Expand Down Expand Up @@ -255,7 +253,7 @@ public void testBasicTrainHiggs() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 50;
p._nlearners = 50;
p._response_column = response;

AdaBoost adaBoost = new AdaBoost(p);
Expand Down Expand Up @@ -330,7 +328,7 @@ public void testBasicTrainAndScoreWithExternalWeightsColumn() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 10;
p._nlearners = 10;
p._response_column = response;

AdaBoost adaBoostReference = new AdaBoost(p);
Expand Down Expand Up @@ -374,7 +372,7 @@ public void testBasicTrainAndScoreWithCustomWeightsColumn() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 10;
p._nlearners = 10;
p._response_column = response;

AdaBoost adaBoostReference = new AdaBoost(p);
Expand Down Expand Up @@ -419,7 +417,7 @@ public void testBasicTrainAndScoreWithDuplicatedWeightsColumn() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 10;
p._nlearners = 10;
p._response_column = response;
p._ignore_const_cols = false;

Expand Down Expand Up @@ -447,7 +445,7 @@ public void testBasicTrainAndScoreGLM() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 50;
p._nlearners = 50;
p._weak_learner = AdaBoostModel.Algorithm.GLM;
p._response_column = response;

Expand All @@ -473,7 +471,7 @@ public void testBasicTrainAndScoreGBM() {
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._n_estimators = 50;
p._nlearners = 50;
p._weak_learner = AdaBoostModel.Algorithm.GBM;
p._response_column = response;

Expand Down

0 comments on commit 1bdcf39

Please sign in to comment.