Skip to content

Commit eb91bd8

Browse files
authored
Merge pull request #59 from mmschlk/tree-explainer-random-forrest
Bugfix and Support for Random Forests in TreeExplainer
2 parents 4cb0b32 + acb9bae commit eb91bd8

File tree

14 files changed

+348
-342
lines changed

14 files changed

+348
-342
lines changed

shapiq/explainer/tree/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def reduce_feature_complexity(self) -> None:
109109
Feature '8' is 'renamed' to '2' such that in the internal representation a one-hot vector
110110
(and matrices) of length 3 suffices to represent the feature indices.
111111
"""
112-
if self.n_features_in_tree < self.max_feature_id:
112+
if self.n_features_in_tree < self.max_feature_id + 1:
113113
new_feature_ids = set(range(self.n_features_in_tree))
114114
mapping_old_new = {old_id: new_id for new_id, old_id in enumerate(self.feature_ids)}
115115
mapping_new_old = {new_id: old_id for new_id, old_id in enumerate(self.feature_ids)}

shapiq/explainer/tree/conversion/sklearn.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,19 @@
11
"""This module contains functions for converting scikit-learn decision trees to the format used by
22
shapiq."""
33

4-
from typing import Optional, Union
4+
from typing import Optional
55

66
import numpy as np
77
from explainer.tree.base import TreeModel
88

99
from shapiq.utils import safe_isinstance
10-
11-
try:
12-
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
13-
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
14-
except ImportError:
15-
pass
10+
from shapiq.utils.types import Model
1611

1712

1813
def convert_sklearn_forest(
19-
tree_model: Union["RandomForestRegressor", "RandomForestClassifier"],
20-
class_label: int = 0,
21-
output_type: Optional[str] = None,
14+
tree_model: Model,
15+
class_label: Optional[int] = None,
16+
output_type: str = "raw",
2217
) -> list[TreeModel]:
2318
"""Transforms a scikit-learn random forest to the format used by shapiq.
2419
@@ -33,8 +28,6 @@ def convert_sklearn_forest(
3328
The converted random forest model.
3429
"""
3530
scaling = 1.0 / len(tree_model.estimators_)
36-
if not safe_isinstance(tree_model, "sklearn.ensemble.RandomForestClassifier"):
37-
output_type = None
3831
return [
3932
convert_sklearn_tree(
4033
tree, scaling=scaling, class_label=class_label, output_type=output_type
@@ -44,8 +37,8 @@ def convert_sklearn_forest(
4437

4538

4639
def convert_sklearn_tree(
47-
tree_model: Union["DecisionTreeRegressor", "DecisionTreeClassifier"],
48-
class_label: int = 0,
40+
tree_model: Model,
41+
class_label: Optional[int] = None,
4942
scaling: float = 1.0,
5043
output_type: str = "raw",
5144
) -> TreeModel:
@@ -63,14 +56,16 @@ def convert_sklearn_tree(
6356
The converted decision tree model.
6457
"""
6558
tree_values = tree_model.tree_.value.copy() * scaling
59+
# set class label if not given and model is a classifier
60+
if safe_isinstance(tree_model, "sklearn.tree.DecisionTreeClassifier") and class_label is None:
61+
class_label = 1
62+
6663
if class_label is not None:
6764
# turn node values into probabilities
6865
if len(tree_values.shape) == 3:
69-
tree_values = tree_values / np.sum(tree_values, axis=2, keepdims=True)
70-
tree_values = tree_values[:, 0, class_label]
71-
else:
72-
tree_values = tree_values / np.sum(tree_values, axis=1, keepdims=True)
73-
tree_values = tree_values[:, class_label]
66+
tree_values = tree_values[:, 0, :]
67+
tree_values = tree_values / np.sum(tree_values, axis=1, keepdims=True)
68+
tree_values = tree_values[:, class_label]
7469
if output_type != "raw":
7570
# TODO: Add support for logits output type
7671
raise NotImplementedError("Only raw output types are currently supported.")

shapiq/explainer/tree/conversion/xgboost.py

Lines changed: 0 additions & 201 deletions
This file was deleted.

shapiq/explainer/tree/explainer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""This module contains the TreeExplainer class making use of the TreeSHAPIQ algorithm for
22
computing any-order Shapley Interactions for tree ensembles."""
33
import copy
4-
from typing import Any, Union
4+
from typing import Any, Optional, Union
55

66
import numpy as np
77
from explainer._base import Explainer
@@ -17,17 +17,20 @@ def __init__(
1717
model: Union[dict, TreeModel, Any],
1818
max_order: int = 2,
1919
min_order: int = 1,
20+
class_label: Optional[int] = None,
21+
output_type: str = "raw",
2022
) -> None:
2123
# validate and parse model
22-
validated_model = _validate_model(model) # the parsed and validated model
23-
24+
validated_model = _validate_model(model, class_label=class_label, output_type=output_type)
2425
self._trees: Union[TreeModel, list[TreeModel]] = copy.deepcopy(validated_model)
2526
if not isinstance(self._trees, list):
2627
self._trees = [self._trees]
2728
self._n_trees = len(self._trees)
2829

29-
self._max_order = max_order
30-
self._min_order = min_order
30+
self._max_order: int = max_order
31+
self._min_order: int = min_order
32+
self._class_label: Optional[int] = class_label
33+
self._output_type: str = output_type
3134

3235
# setup explainers for all trees
3336
self._treeshapiq_explainers: list[TreeSHAPIQ] = [

shapiq/explainer/tree/treeshapiq.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""This module contains the tree explainer implementation."""
22
import copy
33
from math import factorial
4+
from typing import Any, Optional, Union
45

56
import numpy as np
67
from approximator import transforms_sii_to_ksii
@@ -30,7 +31,7 @@ class TreeSHAPIQ:
3031

3132
def __init__(
3233
self,
33-
model: TreeModel,
34+
model: Union[dict, TreeModel, Any],
3435
max_order: int = 2,
3536
min_order: int = 1,
3637
interaction_type: str = "k-SII",
@@ -507,24 +508,22 @@ def _get_N_cii(self, interpolated_poly, order) -> np.ndarray[float]:
507508
)
508509
return Ns
509510

510-
def _get_subset_weight_cii(self, t, order) -> float:
511+
def _get_subset_weight_cii(self, t, order) -> Optional[float]:
511512
# TODO: add docstring
512513
if self._interaction_type == "STI":
513514
return self._max_order / (
514515
self._n_features_in_tree * binom(self._n_features_in_tree - 1, t)
515516
)
516-
elif self._interaction_type == "FSI":
517+
if self._interaction_type == "FSI":
517518
return (
518519
factorial(2 * self._max_order - 1)
519520
/ factorial(self._max_order - 1) ** 2
520521
* factorial(self._max_order + t - 1)
521522
* factorial(self._n_features_in_tree - t - 1)
522523
/ factorial(self._n_features_in_tree + self._max_order - 1)
523524
)
524-
elif self._interaction_type == "BZF":
525+
if self._interaction_type == "BZF":
525526
return 1 / (2 ** (self._n_features_in_tree - order))
526-
else:
527-
raise ValueError("Interaction type not supported")
528527

529528
@staticmethod
530529
def _get_N_id(D) -> np.ndarray[float]:

0 commit comments

Comments
 (0)