Skip to content

Commit 9c84d6f

Browse files
adam2392PSSF23
andauthored
[ENH, BUG] Test honest tree performance via quadratic simulation (neurodata#164)
* Test honest tree performance * Fixes API for calling n_estimators * Adds additional testing towards fixing the honest tree power performance via quadratic simulation --------- Signed-off-by: Adam Li <[email protected]> Co-authored-by: Haoyin Xu <[email protected]>
1 parent 030a064 commit 9c84d6f

File tree

8 files changed

+185
-98
lines changed

8 files changed

+185
-98
lines changed

.spin/cmds.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def coverage(ctx, slowtest):
4141
def setup_submodule(forcesubmodule=False):
4242
"""Build scikit-tree using submodules.
4343
44-
git submodule set-branch -b submodulev2 sktree/_lib/sklearn
44+
git submodule set-branch -b submodulev3 sktree/_lib/sklearn
4545
4646
git submodule update --recursive --remote
4747
@@ -137,7 +137,7 @@ def setup_submodule(forcesubmodule=False):
137137
def build(ctx, meson_args, jobs=None, clean=False, forcesubmodule=False, verbose=False):
138138
"""Build scikit-tree using submodules.
139139
140-
git submodule update --recursive --remote
140+
git submodule update --recursive --remote
141141
142142
To update submodule wrt latest commits:
143143

doc/whats_new/v0.4.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Changelog
1515

1616
- |API| ``FeatureImportanceForest*`` now has a hyperparameter to control the number of permutations is done per forest ``permute_per_forest_fraction``, by `Adam Li`_ (:pr:`145`)
1717
- |Enhancement| Add dataset generators for regression and classification and hypothesis testing, by `Adam Li`_ (:pr:`169`)
18+
- |Fix| Fixes a bug where ``FeatureImportanceForest*`` was unable to be run when calling ``statistic`` with ``covariate_index`` defined for MI, AUC metrics, by `Adam Li`_ (:pr:`164`)
1819

1920
Code and Documentation Contributors
2021
-----------------------------------

sktree/datasets/meson.build

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ py3.install_sources(
1010
subdir: 'sktree/datasets'
1111
)
1212

13-
subdir('tests')
13+
subdir('tests')

sktree/stats/forestht.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,6 @@ def n_estimators(self):
151151
finally:
152152
return self._get_estimator().n_estimators
153153

154-
def _get_estimator(self):
155-
pass
156-
157154
def reset(self):
158155
class_attributes = dir(type(self))
159156
instance_attributes = dir(self)
@@ -190,21 +187,12 @@ def _get_estimators_indices(self, stratifier=None, sample_separate=False):
190187
self._seeds = []
191188
self._n_permutations = 0
192189

193-
num_trees_per_seed = max(
194-
int(permute_forest_fraction * len(self.estimator_.estimators_)), 1
195-
)
196-
for tree_idx, tree in enumerate(self.estimator_.estimators_):
197-
if tree_idx == 0 or tree_idx % num_trees_per_seed == 0:
198-
if tree.random_state is None:
199-
seed = rng.integers(low=0, high=np.iinfo(np.int32).max)
200-
else:
201-
seed = tree.random_state
202-
203-
self._n_permutations += 1
204-
self._seeds.append(seed)
205-
206-
# now that we have the random seeds, we can sample the train/test indices
207-
# deterministically
190+
for itree in range(self.estimator_.n_estimators):
191+
tree = self.estimator_.estimators_[itree]
192+
if tree.random_state is None:
193+
self._seeds.append(rng.integers(low=0, high=np.iinfo(np.int32).max))
194+
else:
195+
self._seeds.append(tree.random_state)
208196
seeds = self._seeds
209197

210198
for idx, tree in enumerate(self.estimator_.estimators_):
@@ -236,7 +224,7 @@ def _get_estimators_indices(self, stratifier=None, sample_separate=False):
236224
random_state=self._seeds,
237225
)
238226

239-
for _ in self.estimator_.estimators_:
227+
for _ in range(self.estimator_.n_estimators):
240228
yield indices_train, indices_test
241229

