Skip to content

Commit

Permalink
Enabling the MSEv criterion for the Python wrapper (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
LHBO authored Apr 11, 2024
1 parent c480102 commit 54d8551
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 32 deletions.
12 changes: 10 additions & 2 deletions python/examples/keras_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
epochs=10,
validation_data=(dfx_test, dfy_test))
## Shapr
df_shapley, pred_explain, internal, timing = explain(
df_shapley, pred_explain, internal, timing, MSEv = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
Expand All @@ -49,4 +49,12 @@
4 0.018697
5 0.026814
"""
"""

# Look at the (overall) MSEv
MSEv["MSEv"]

"""
MSEv MSEv_sd
1 0.000312 0.00014
"""
25 changes: 23 additions & 2 deletions python/examples/pytorch_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,33 @@ def forward(self, x):
optim.zero_grad()

## Shapr
df_shapley, pred_explain, internal, timing = explain(
df_shapley, pred_explain, internal, timing, MSEv = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
approach = 'empirical',
predict_model = lambda m, x: m(torch.from_numpy(x.values).float()).cpu().detach().numpy(),
prediction_zero = dfy_train.mean().item(),
)
print(df_shapley)
print(df_shapley)
"""
none MedInc HouseAge AveRooms AveBedrms Population AveOccup \
1 2.205947 2.313935 5.774470 5.425240 4.194669 1.712164 3.546001
2 2.205947 4.477620 5.467266 2.904239 3.046492 1.484807 5.631292
3 2.205946 4.028013 1.168401 5.229893 1.719724 2.134012 3.426378
4 2.205948 4.230376 8.639265 1.138520 3.776463 3.786978 4.253034
5 2.205947 3.923747 1.483737 1.113199 4.963213 -3.645875 4.950775
Latitude Longitude
1 1.102239 2.906469
2 4.966465 2.178510
3 3.503413 2.909760
4 3.413727 3.795563
5 3.011126 4.016985
"""

MSEv["MSEv"]
"""
MSEv MSEv_sd
1 27.046126 7.253933
"""
4 changes: 2 additions & 2 deletions python/examples/sklearn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
model.fit(dfx_train, dfy_train.values.flatten())

## Shapr
df_shapley, pred_explain, internal, timing = explain(
df_shapley, pred_explain, internal, timing, MSEv = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
Expand All @@ -33,4 +33,4 @@
4 -0.179275
5 -0.136463
"""
"""
8 changes: 7 additions & 1 deletion python/examples/sklearn_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
model.fit(dfx_train, dfy_train.values.flatten())

## Shapr
df_shapley, pred_explain, internal, timing = explain(
df_shapley, pred_explain, internal, timing, MSEv = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
Expand All @@ -34,6 +34,12 @@
5 0.099410 0.315230
"""

MSEv["MSEv"]
"""
MSEv MSEv_sd
1 0.534141 0.247984
"""

# Now do this for grouping as well

group = {'A': ['MedInc','HouseAge','AveRooms'],
Expand Down
30 changes: 18 additions & 12 deletions python/examples/xgboost_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
model = xgb.train(params={}, num_boost_round=20, dtrain=dtrain,)

## Shapr
df_shapley, pred_explain, internal, timing = explain(
df_shapley, pred_explain, internal, timing, MSEv = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
Expand All @@ -20,17 +20,23 @@

"""
none MedInc HouseAge AveRooms AveBedrms Population AveOccup \
1 2.205937 -0.701774 0.105231 -0.062693 0.119937 -0.054787 -0.296760
2 2.205938 -0.518543 0.062408 -0.447457 -0.232334 -0.015201 0.080512
3 2.205938 0.313422 0.558150 0.064590 -0.603992 0.052804 1.035786
4 2.205938 0.473781 -0.093358 0.046108 0.105684 -0.153729 -0.151584
5 2.205938 -0.084333 -0.099879 -0.119632 0.217170 0.162796 0.253543
1 2.205937 -0.660639 0.085530 -0.103341 0.123983 -0.056817 -0.282401
2 2.205938 -0.521214 0.073131 -0.506383 -0.184693 -0.011323 0.133417
3 2.205938 0.513277 0.715094 0.044417 -0.220822 0.049277 1.096243
4 2.205938 0.382929 -0.092198 -0.016058 0.150373 -0.137098 -0.260570
5 2.205938 -0.424637 -0.060884 -0.136637 0.153806 0.100997 -0.020819
Latitude Longitude
1 -0.553604 -0.223937
2 -0.258331 -0.097497
3 0.175916 0.410554
4 0.080829 0.187370
5 0.129439 0.382075
1 -0.525114 -0.222924
2 -0.211306 -0.063592
3 0.099492 0.392385
4 0.101910 0.206964
5 0.090441 0.314439
"""

"""
MSEv["MSEv"]

"""
MSEv MSEv_sd
1 0.815449 0.459069
"""
31 changes: 19 additions & 12 deletions python/examples/xgboost_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
model.fit(dfx_train, dfy_train.values.flatten())

## Shapr
df_shapley, pred_explain, internal, timing = explain(
df_shapley, pred_explain, internal, timing, MSEv = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
Expand All @@ -20,17 +20,24 @@

"""
none MedInc HouseAge AveRooms AveBedrms Population AveOccup \
1 2.205937 -0.697653 0.103323 -0.066003 0.115853 -0.057640 -0.292739
2 2.205938 -0.521995 0.064876 -0.445466 -0.230454 -0.019290 0.080655
3 2.205938 0.307681 0.563008 0.062743 -0.626912 0.050450 1.050069
4 2.205938 0.479900 -0.100035 0.030474 0.104301 -0.154396 -0.148057
5 2.205938 -0.088568 -0.101495 -0.121637 0.213535 0.169194 0.253711
1 2.205937 -0.655245 0.079722 -0.096497 0.126559 -0.056841 -0.287298
2 2.205938 -0.512414 0.077358 -0.504863 -0.176676 -0.005584 0.128635
3 2.205938 0.510828 0.719958 0.039504 -0.225118 0.044157 1.116464
4 2.205938 0.381191 -0.098956 -0.022961 0.145486 -0.139457 -0.241768
5 2.205938 -0.427220 -0.059622 -0.135028 0.158339 0.084157 -0.017783
Latitude Longitude
1 -0.573533 -0.237709
2 -0.265165 -0.090790
3 0.181936 0.412604
4 0.078605 0.186957
5 0.126759 0.376471
1 -0.579708 -0.231051
2 -0.212670 -0.051049
3 0.103866 0.405895
4 0.062271 0.201911
5 0.078978 0.307480
"""
"""

MSEv["MSEv"]

"""
MSEv MSEv_sd
1 0.825758 0.465439
"""
14 changes: 13 additions & 1 deletion python/shaprpy/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def explain(
keep_samp_for_vS: bool = False,
predict_model: Callable = None,
get_model_specs: Callable = None,
MSEv_uniform_comb_weights: bool = True,
timing: bool = True,
verbose: int | None = 0,
):
Expand Down Expand Up @@ -86,6 +87,10 @@ def explain(
If `None` (the default) internal functions are used for natively supported model classes, and the checking is
disabled for unsupported model classes.
Can also be used to override the default function for natively supported model classes.
MSEv_uniform_comb_weights: Logical. If `True` (default), then the function weights the combinations
uniformly when computing the MSEv criterion. If `False`, then the function use the Shapley kernel weights to
weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by
the sampling frequency when not all combinations are considered.
timing: Indicates whether the timing of the different parts of the explain call should be saved and returned.
verbose: An integer specifying the level of verbosity. If `0` (default), `shapr` will stay silent.
If `1`, it will print information about performance. If `2`, some additional information will be printed out.
Expand All @@ -98,6 +103,11 @@ def explain(
A numpy.Array with the predictions on `x_explain`.
dict
A dictionary of additional information.
dict
A dictionary of elapsed time information if `timing` is set to `True`.
dict
A dictionary of the MSEv evaluation criterion scores: averaged over both the explicands and coalitions,
only over the explicands, and only over the coalitions.
'''

timing_list = {
Expand Down Expand Up @@ -126,6 +136,7 @@ def explain(
seed = seed,
keep_samp_for_vS = keep_samp_for_vS,
feature_specs = rfeature_specs,
MSEv_uniform_comb_weights = MSEv_uniform_comb_weights,
timing = timing,
verbose = verbose,
is_python=True,
Expand Down Expand Up @@ -162,7 +173,8 @@ def explain(
df_shapley = r2py(base.as_data_frame(routput.rx2('shapley_values')))
pred_explain = r2py(routput.rx2('pred_explain'))
internal = recurse_r_tree(routput.rx2('internal'))
return df_shapley, pred_explain, internal, timing
MSEv = recurse_r_tree(routput.rx2('MSEv'))
return df_shapley, pred_explain, internal, timing, MSEv


def compute_vS(rinternal, model, predict_model):
Expand Down

0 comments on commit 54d8551

Please sign in to comment.