Skip to content

Commit 5478906

Browse files
authored
Merge pull request #51 from mmschlk/26-add-treeshap-iq-explainer
adds TreeExplainer with TreeSHAP-IQ
2 parents 596acd3 + e5f98bd commit 5478906

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2357
-418
lines changed

README.md

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,14 @@ Explain your models with Shapley interaction values like the k-SII values:
7474
```python
7575
# train a model
7676
from sklearn.ensemble import RandomForestRegressor
77+
7778
model = RandomForestRegressor(n_estimators=50, random_state=42)
7879
model.fit(x_train, y_train)
7980

8081
# explain with k-SII interaction scores
81-
from shapiq import InteractionExplainer
82-
explainer = InteractionExplainer(
82+
from shapiq import TabularExplainer
83+
84+
explainer = TabularExplainer(
8385
model=model.predict,
8486
background_data=x_train,
8587
index="k-SII",
@@ -88,19 +90,19 @@ explainer = InteractionExplainer(
8890
interaction_values = explainer.explain(x_explain, budget=2000)
8991
print(interaction_values)
9092

91-
>>> InteractionValues(
92-
>>> index=k-SII, max_order=2, min_order=1, estimated=True, estimation_budget=2000,
93-
>>> values={
94-
>>> (0,): -91.0403, # main effect for feature 0
95-
>>> (1,): 4.1264, # main effect for feature 1
96-
>>> (2,): -0.4724, # main effect for feature 2
97-
>>> ...
98-
>>> (0, 1): -0.8073, # 2-way interaction for feature 0 and 1
99-
>>> (0, 2): 2.469, # 2-way interaction for feature 0 and 2
100-
>>> ...
101-
>>> (10, 11): 0.4057 # 2-way interaction for feature 10 and 11
102-
>>> }
103-
>>> )
93+
>> > InteractionValues(
94+
>> > index = k - SII, max_order = 2, min_order = 1, estimated = True, estimation_budget = 2000,
95+
>> > values = {
96+
>> > (0,): -91.0403, # main effect for feature 0
97+
>> > (1,): 4.1264, # main effect for feature 1
98+
>> > (2,): -0.4724, # main effect for feature 2
99+
>> > ...
100+
>> > (0, 1): -0.8073, # 2-way interaction for feature 0 and 1
101+
>> > (0, 2): 2.469, # 2-way interaction for feature 0 and 2
102+
>> > ...
103+
>> > (10, 11): 0.4057 # 2-way interaction for feature 10 and 11
104+
>> >}
105+
>> > )
104106
```
105107

106108
## 📊 Visualize your Interactions

shapiq/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,18 @@
1515
from .datasets import load_bike
1616

1717
# explainer classes
18-
from .explainer import InteractionExplainer
18+
from .explainer import TabularExplainer, TreeExplainer
1919

2020
# game classes
2121
from .games import DummyGame
22+
from .interaction_values import InteractionValues
2223

2324
# plotting functions
2425
from .plot import network_plot, stacked_bar_plot
2526