242230
@property
@@ -394,6 +382,25 @@ def statistic(
394382
self.permuted_estimator_ = self._get_estimator()
395383
estimator = self.permuted_estimator_
396384

385+
if not hasattr(self, "estimator_") or self.estimator_ is None:
386+
self.estimator_ = self._get_estimator()
387+
388+
# Ensure that the estimator_ is fitted at least
389+
if not _is_fitted(self.estimator_) and is_classifier(self.estimator_):
390+
_unique_y = []
391+
for axis in range(y.shape[1]):
392+
_unique_y.append(np.unique(y[:, axis]))
393+
unique_y = np.hstack(_unique_y)
394+
if unique_y.ndim > 1 and unique_y.shape[1] == 1:
395+
unique_y = unique_y.ravel()
396+
X_dummy = np.zeros((unique_y.shape[0], X.shape[1]))
397+
self.estimator_.fit(X_dummy, unique_y)
398+
elif not _is_fitted(estimator):
399+
if y.ndim > 1 and y.shape[1] == 1:
400+
self.estimator_.fit(X[:2], y[:2].ravel())
401+
else:
402+
self.estimator_.fit(X[:2], y[:2])
403+
397404
# Store a cache of the y variable
398405
if is_classifier(self._get_estimator()):
399406
self._y = y.copy()
@@ -434,7 +441,7 @@ def statistic(
434441
)
435442
self._metric = metric
436443

437-
if not is_classifier(self.estimator_) and metric not in REGRESSOR_METRICS:
444+
if not is_classifier(estimator) and metric not in REGRESSOR_METRICS:
438445
raise RuntimeError(
439446
f'Metric must be either "mse" or "mae" if using Regression, got {metric}'
440447
)
@@ -798,7 +805,7 @@ def _statistic(
798805
indices_train, indices_test = self.train_test_samples_[0]
799806

800807
X_train, _ = X[indices_train, :], X[indices_test, :]
801-
y_train, y_test = y[indices_train, :], y[indices_test, :]
808+
y_train, _ = y[indices_train, :], y[indices_test, :]
802809

803810
if covariate_index is not None:
804811
# perform permutation of covariates
@@ -815,10 +822,6 @@ def _statistic(
815822
y_train = y_train.ravel()
816823
estimator.fit(X_train, y_train)
817824

818-
# set variables to compute metric
819-
samples = indices_test
820-
y_true_final = y_test
821-
822825
# TODO: probably a more elegant way of doing this
823826
if self.train_test_split:
824827
# accumulate the predictions across all trees
@@ -1067,9 +1070,6 @@ def _statistic(
10671070
y_train = y_train.ravel()
10681071
estimator.fit(X_train, y_train)
10691072

1070-
# set variables to compute metric
1071-
samples = indices_test
1072-
10731073
# list of tree outputs. Each tree output is (n_samples, n_outputs), or (n_samples,)
10741074
if predict_posteriors:
10751075
# all_proba = Parallel(n_jobs=estimator.n_jobs, verbose=self.verbose)(

sktree/stats/tests/test_coleman.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
{
3333
"estimator": RandomForestRegressor(
3434
max_features="sqrt",
35-
random_state=seed,
3635
n_estimators=75,
3736
n_jobs=-1,
3837
),
@@ -47,7 +46,6 @@
4746
{
4847
"estimator": RandomForestRegressor(
4948
max_features="sqrt",
50-
# random_state=seed,
5149
n_estimators=125,
5250
n_jobs=-1,
5351
),
@@ -81,12 +79,11 @@
8179
{
8280
"estimator": RandomForestRegressor(
8381
max_features="sqrt",
84-
# random_state=seed,
8582
n_estimators=125,
8683
n_jobs=-1,
8784
),
8885
# "random_state": seed,
89-
"permute_forest_fraction": 1.0 / 125,
86+
"permute_forest_fraction": 0.5,
9087
"sample_dataset_per_tree": False,
9188
},
9289
300, # n_samples
@@ -151,7 +148,6 @@ def test_linear_model(hypotester, model_kwargs, n_samples, n_repeats, test_size)
151148
{
152149
"estimator": RandomForestClassifier(
153150
max_features="sqrt",
154-
random_state=seed,
155151
n_estimators=50,
156152
n_jobs=-1,
157153
),
@@ -167,7 +163,6 @@ def test_linear_model(hypotester, model_kwargs, n_samples, n_repeats, test_size)
167163
{
168164
"estimator": RandomForestClassifier(
169165
max_features="sqrt",
170-
# random_state=seed,
171166
n_estimators=100,
172167
n_jobs=-1,
173168
),
@@ -200,8 +195,7 @@ def test_linear_model(hypotester, model_kwargs, n_samples, n_repeats, test_size)
200195
{
201196
"estimator": RandomForestClassifier(
202197
max_features="sqrt",
203-
# random_state=seed,
204-
n_estimators=100,
198+
n_estimators=200,
205199
n_jobs=-1,
206200
),
207201
"permute_forest_fraction": 0.5,

sktree/stats/tests/test_forestht.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,19 @@ def test_featureimportance_forest_permute_pertree(sample_dataset_per_tree):
6161
with pytest.raises(RuntimeError, match="Metric must be"):
6262
est.statistic(iris_X[:n_samples], iris_y[:n_samples], metric="mi")
6363

64+
# covariate index should work with mse
65+
est.reset()
66+
est.statistic(iris_X[:n_samples], iris_y[:n_samples], covariate_index=[1], metric="mse")
67+
with pytest.raises(RuntimeError, match="Metric must be"):
68+
est.statistic(iris_X[:n_samples], iris_y[:n_samples], covariate_index=[1], metric="mi")
69+
6470
# covariate index must be an iterable
71+
est.reset()
6572
with pytest.raises(RuntimeError, match="covariate_index must be an iterable"):
6673
est.statistic(iris_X[:n_samples], iris_y[:n_samples], 0, metric="mi")
6774

6875
# covariate index must be an iterable of ints
76+
est.reset()
6977
with pytest.raises(RuntimeError, match="Not all covariate_index"):
7078
est.statistic(iris_X[:n_samples], iris_y[:n_samples], [0, 1.0], metric="mi")
7179

@@ -98,6 +106,29 @@ def test_featureimportance_forest_permute_pertree(sample_dataset_per_tree):
98106
est.statistic(iris_X[:n_samples], iris_y[:n_samples], metric="mse")
99107

100108

109+
@pytest.mark.parametrize("covariate_index", [None, [0, 1]])
110+
def test_featureimportance_forest_statistic_with_covariate_index(covariate_index):
111+
"""Tests that calling `est.statistic` with covariate_index defined works.
112+
113+
There should be no issue calling `est.statistic` with covariate_index defined.
114+
"""
115+
n_estimators = 10
116+
n_samples = 10
117+
118+
est = FeatureImportanceForestClassifier(
119+
estimator=RandomForestClassifier(
120+
n_estimators=n_estimators,
121+
random_state=seed,
122+
),
123+
permute_forest_fraction=1.0 / n_estimators * 5,
124+
test_size=0.7,
125+
random_state=seed,
126+
)
127+
est.statistic(
128+
iris_X[:n_samples], iris_y[:n_samples], covariate_index=covariate_index, metric="mi"
129+
)
130+
131+
101132
@pytest.mark.parametrize("sample_dataset_per_tree", [True, False])
102133
def test_featureimportance_forest_stratified(sample_dataset_per_tree):
103134
n_samples = 100

sktree/tests/test_honest_forest.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import numpy as np
22
import pytest
3-
from numpy.testing import assert_array_almost_equal
3+
from numpy.testing import assert_allclose, assert_array_almost_equal
44
from sklearn import datasets
5-
from sklearn.metrics import accuracy_score, r2_score
5+
from sklearn.metrics import accuracy_score, r2_score, roc_auc_score
6+
from sklearn.model_selection import cross_val_score
7+
from sklearn.tree import DecisionTreeClassifier as skDecisionTreeClassifier
68
from sklearn.utils import check_random_state
79
from sklearn.utils.estimator_checks import parametrize_with_checks
810

911
from sktree._lib.sklearn.tree import DecisionTreeClassifier
12+
from sktree.datasets import make_quadratic_classification
1013
from sktree.ensemble import HonestForestClassifier
14+
from sktree.stats.utils import _mutual_information
1115
from sktree.tree import ObliqueDecisionTreeClassifier, PatchObliqueDecisionTreeClassifier
1216

1317
CLF_CRITERIONS = ("gini", "entropy")
@@ -252,3 +256,117 @@ def test_importances(dtype, criterion):
252256
est.fit(X, y, sample_weight=scale * sample_weight)
253257
importances_bis = est.feature_importances_
254258
assert np.abs(importances - importances_bis).mean() < tolerance
259+
260+
261+
def test_honest_forest_with_sklearn_trees():
262+
"""Test against regression in power-curves discussed in:
263+
https://github.com/neurodata/scikit-tree/pull/157."""
264+
265+
# generate the high-dimensional quadratic data
266+
X, y = make_quadratic_classification(1024, 4096, noise=True, seed=0)
267+
y = y.squeeze()
268+
print(X.shape, y.shape)
269+
print(np.sum(y) / len(y))
270+
271+
clf = HonestForestClassifier(
272+
n_estimators=10, tree_estimator=skDecisionTreeClassifier(), random_state=0
273+
)
274+
honestsk_scores = cross_val_score(clf, X, y, cv=5)
275+
print(honestsk_scores)
276+
277+
clf = HonestForestClassifier(
278+
n_estimators=10, tree_estimator=DecisionTreeClassifier(), random_state=0
279+
)
280+
honest_scores = cross_val_score(clf, X, y, cv=5)
281+
print(honest_scores)
282+
283+
# XXX: surprisingly, when we use the default which uses the fork DecisionTree,
284+
# we get different results
285+
# clf = HonestForestClassifier(n_estimators=10, random_state=0)
286+
# honest_scores = cross_val_score(clf, X, y, cv=5)
287+
# print(honest_scores)
288+
289+
print(honestsk_scores, honest_scores)
290+
print(np.mean(honestsk_scores), np.mean(honest_scores))
291+
assert_allclose(np.mean(honestsk_scores), np.mean(honest_scores))
292+
293+
294+
def test_honest_forest_with_sklearn_trees_with_auc():
295+
"""Test against regression in power-curves discussed in:
296+
https://github.com/neurodata/scikit-tree/pull/157.
297+
298+
This unit-test tests the equivalent of the AUC using sklearn's DTC
299+
vs our forked version of sklearn's DTC as the base tree.
300+
"""
301+
skForest = HonestForestClassifier(
302+
n_estimators=10, tree_estimator=skDecisionTreeClassifier(), random_state=0
303+
)
304+
305+
Forest = HonestForestClassifier(
306+
n_estimators=10, tree_estimator=DecisionTreeClassifier(), random_state=0
307+
)
308+
309+
max_fpr = 0.1
310+
scores = []
311+
sk_scores = []
312+
for idx in range(10):
313+
X, y = make_quadratic_classification(1024, 4096, noise=True, seed=idx)
314+
y = y.squeeze()
315+
316+
skForest.fit(X, y)
317+
Forest.fit(X, y)
318+
319+
# compute MI
320+
y_pred_proba = skForest.predict_proba(X)[:, 1].reshape(-1, 1)
321+
sk_mi = roc_auc_score(y, y_pred_proba, max_fpr=max_fpr)
322+
323+
y_pred_proba = Forest.predict_proba(X)[:, 1].reshape(-1, 1)
324+
mi = roc_auc_score(y, y_pred_proba, max_fpr=max_fpr)
325+
326+
scores.append(mi)
327+
sk_scores.append(sk_mi)
328+
329+
print(scores, sk_scores)
330+
print(np.mean(scores), np.mean(sk_scores))
331+
print(np.std(scores), np.std(sk_scores))
332+
assert_allclose(np.mean(sk_scores), np.mean(scores), atol=0.005)
333+
334+
335+
def test_honest_forest_with_sklearn_trees_with_mi():
336+
"""Test against regression in power-curves discussed in:
337+
https://github.com/neurodata/scikit-tree/pull/157.
338+
339+
This unit-test tests the equivalent of the MI using sklearn's DTC
340+
vs our forked version of sklearn's DTC as the base tree.
341+
"""
342+
skForest = HonestForestClassifier(
343+
n_estimators=10, tree_estimator=skDecisionTreeClassifier(), random_state=0
344+
)
345+
346+
Forest = HonestForestClassifier(
347+
n_estimators=10, tree_estimator=DecisionTreeClassifier(), random_state=0
348+
)
349+
350+
scores = []
351+
sk_scores = []
352+
for idx in range(10):
353+
X, y = make_quadratic_classification(1024, 4096, noise=True, seed=idx)
354+
y = y.squeeze()
355+
356+
skForest.fit(X, y)
357+
Forest.fit(X, y)
358+
359+
# compute MI
360+
sk_posterior = skForest.predict_proba(X)
361+
sk_score = _mutual_information(y, sk_posterior)
362+
363+
posterior = Forest.predict_proba(X)
364+
score = _mutual_information(y, posterior)
365+
366+
scores.append(score)
367+
sk_scores.append(sk_score)
368+
369+
print(scores, sk_scores)
370+
print(np.mean(scores), np.mean(sk_scores))
371+
print(np.std(scores), np.std(sk_scores))
372+
assert_allclose(np.mean(sk_scores), np.mean(scores), atol=0.005)

0 commit comments

Comments
 (0)