Skip to content

Commit 6feb897

Browse files
author
mgarbacz
committed
Add support for passing check_additivity argument
1 parent c82c984 commit 6feb897

File tree

6 files changed

+75
-52
lines changed

6 files changed

+75
-52
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ repos:
66
entry: black
77
language: python
88
types: [python]
9-
language_version: python3.8
9+
language_version: python3
1010
args: [--line-length=120]
1111
- repo: local
1212
hooks:
@@ -23,12 +23,13 @@ repos:
2323
entry: flake8
2424
language: system
2525
types: [python]
26-
args: [--max-line-length=120, --docstring-convention=google, "--ignore=D100,D104,D202,D212,D200,E203,E731,W293,D412,D417,W503"]
26+
args: [--max-line-length=120, --docstring-convention=google, "--ignore=D100,D104,D202,D212,D200,E203,E731,W293,D412,D417,W503,D411"]
2727
# D100 requires all Python files (modules) to have a "public" docstring even if all functions within have a docstring.
2828
# D104 requires __init__ files to have a docstring
2929
# D202 No blank lines allowed after function docstring
3030
# D212
31-
# D200
31+
# D200
32+
# D411 Missing blank line before section
3233
# D412 No blank lines allowed between a section header and its content
3334
# D417 Missing argument descriptions in the docstring # Only ignored because of false positve when using multiline args.
3435
# E203

