Skip to content

Commit 0d80cc7

Browse files
authored
add tests against shap for extra trees (mmschlk#373)
1 parent 3f84359 commit 0d80cc7

File tree

2 files changed

+92
-1
lines changed

2 files changed

+92
-1
lines changed

tests/fixtures/models.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44

55
import numpy as np
66
import pytest
7-
from sklearn.ensemble import IsolationForest, RandomForestClassifier, RandomForestRegressor
7+
from sklearn.ensemble import (
8+
ExtraTreesClassifier,
9+
ExtraTreesRegressor,
10+
IsolationForest,
11+
RandomForestClassifier,
12+
RandomForestRegressor,
13+
)
814
from sklearn.linear_model import LinearRegression, LogisticRegression
915
from sklearn.model_selection import train_test_split
1016
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
@@ -287,3 +293,24 @@ def if_clf_model(if_clf_dataset) -> IsolationForest:
287293
model = IsolationForest(random_state=42, n_estimators=3)
288294
model.fit(X, y)
289295
return model
296+
297+
298+
# Extra trees model
299+
@pytest.fixture
300+
def et_clf_model(background_clf_dataset) -> Model:
301+
"""Return a simple (classification) extra trees model."""
302+
303+
X, y = background_clf_dataset
304+
model = ExtraTreesClassifier(random_state=42, max_depth=3, n_estimators=3)
305+
model.fit(X, y)
306+
return model
307+
308+
309+
@pytest.fixture
310+
def et_reg_model(background_reg_dataset) -> Model:
311+
"""Return a simple (regression) extra trees model."""
312+
313+
X, y = background_reg_dataset
314+
model = ExtraTreesRegressor(random_state=42, max_depth=3, n_estimators=3)
315+
model.fit(X, y)
316+
return model

tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,67 @@ def test_decision_stumps(background_reg_dataset, background_clf_dataset):
466466
continue
467467

468468
assert pred == pytest.approx(efficiency, rel=1e-5)
469+
470+
471+
def test_extra_trees_clf(et_clf_model, background_clf_data):
472+
"""Test the shapiq implementation of TreeSHAP vs. SHAP's implementation for Extra Trees."""
473+
explanation_instance = 1
474+
class_label = 1
475+
476+
# the following code is used to get the shap values from the SHAP implementation
477+
"""
478+
#import shap
479+
# model_copy = copy.deepcopy(et_clf_model)
480+
# explainer_shap = shap.TreeExplainer(model=model_copy)
481+
# baseline_shap = float(explainer_shap.expected_value[class_label])
482+
# x_explain_shap = copy.deepcopy(background_clf_data[explanation_instance].reshape(1, -1))
483+
# sv_shap_all_classes = explainer_shap.shap_values(x_explain_shap)
484+
# sv_shap = sv_shap_all_classes[0][:, class_label]
485+
# print(sv_shap_all_classes, format(baseline_shap, '.20f'))
486+
""" # noqa: ERA001
487+
sv_shap = [0.00207427, 0.00949552, -0.00108266, -0.03825587, -0.02694092, 0.0170296, 0.02046364]
488+
sv_shap = np.asarray(sv_shap)
489+
baseline_shap = 0.34000000000000002
490+
491+
# compute with shapiq
492+
explainer_shapiq = TreeExplainer(
493+
model=et_clf_model, max_order=1, index="SV", class_index=class_label
494+
)
495+
x_explain_shapiq = copy.deepcopy(background_clf_data[explanation_instance])
496+
sv_shapiq = explainer_shapiq.explain(x=x_explain_shapiq)
497+
sv_shapiq_values = sv_shapiq.get_n_order_values(1)
498+
baseline_shapiq = sv_shapiq.baseline_value
499+
500+
assert baseline_shap == pytest.approx(baseline_shapiq, rel=1e-4)
501+
assert np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5)
502+
503+
504+
def test_extra_trees_reg(et_reg_model, background_reg_data):
505+
"""Test the shapiq implementation of TreeSHAP vs. SHAP's implementation for Extra Trees."""
506+
explanation_instance = 1
507+
508+
# the following code is used to get the shap values from the SHAP implementation
509+
"""
510+
# import shap
511+
# model_copy = copy.deepcopy(et_reg_model)
512+
# explainer_shap = shap.TreeExplainer(model=model_copy)
513+
# baseline_shap = float(explainer_shap.expected_value)
514+
# x_explain_shap = copy.deepcopy(background_reg_data[explanation_instance].reshape(1, -1))
515+
# sv_shap_all_classes = explainer_shap.shap_values(x_explain_shap)
516+
# sv_shap = sv_shap_all_classes[0]
517+
# print(sv_shap_all_classes, format(baseline_shap, '.20f'))
518+
""" # noqa: ERA001
519+
sv_shap = [19.28673017, -19.87182634, 0.0, 10.89201698, -9.62498263, 0.35992212, 42.31290091]
520+
sv_shap = np.asarray(sv_shap)
521+
print(sv_shap)
522+
baseline_shap = -2.56682283435175007
523+
524+
# compute with shapiq
525+
explainer_shapiq = TreeExplainer(model=et_reg_model, max_order=1, index="SV")
526+
x_explain_shapiq = copy.deepcopy(background_reg_data[explanation_instance])
527+
sv_shapiq = explainer_shapiq.explain(x=x_explain_shapiq)
528+
sv_shapiq_values = sv_shapiq.get_n_order_values(1)
529+
baseline_shapiq = sv_shapiq.baseline_value
530+
531+
assert baseline_shap == pytest.approx(baseline_shapiq, rel=1e-4)
532+
assert np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5)

0 commit comments

Comments
 (0)