Skip to content

Commit 4494843

Browse files
committed
adds interaction_type as parameter to TreeExplainer fixes #64 and closes #64
1 parent 87fa1ca commit 4494843

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

shapiq/approximator/k_sii.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,11 @@ def _calculate_ksii_from_sii(
112112
`min_order`, and `max_order` parameters. Defaults to `None`.
113113
114114
Returns:
115-
The nSII values.
115+
The k-SII values.
116116
"""
117+
if interaction_lookup is None:
118+
interaction_lookup = generate_interaction_lookup(n, 1, max_order)
119+
117120
# compute nSII values from SII values
118121
bernoulli_numbers = bernoulli(max_order)
119122
nsii_values = np.zeros_like(sii_values)
@@ -128,8 +131,11 @@ def _calculate_ksii_from_sii(
128131
# go over all subsets T of length |S| + 1, ..., n that contain S
129132
for T in powerset(set(range(n)), min_size=interaction_size + 1, max_size=max_order):
130133
if set(subset).issubset(T):
131-
effect_index = interaction_lookup[T] # get the index of T
132-
effect_value = sii_values[effect_index] # get the effect of T
134+
try:
135+
effect_index = interaction_lookup[T] # get the index of T
136+
effect_value = sii_values[effect_index] # get the effect of T
137+
except KeyError:
138+
effect_value = 0 # if T is not in the interaction_lookup # TODO: verify this
133139
bernoulli_factor = bernoulli_numbers[len(T) - interaction_size]
134140
ksii_value += bernoulli_factor * effect_value
135141
nsii_values[interaction_index] = ksii_value

shapiq/explainer/tree/explainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
model: Union[dict, TreeModel, Any],
1818
max_order: int = 2,
1919
min_order: int = 1,
20+
interaction_type: str = "k-SII",
2021
class_label: Optional[int] = None,
2122
output_type: str = "raw",
2223
) -> None:
@@ -34,7 +35,7 @@ def __init__(
3435

3536
# setup explainers for all trees
3637
self._treeshapiq_explainers: list[TreeSHAPIQ] = [
37-
TreeSHAPIQ(model=_tree, max_order=self._max_order, interaction_type="SII")
38+
TreeSHAPIQ(model=_tree, max_order=self._max_order, interaction_type=interaction_type)
3839
for _tree in self._trees
3940
]
4041

tests/tests_approximators/test_approximator_ksii_estimation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,7 @@ def test_nsii_estimation(sii_approximator, ksii_approximator):
7070
assert isinstance(transformed, np.ndarray)
7171
with pytest.raises(ValueError):
7272
_ = transforms_sii_to_ksii(sii_estimates.values)
73+
# check with interaction_lookup = None
74+
sii_estimates.interaction_lookup = None
75+
transformed = transforms_sii_to_ksii(sii_estimates)
76+
assert transformed.index == "k-SII"

0 commit comments

Comments
 (0)