2627
# public utils functions
2728
from .utils import ( # sets.py # tree.py
28-
get_conditional_sample_weights,
2929
get_explicit_subsets,
30-
get_parent_array,
3130
powerset,
3231
safe_isinstance,
3332
split_subsets_budget,
@@ -36,14 +35,17 @@
3635
__all__ = [
3736
# version
3837
"__version__",
38+
# base
39+
"InteractionValues",
3940
# approximators
4041
"ShapIQ",
4142
"PermutationSamplingSII",
4243
"PermutationSamplingSTI",
4344
"RegressionSII",
4445
"RegressionFSI",
4546
# explainers
46-
"InteractionExplainer",
47+
"TabularExplainer",
48+
"TreeExplainer",
4749
# games
4850
"DummyGame",
4951
# plots
@@ -53,8 +55,6 @@
5355
"powerset",
5456
"get_explicit_subsets",
5557
"split_subsets_budget",
56-
"get_conditional_sample_weights",
57-
"get_parent_array",
5858
"safe_isinstance",
5959
# datasets
6060
"load_bike",

shapiq/approximator/_base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from typing import Callable, Optional
55

66
import numpy as np
7-
from approximator._config import AVAILABLE_INDICES
8-
from approximator._interaction_values import InteractionValues
9-
from approximator._utils import _generate_interaction_lookup
7+
from interaction_values import InteractionValues
8+
9+
from shapiq.utils.sets import generate_interaction_lookup
10+
11+
from ._config import AVAILABLE_INDICES
1012

1113
__all__ = [
1214
"Approximator",
@@ -65,7 +67,7 @@ def __init__(
6567
self.max_order: int = max_order
6668
self.min_order: int = self.max_order if self.top_order else 1
6769
self.iteration_cost: int = 1 # default value, can be overwritten by subclasses
68-
self._interaction_lookup = _generate_interaction_lookup(
70+
self._interaction_lookup = generate_interaction_lookup(
6971
self.n, self.min_order, self.max_order
7072
)
7173
self._random_state: Optional[int] = random_state

shapiq/approximator/_interaction_values.py

Lines changed: 0 additions & 145 deletions
This file was deleted.

shapiq/approximator/_utils.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

shapiq/approximator/k_sii.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from typing import Optional, Union
44

55
import numpy as np
6-
from approximator._base import Approximator
7-
from approximator._interaction_values import InteractionValues
8-
from approximator._utils import _generate_interaction_lookup
6+
from interaction_values import InteractionValues
97
from scipy.special import bernoulli
108

11-
from shapiq.utils import powerset
9+
from shapiq.approximator._base import Approximator
10+
from shapiq.utils import generate_interaction_lookup, powerset
1211

1312

1413
class KShapleyMixin:
@@ -86,7 +85,7 @@ def transforms_sii_to_ksii(
8685
)
8786
elif n is not None and max_order is not None:
8887
if interaction_lookup is None:
89-
interaction_lookup = _generate_interaction_lookup(n, 1, max_order)
88+
interaction_lookup = generate_interaction_lookup(n, 1, max_order)
9089
return _calculate_ksii_from_sii(sii_values, n, max_order, interaction_lookup)
9190
else:
9291
raise ValueError(
@@ -120,9 +119,12 @@ def _calculate_ksii_from_sii(
120119
nsii_values = np.zeros_like(sii_values)
121120
# all subsets S with 1 <= |S| <= max_order
122121
for subset in powerset(set(range(n)), min_size=1, max_size=max_order):
123-
interaction_index = interaction_lookup[subset]
124122
interaction_size = len(subset)
125-
ksii_value = sii_values[interaction_index]
123+
try:
124+
interaction_index = interaction_lookup[subset]
125+
ksii_value = sii_values[interaction_index]
126+
except KeyError:
127+
continue # a zero value is not scaled # TODO: verify this
126128
# go over all subsets T of length |S| + 1, ..., n that contain S
127129
for T in powerset(set(range(n)), min_size=interaction_size + 1, max_size=max_order):
128130
if set(subset).issubset(T):

shapiq/approximator/permutation/sii.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import numpy as np
66
from approximator._base import Approximator
7-
from approximator._interaction_values import InteractionValues
87
from approximator.k_sii import KShapleyMixin
8+
from interaction_values import InteractionValues
99
from utils import powerset
1010

1111

shapiq/approximator/permutation/sti.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
from approximator._base import Approximator
8-
from approximator._interaction_values import InteractionValues
8+
from interaction_values import InteractionValues
99
from scipy.special import binom
1010
from utils import get_explicit_subsets, powerset
1111

shapiq/approximator/regression/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import numpy as np
66
from approximator._base import Approximator
7-
from approximator._interaction_values import InteractionValues
87
from approximator.sampling import ShapleySamplingMixin
8+
from interaction_values import InteractionValues
99
from scipy.special import bernoulli, binom
1010
from utils import powerset
1111

shapiq/approximator/shapiq/shapiq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
import numpy as np
77
from approximator._base import Approximator
8-
from approximator._interaction_values import InteractionValues
98
from approximator.k_sii import KShapleyMixin
109
from approximator.sampling import ShapleySamplingMixin
10+
from interaction_values import InteractionValues
1111
from utils import powerset
1212

1313
AVAILABLE_INDICES_SHAPIQ = {"SII", "STI", "FSI", "k-SII"}

0 commit comments

Comments
 (0)