Skip to content

Commit 7fc517e

Browse files
authored
[ENH] Efficient out-of-bag computation per honest tree (neurodata#200)
* Add oob_samples_ property to honest forest * Update submodule * Update return types to in line with scikit-learn --------- Signed-off-by: Adam Li <[email protected]>
1 parent 4703a82 commit 7fc517e

35 files changed

+492
-157
lines changed

.github/workflows/pr_checks.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ jobs:
3838
exit 0
3939
fi
4040
all_changelogs=$(cat ./doc/whats_new/v*.rst)
41-
if [[ "$all_changelogs" =~ :pr:\`$PR_NUMBER\` ]]
41+
if [[ "$all_changelogs" =~ :pr:\`#$PR_NUMBER\` ]]
4242
then
4343
echo "Changelog has been updated."
4444
# If the pull request is milestoned check the correspondent changelog
4545
if exist -f ./doc/whats_new/v${TAGGED_MILESTONE:0:4}.rst
4646
then
4747
expected_changelog=$(cat ./doc/whats_new/v${TAGGED_MILESTONE:0:4}.rst)
48-
if [[ "$expected_changelog" =~ :pr:\`$PR_NUMBER\` ]]
48+
if [[ "$expected_changelog" =~ :pr:\`#$PR_NUMBER\` ]]
4949
then
5050
echo "Changelog and milestone correspond."
5151
else
@@ -56,7 +56,7 @@ jobs:
5656
fi
5757
fi
5858
else
59-
echo "A Changelog entry is missing."
59+
echo "A Changelog entry is missing for :pr:\`#$PR_NUMBER\`"
6060
echo ""
6161
echo "Please add an entry to the changelog at 'doc/whats_new/v*.rst'"
6262
echo "to document your change assuming that the PR will be merged"

.github/workflows/style.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ jobs:
2929
steps:
3030
- name: Checkout repository
3131
uses: actions/checkout@v4
32-
- name: Setup Python 3.10
32+
- name: Setup Python 3.11
3333
uses: actions/setup-python@v5
3434
with:
35-
python-version: "3.10"
35+
python-version: "3.11"
3636
architecture: "x64"
3737

3838
- name: Install packages for Ubuntu

.pre-commit-config.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ repos:
55
- id: black
66
args: [--quiet]
77

8+
- repo: https://github.com/MarcoGorelli/cython-lint
9+
rev: v0.16.0
10+
hooks:
11+
- id: cython-lint
12+
- id: double-quote-cython-strings
13+
814
# Ruff sktree
915
- repo: https://github.com/astral-sh/ruff-pre-commit
1016
rev: v0.1.6
@@ -65,7 +71,7 @@ repos:
6571
- id: rstcheck
6672
additional_dependencies:
6773
- tomli
68-
files: ^doc/.*\.(rst|inc)$
74+
files: ^(?!doc/use\.rst$).*\.(rst|inc)$
6975

7076
ci:
7177
autofix_prs: false

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@
246246
"TreeBuilder",
247247
"joint_rank",
248248
"n_dim",
249+
"n_samples_bootstrap",
249250
}
250251

251252
# validation

doc/use.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ to learn everything you need!
88

99
.. rstcheck: ignore-next-code-block
1010
.. include:: auto_examples/index.rst
11-
:start-after: :orphan:
11+
:start-after: :orphan:

doc/whats_new/v0.6.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ Changelog
2929
has a generative model based on Trunk and banded covariance, :func:`sktree.datasets.approximate_clf_mutual_information` and
3030
:func:`sktree.datasets.approximate_clf_mutual_information_with_monte_carlo` to
3131
approximate mutual information either numerically or via Monte-Carlo, by `Adam Li`_ and `Haoyin Xu`_ (:pr:`#199`).
32+
- |Enhancement| :class:`sktree.HonestForestClassifier` now has a fitted
33+
property ``oob_samples_``, which reproduces the sample indices per tree that is out
34+
of bag, by `Adam Li`_ (:pr:`#200`).
3235

3336

3437
Code and Documentation Contributors

examples/calibration/plot_overlapping_gaussians.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
(
6666
"IRF",
6767
CalibratedClassifierCV(
68-
base_estimator=RandomForestClassifier(
68+
estimator=RandomForestClassifier(
6969
n_estimators=n_estimators // clf_cv,
7070
max_features=max_features,
7171
n_jobs=n_jobs,
@@ -77,7 +77,7 @@
7777
(
7878
"SigRF",
7979
CalibratedClassifierCV(
80-
base_estimator=RandomForestClassifier(
80+
estimator=RandomForestClassifier(
8181
n_estimators=n_estimators // clf_cv,
8282
max_features=max_features,
8383
n_jobs=n_jobs,

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ ignore_roles = [
293293
]
294294
report_level = "WARNING"
295295
ignore = ["SEVERE/4"]
296-
paths = ["doc/use.rst"]
297296

298297
[tool.ruff]
299298
extend-exclude = [

sktree/_lib/sklearn_fork

Submodule sklearn_fork updated 81 files

sktree/datasets/hyppo.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def make_quadratic_classification(n_samples: int, n_features: int, noise=False,
5050
def make_trunk_classification(
5151
n_samples,
5252
n_dim=10,
53+
n_informative=10,
5354
m_factor: int = -1,
5455
rho: int = 0,
5556
band_type: str = "ma",
@@ -76,6 +77,9 @@ def make_trunk_classification(
7677
n_dim : int, optional
7778
The dimensionality of the dataset and the number of
7879
unique labels, by default 10.
80+
n_informative : int, optional
81+
The informative dimensions. All others for ``n_dim - n_informative``
82+
are uniform noise.
7983
m_factor : int, optional
8084
The multiplicative factor to apply to the mean-vector of the first
8185
distribution to obtain the mean-vector of the second distribution.
@@ -108,25 +112,30 @@ def make_trunk_classification(
108112
----------
109113
.. footbibliography::
110114
"""
115+
if n_dim < n_informative:
116+
raise ValueError(
117+
f"Number of informative dimensions {n_informative} must be less than number "
118+
f"of dimensions, {n_dim}"
119+
)
111120
rng = np.random.default_rng(seed=seed)
112121

113-
mu_1 = np.array([1 / np.sqrt(i) for i in range(1, n_dim + 1)])
122+
mu_1 = np.array([1 / np.sqrt(i) for i in range(1, n_informative + 1)])
114123
mu_0 = m_factor * mu_1
115124

116125
if rho != 0:
117126
if band_type == "ma":
118-
cov = _moving_avg_cov(n_dim, rho)
127+
cov = _moving_avg_cov(n_informative, rho)
119128
elif band_type == "ar":
120-
cov = _autoregressive_cov(n_dim, rho)
129+
cov = _autoregressive_cov(n_informative, rho)
121130
else:
122131
raise ValueError(f'Band type {band_type} must be one of "ma", or "ar".')
123132
else:
124-
cov = np.identity(n_dim)
133+
cov = np.identity(n_informative)
125134

126135
if mix < 0 or mix > 1:
127136
raise ValueError("Mix must be between 0 and 1.")
128137

129-
if n_dim > 1000:
138+
if n_informative > 1000:
130139
method = "cholesky"
131140
else:
132141
method = "svd"
@@ -139,13 +148,29 @@ def make_trunk_classification(
139148
)
140149
)
141150
else:
151+
mixture_idx = rng.choice(
152+
[0, 1], n_samples // 2, replace=True, shuffle=True, p=[mix, 1 - mix]
153+
)
154+
X_mixture = np.zeros((n_samples // 2, len(mu_1)))
155+
for idx in range(n_samples // 2):
156+
if mixture_idx[idx] == 1:
157+
X_sample = rng.multivariate_normal(mu_1, cov, 1, method=method)
158+
else:
159+
X_sample = rng.multivariate_normal(mu_0, cov, 1, method=method)
160+
X_mixture[idx, :] = X_sample
161+
142162
X = np.vstack(
143163
(
144-
rng.multivariate_normal(np.zeros(n_dim), cov, n_samples // 2, method=method),
145-
(1 - mix) * rng.multivariate_normal(mu_1, cov, n_samples // 2, method=method)
146-
+ mix * rng.multivariate_normal(mu_0, cov, n_samples // 2, method=method),
164+
rng.multivariate_normal(
165+
np.zeros(n_informative), cov, n_samples // 2, method=method
166+
),
167+
X_mixture,
147168
)
148169
)
170+
171+
if n_dim > n_informative:
172+
X = np.hstack((X, rng.uniform(low=0, high=1, size=(n_samples, n_dim - n_informative))))
173+
149174
y = np.concatenate((np.zeros(n_samples // 2), np.ones(n_samples // 2)))
150175

151176
if return_params:
@@ -208,19 +233,19 @@ def approximate_clf_mutual_information(
208233
# this implicitly assumes that the signal of interest is between -10 and 10
209234
scale = 10
210235
n_dims = [cov.shape[1] for cov in covs]
211-
lims = [[-scale, scale]] * n_dims
236+
lims = [[-scale, scale]] * max(n_dims)
212237

213238
# Compute entropy and X and Y.
214239
def func(*args):
215240
x = np.array(args)
216241
p = 0
217242
for k in range(len(means)):
218-
p += class_probs[k] * multivariate_normal(seed=seed).pdf(x, means[k], covs[k])
243+
p += class_probs[k] * multivariate_normal.pdf(x, means[k], covs[k])
219244
return -p * np.log(p) / np.log(base)
220245

221246
# numerically integrate H(X)
222-
opts = dict(limit=1000)
223-
H_X, int_err = nquad(func, lims, opts=opts)
247+
# opts = dict(limit=1000)
248+
H_X, int_err = nquad(func, lims)
224249

225250
# Compute MI.
226251
H_XY = 0

0 commit comments

Comments
 (0)