probatus/feature_elimination/feature_elimination.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def _report_current_results(
333333
self.report_df = pd.concat([self.report_df, current_row], axis=0)
334334

335335
@staticmethod
336-
def _get_feature_shap_values_per_fold(X, y, clf, train_index, val_index, scorer, verbose=0):
336+
def _get_feature_shap_values_per_fold(X, y, clf, train_index, val_index, scorer, verbose=0, **shap_kwargs):
337337
"""
338338
This function calculates the shap values on validation set, and Train and Val score.
339339
@@ -365,6 +365,12 @@ def _get_feature_shap_values_per_fold(X, y, clf, train_index, val_index, scorer,
365365
- 51 - 100 - shows most important warnings, prints of the feature removal process
366366
- above 100 - presents all prints and all warnings (including SHAP warnings).
367367
368+
**shap_kwargs:
369+
keyword arguments passed to
370+
[shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
371+
It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
372+
The `approximate=True` causes less accurate, but faster SHAP values calculation, while
373+
`check_additivity=False` disables the additivity check inside SHAP.
368374
Returns:
369375
(np.array, float, float):
370376
Tuple with the results: Shap Values on validation fold, train score, validation score.
@@ -380,10 +386,10 @@ def _get_feature_shap_values_per_fold(X, y, clf, train_index, val_index, scorer,
380386
score_val = scorer(clf, X_val, y_val)
381387

382388
# Compute SHAP values
383-
shap_values = shap_calc(clf, X_val, verbose=verbose)
389+
shap_values = shap_calc(clf, X_val, verbose=verbose, **shap_kwargs)
384390
return shap_values, score_train, score_val
385391

386-
def fit(self, X, y, columns_to_keep=None, column_names=None):
392+
def fit(self, X, y, columns_to_keep=None, column_names=None, **shap_kwargs):
387393
"""
388394
Fits the object with the provided data.
389395
@@ -413,6 +419,13 @@ def fit(self, X, y, columns_to_keep=None, column_names=None):
413419
feature names. If not provided the existing feature names are used or default feature names are
414420
generated.
415421
422+
**shap_kwargs:
423+
keyword arguments passed to
424+
[shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
425+
It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
426+
The `approximate=True` causes less accurate, but faster SHAP values calculation, while
427+
`check_additivity=False` disables the additivity check inside SHAP.
428+
416429
Returns:
417430
(ShapRFECV): Fitted object.
418431
"""
@@ -502,6 +515,7 @@ def fit(self, X, y, columns_to_keep=None, column_names=None):
502515
val_index=val_index,
503516
scorer=self.scorer.scorer,
504517
verbose=self.verbose,
518+
**shap_kwargs,
505519
)
506520
for train_index, val_index in self.cv.split(current_X, self.y)
507521
)
@@ -557,7 +571,7 @@ def compute(self):
557571

558572
return self.report_df
559573

560-
def fit_compute(self, X, y, columns_to_keep=None, column_names=None):
574+
def fit_compute(self, X, y, columns_to_keep=None, column_names=None, **shap_kwargs):
561575
"""
562576
Fits the object with the provided data.
563577
@@ -586,12 +600,19 @@ def fit_compute(self, X, y, columns_to_keep=None, column_names=None):
586600
feature names. If not provided the existing feature names are used or default feature names are
587601
generated.
588602
603+
**shap_kwargs:
604+
keyword arguments passed to
605+
[shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
606+
It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
607+
The `approximate=True` causes less accurate, but faster SHAP values calculation, while
608+
`check_additivity=False` disables the additivity check inside SHAP.
609+
589610
Returns:
590611
(pd.DataFrame):
591612
DataFrame containing results of feature elimination from each iteration.
592613
"""
593614

594-
self.fit(X, y, columns_to_keep=columns_to_keep, column_names=column_names)
615+
self.fit(X, y, columns_to_keep=columns_to_keep, column_names=column_names, **shap_kwargs)
595616
return self.compute()
596617

597618
def get_reduced_features_set(self, num_features):

probatus/interpret/model_interpret.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,7 @@ def __init__(self, clf, scoring="roc_auc", verbose=0):
104104
self.scorer = get_single_scorer(scoring)
105105
self.verbose = verbose
106106

107-
def fit(
108-
self,
109-
X_train,
110-
X_test,
111-
y_train,
112-
y_test,
113-
column_names=None,
114-
class_names=None,
115-
approximate=False,
116-
**shap_kwargs,
117-
):
107+
def fit(self, X_train, X_test, y_train, y_test, column_names=None, class_names=None, **shap_kwargs):
118108
"""
119109
Fits the object and calculates the shap values for the provided datasets.
120110
@@ -138,12 +128,12 @@ def fit(
138128
List of class names e.g. ['neg', 'pos']. If none, the default ['Negative Class', 'Positive Class'] are
139129
used.
140130
141-
approximate (boolean, optional):
142-
if True uses shap approximations - less accurate, but very fast.
143-
144131
**shap_kwargs:
145132
keyword arguments passed to
146133
[shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
134+
It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
135+
The `approximate=True` causes less accurate, but faster SHAP values calculation, while
136+
`check_additivity=False` disables the additivity check inside SHAP.
147137
"""
148138

149139
self.X_train, self.column_names = preprocess_data(
@@ -171,7 +161,6 @@ def fit(
171161
clf=self.clf,
172162
X=self.X_train,
173163
y=self.y_train,
174-
approximate=approximate,
175164
column_names=self.column_names,
176165
class_names=self.class_names,
177166
verbose=self.verbose,
@@ -182,7 +171,6 @@ def fit(
182171
clf=self.clf,
183172
X=self.X_test,
184173
y=self.y_test,
185-
approximate=approximate,
186174
column_names=self.column_names,
187175
class_names=self.class_names,
188176
verbose=self.verbose,
@@ -285,7 +273,6 @@ def fit_compute(
285273
y_test,
286274
column_names=None,
287275
class_names=None,
288-
approximate=False,
289276
return_scores=False,
290277
**shap_kwargs,
291278
):
@@ -314,19 +301,19 @@ def fit_compute(
314301
If none, the default ['Negative Class', 'Positive Class'] are
315302
used.
316303
317-
approximate (boolean, optional):
318-
if True uses shap approximations - less accurate, but very fast.
319-
320304
return_scores (bool, optional):
321305
Flag indicating whether the method should return
322306
the train and test score of the model,
323307
together with the model interpretation report. If true,
324308
the output of this method is a tuple of DataFrame, float,
325309
float.
326310
327-
**shap_kwargs: keyword arguments passed to
311+
**shap_kwargs:
328312
keyword arguments passed to
329313
[shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
314+
It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
315+
The `approximate=True` causes less accurate, but faster SHAP values calculation, while
316+
`check_additivity=False` disables the additivity check inside SHAP.
330317
331318
Returns:
332319
(pd.DataFrame or tuple(pd.DataFrame, float, float)):
@@ -340,20 +327,11 @@ def fit_compute(
340327
y_test=y_test,
341328
column_names=column_names,
342329
class_names=class_names,
343-
approximate=approximate,
344330
**shap_kwargs,
345331
)
346332
return self.compute()
347333

348-
def plot(
349-
self,
350-
plot_type,
351-
target_set="test",
352-
target_columns=None,
353-
samples_index=None,
354-
show=True,
355-
**plot_kwargs,
356-
):
334+
def plot(self, plot_type, target_set="test", target_columns=None, samples_index=None, show=True, **plot_kwargs):
357335
"""
358336
Plots the appropriate SHAP plot.
359337

probatus/interpret/shap_dependence.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,22 +82,33 @@ def __repr__(self):
8282
"""
8383
return "Shap dependence plotter for {}".format(self.clf.__class__.__name__)
8484

85-
def fit(self, X, y, column_names=None, class_names=None, precalc_shap=None):
85+
def fit(self, X, y, column_names=None, class_names=None, precalc_shap=None, **shap_kwargs):
8686
"""
8787
Fits the plotter to the model and data by computing the shap values.
8888
8989
If the shap_values are passed, they do not need to be computed.
9090
9191
Args:
9292
X (pd.DataFrame): input variables.
93+
9394
y (pd.Series): target variable.
95+
9496
column_names (None, or list of str, optional):
9597
List of feature names for the dataset. If None, then column names from the X_train dataframe are used.
98+
9699
class_names (None, or list of str, optional):
97100
List of class names e.g. ['neg', 'pos']. If none, the default ['Negative Class', 'Positive Class'] are
98101
used.
102+
99103
precalc_shap (Optional, None or np.array):
100104
Precalculated shap values, If provided they don't need to be computed.
105+
106+
**shap_kwargs:
107+
keyword arguments passed to
108+
[shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
109+
It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
110+
The `approximate=True` causes less accurate, but faster SHAP values calculation, while
111+
`check_additivity=False` disables the additivity check inside SHAP.
101112
"""
102113
self.X, self.column_names = preprocess_data(X, X_name="X", column_names=column_names, verbose=self.verbose)
103114
self.y = preprocess_labels(y, y_name="y", index=self.X.index, verbose=self.verbose)
@@ -107,7 +118,7 @@ def fit(self, X, y, column_names=None, class_names=None, precalc_shap=None):
107118
if self.class_names is None:
108119
self.class_names = ["Negative Class", "Positive Class"]
109120

110-
self.shap_vals_df = shap_to_df(self.clf, self.X, precalc_shap=precalc_shap, verbose=self.verbose)
121+
self.shap_vals_df = shap_to_df(self.clf, self.X, precalc_shap=precalc_shap, verbose=self.verbose, **shap_kwargs)
111122

112123
self.fitted = True
113124
return self
@@ -123,7 +134,7 @@ def compute(self):
123134
self._check_if_fitted()
124135
return self.shap_vals_df
125136

126-
def fit_compute(self, X, y, column_names=None, class_names=None, precalc_shap=None):
137+
def fit_compute(self, X, y, column_names=None, class_names=None, precalc_shap=None, **shap_kwargs):
127138
"""
128139
Fits the plotter to the model and data by computing the shap values.
129140
@@ -146,17 +157,18 @@ def fit_compute(self, X, y, column_names=None, class_names=None, precalc_shap=No
146157
precalc_shap (Optional, None or np.array):
147158
Precalculated shap values, If provided they don't need to be computed.
148159
160+
**shap_kwargs:
161+
keyword arguments passed to
162+
[shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
163+
It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
164+
The `approximate=True` causes less accurate, but faster SHAP values calculation, while
165+
`check_additivity=False` disables the additivity check inside SHAP.
166+
149167
Returns:
150168
(pd.DataFrame):
151169
SHAP Values for X.
152170
"""
153-
self.fit(
154-
X,
155-
y,
156-
column_names=column_names,
157-
class_names=class_names,
158-
precalc_shap=precalc_shap,
159-
)
171+
self.fit(X, y, column_names=column_names, class_names=class_names, precalc_shap=precalc_shap, **shap_kwargs)
160172
return self.compute()
161173

162174
def plot(

probatus/sample_similarity/resemblance_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ class is 'roc_auc'.
594594

595595
self.plot_title = "SHAP summary plot"
596596

597-
def fit(self, X1, X2, column_names=None, class_names=None):
597+
def fit(self, X1, X2, column_names=None, class_names=None, **shap_kwargs):
598598
"""
599599
This function assigns to labels to each sample, 0 to first sample, 1 to the second.
600600
@@ -619,13 +619,20 @@ def fit(self, X1, X2, column_names=None, class_names=None):
619619
List of class names assigned, in this case provided samples e.g. ['sample1', 'sample2']. If none, the
620620
default ['First Sample', 'Second Sample'] are used.
621621
622+
**shap_kwargs:
623+
keyword arguments passed to
624+
[shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
625+
It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
626+
The `approximate=True` causes less accurate, but faster SHAP values calculation, while
627+
`check_additivity=False` disables the additivity check inside SHAP.
628+
622629
Returns:
623630
(SHAPImportanceResemblance):
624631
Fitted object.
625632
"""
626633
super().fit(X1=X1, X2=X2, column_names=column_names, class_names=class_names)
627634

628-
self.shap_values_test = shap_calc(self.clf, self.X_test, verbose=self.verbose)
635+
self.shap_values_test = shap_calc(self.clf, self.X_test, verbose=self.verbose, **shap_kwargs)
629636
self.report = calculate_shap_importance(self.shap_values_test, self.column_names)
630637
return self
631638

probatus/utils/shap_helpers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def shap_calc(
3131
return_explainer=False,
3232
verbose=0,
3333
sample_size=100,
34+
check_additivity=True,
3435
**shap_kwargs,
3536
):
3637
"""
@@ -57,6 +58,9 @@ def shap_calc(
5758
- 51 - 100 - shows other warnings and prints
5859
- above 100 - presents all prints and all warnings (including SHAP warnings).
5960
61+
check_additivity (boolean):
62+
if False SHAP will disable the additivity check.
63+
6064
**shap_kwargs: kwargs of the shap.Explainer
6165
6266
Returns:
@@ -80,7 +84,7 @@ def shap_calc(
8084

8185
explainer = shap.Explainer(model, masker=mask, **shap_kwargs)
8286
# Calculate Shap values.
83-
shap_values = explainer.shap_values(X)
87+
shap_values = explainer.shap_values(X, check_additivity=check_additivity, approximate=approximate)
8488

8589
if isinstance(shap_values, list) and len(shap_values) == 2:
8690
warnings.warn(

0 commit comments

Comments
 (0)