@@ -466,3 +466,67 @@ def test_decision_stumps(background_reg_dataset, background_clf_dataset):
466
466
continue
467
467
468
468
assert pred == pytest .approx (efficiency , rel = 1e-5 )
469
+
470
+
471
+ def test_extra_trees_clf (et_clf_model , background_clf_data ):
472
+ """Test the shapiq implementation of TreeSHAP vs. SHAP's implementation for Extra Trees."""
473
+ explanation_instance = 1
474
+ class_label = 1
475
+
476
+ # the following code is used to get the shap values from the SHAP implementation
477
+ """
478
+ #import shap
479
+ # model_copy = copy.deepcopy(et_clf_model)
480
+ # explainer_shap = shap.TreeExplainer(model=model_copy)
481
+ # baseline_shap = float(explainer_shap.expected_value[class_label])
482
+ # x_explain_shap = copy.deepcopy(background_clf_data[explanation_instance].reshape(1, -1))
483
+ # sv_shap_all_classes = explainer_shap.shap_values(x_explain_shap)
484
+ # sv_shap = sv_shap_all_classes[0][:, class_label]
485
+ # print(sv_shap_all_classes, format(baseline_shap, '.20f'))
486
+ """ # noqa: ERA001
487
+ sv_shap = [0.00207427 , 0.00949552 , - 0.00108266 , - 0.03825587 , - 0.02694092 , 0.0170296 , 0.02046364 ]
488
+ sv_shap = np .asarray (sv_shap )
489
+ baseline_shap = 0.34000000000000002
490
+
491
+ # compute with shapiq
492
+ explainer_shapiq = TreeExplainer (
493
+ model = et_clf_model , max_order = 1 , index = "SV" , class_index = class_label
494
+ )
495
+ x_explain_shapiq = copy .deepcopy (background_clf_data [explanation_instance ])
496
+ sv_shapiq = explainer_shapiq .explain (x = x_explain_shapiq )
497
+ sv_shapiq_values = sv_shapiq .get_n_order_values (1 )
498
+ baseline_shapiq = sv_shapiq .baseline_value
499
+
500
+ assert baseline_shap == pytest .approx (baseline_shapiq , rel = 1e-4 )
501
+ assert np .allclose (sv_shap , sv_shapiq_values , rtol = 1e-5 )
502
+
503
+
504
+ def test_extra_trees_reg (et_reg_model , background_reg_data ):
505
+ """Test the shapiq implementation of TreeSHAP vs. SHAP's implementation for Extra Trees."""
506
+ explanation_instance = 1
507
+
508
+ # the following code is used to get the shap values from the SHAP implementation
509
+ """
510
+ # import shap
511
+ # model_copy = copy.deepcopy(et_reg_model)
512
+ # explainer_shap = shap.TreeExplainer(model=model_copy)
513
+ # baseline_shap = float(explainer_shap.expected_value)
514
+ # x_explain_shap = copy.deepcopy(background_reg_data[explanation_instance].reshape(1, -1))
515
+ # sv_shap_all_classes = explainer_shap.shap_values(x_explain_shap)
516
+ # sv_shap = sv_shap_all_classes[0]
517
+ # print(sv_shap_all_classes, format(baseline_shap, '.20f'))
518
+ """ # noqa: ERA001
519
+ sv_shap = [19.28673017 , - 19.87182634 , 0.0 , 10.89201698 , - 9.62498263 , 0.35992212 , 42.31290091 ]
520
+ sv_shap = np .asarray (sv_shap )
521
+ print (sv_shap )
522
+ baseline_shap = - 2.56682283435175007
523
+
524
+ # compute with shapiq
525
+ explainer_shapiq = TreeExplainer (model = et_reg_model , max_order = 1 , index = "SV" )
526
+ x_explain_shapiq = copy .deepcopy (background_reg_data [explanation_instance ])
527
+ sv_shapiq = explainer_shapiq .explain (x = x_explain_shapiq )
528
+ sv_shapiq_values = sv_shapiq .get_n_order_values (1 )
529
+ baseline_shapiq = sv_shapiq .baseline_value
530
+
531
+ assert baseline_shap == pytest .approx (baseline_shapiq , rel = 1e-4 )
532
+ assert np .allclose (sv_shap , sv_shapiq_values , rtol = 1e-5 )
0 commit comments