-
Notifications
You must be signed in to change notification settings - Fork 6
/
explainers.py
149 lines (135 loc) · 4.86 KB
/
explainers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from emutils.preprocessing import process_data
from cfshap.attribution import (
TreeExplainer,
CompositeExplainer,
)
from cfshap.counterfactuals import (
BisectionProjectionDBCounterfactuals,
KNNCounterfactuals,
compose_diverse_counterfactual_method,
)
from cfshap.background import (
DifferentLabelBackgroundGenerator,
DifferentPredictionBackgroundGenerator,
)
from cfshap.trend import TrendEstimator
MAXSAMPLES = 10000
NN_Ks = [1, 3, 5, 10, 20, 50, 100, 250, 500, 1000]
trend_estimator = TrendEstimator('mean')
def create_counterfactual_explainers(data,
y,
model,
feature_names,
multiscaler,
random_state,
verbose,
feature_trends=None):
cf_methods = {
**dict(
diff_pred=DifferentPredictionBackgroundGenerator(
model,
data,
random_state=random_state,
max_samples=MAXSAMPLES,
),
diff_label=DifferentLabelBackgroundGenerator(
model,
data,
y,
random_state=random_state,
max_samples=MAXSAMPLES,
),
# Default SHAP (only 100 samples)
diff_pred_100=DifferentPredictionBackgroundGenerator(
model,
data,
random_state=random_state,
max_samples=100,
),
diff_label_100=DifferentLabelBackgroundGenerator(
model,
data,
y,
random_state=random_state,
max_samples=100,
),
),
# -------------- KNN ------------------------------------------------
**{
f"knn{k}_qL1": KNNCounterfactuals(
model=model,
X=data,
n_neighbors=k,
distance='cityblock',
scaler=multiscaler.get_transformer('quantile'),
max_samples=MAXSAMPLES,
)
for k in NN_Ks
},
}
cf_methods.update({
# -------------- KNN Projected on DB --------------------------------------
f"cone-{wrapped_name}": compose_diverse_counterfactual_method(
BisectionProjectionDBCounterfactuals(
model=model,
data=data,
max_samples=MAXSAMPLES,
random_state=random_state,
verbose=0,
),
wrapped,
verbose=verbose,
)
for wrapped_name, wrapped in cf_methods.items() if 'knn' in wrapped_name
and int(wrapped_name.split('knn')[1].split('_')[0]) in [100] # Only for some KNN we do also the projection
})
return cf_methods
def create_explainers(X,
y,
model,
ref_points,
feature_names,
multiscaler,
random_state,
feature_trends=None,
verbose=True):
# Pre-process data
data = process_data(X, ret_type='np', names=feature_names, names_is='subset')
y = process_data(y, ret_type='np').flatten()
# Create CF explainers
cf_methods = create_counterfactual_explainers(data, y, model, feature_names, multiscaler, random_state, verbose,
feature_trends)
return {
###################### CLASSIC background ###########################################
**dict(
training=TreeExplainer(
model,
data=data,
feature_perturbation='interventional',
max_samples=MAXSAMPLES,
trend_estimator=trend_estimator,
),
training_100=TreeExplainer(
model,
data=data,
feature_perturbation='interventional',
max_samples=100,
trend_estimator=trend_estimator,
),
),
################## SHAP CF (including diff_pred and label) ##############################
**{
method: CompositeExplainer(
cfe,
TreeExplainer(
model,
data=None,
feature_perturbation='interventional',
trend_estimator=trend_estimator,
max_samples=MAXSAMPLES,
),
verbose=verbose,
)
for method, cfe in cf_methods.items()
},
}