15
15
16
16
17
17
def stacked_bar_plot (
18
- n_shapley_interaction_values : InteractionValues ,
18
+ interaction_values : InteractionValues ,
19
19
feature_names : Optional [list [Any ]] = None ,
20
- n_sii_max_order : Optional [int ] = None ,
20
+ max_order : Optional [int ] = None ,
21
21
title : Optional [str ] = None ,
22
22
xlabel : Optional [str ] = None ,
23
23
ylabel : Optional [str ] = None ,
@@ -36,10 +36,10 @@ def stacked_bar_plot(
36
36
:align: center
37
37
38
38
Args:
39
- n_shapley_interaction_values (InteractionValues): n-SII values as InteractionValues object
39
+ interaction_values (InteractionValues): n-SII values as InteractionValues object
40
40
feature_names: The feature names used for plotting. If no feature names are provided, the
41
41
feature indices are used instead. Defaults to ``None``.
42
- n_sii_max_order (int): The order of the n-SII values.
42
+ max_order (int): The order of the n-SII values.
43
43
title (str): The title of the plot.
44
44
xlabel (str): The label of the x-axis.
45
45
ylabel (str): The label of the y-axis.
@@ -64,32 +64,32 @@ def stacked_bar_plot(
64
64
... )
65
65
>>> feature_names = ["a", "b", "c"]
66
66
>>> fig, axes = stacked_bar_plot(
67
- ... n_shapley_interaction_values =interaction_values,
67
+ ... interaction_values =interaction_values,
68
68
... feature_names=feature_names,
69
69
... )
70
70
>>> plt.show()
71
71
"""
72
72
# sanitize inputs
73
- if n_sii_max_order is None :
74
- n_sii_max_order = n_shapley_interaction_values .max_order
73
+ if max_order is None :
74
+ max_order = interaction_values .max_order
75
75
76
76
fig , axis = plt .subplots ()
77
77
78
78
# transform data to make plotting easier
79
79
values_pos = np .array (
80
80
[
81
- n_shapley_interaction_values .get_n_order_values (order )
81
+ interaction_values .get_n_order_values (order )
82
82
.clip (min = 0 )
83
83
.sum (axis = tuple (range (1 , order )))
84
- for order in range (1 , n_sii_max_order + 1 )
84
+ for order in range (1 , max_order + 1 )
85
85
]
86
86
)
87
87
values_neg = np .array (
88
88
[
89
- n_shapley_interaction_values .get_n_order_values (order )
89
+ interaction_values .get_n_order_values (order )
90
90
.clip (max = 0 )
91
91
.sum (axis = tuple (range (1 , order )))
92
- for order in range (1 , n_sii_max_order + 1 )
92
+ for order in range (1 , max_order + 1 )
93
93
]
94
94
)
95
95
# get the number of features and the feature names
@@ -118,11 +118,11 @@ def stacked_bar_plot(
118
118
119
119
# add a legend to the plots
120
120
legend_elements = []
121
- for order in range (n_sii_max_order ):
121
+ for order in range (max_order ):
122
122
legend_elements .append (
123
123
Patch (facecolor = COLORS_K_SII [order ], edgecolor = "black" , label = f"Order { order + 1 } " )
124
124
)
125
- axis .legend (handles = legend_elements , loc = "upper center" , ncol = min (n_sii_max_order , 4 ))
125
+ axis .legend (handles = legend_elements , loc = "upper center" , ncol = min (max_order , 4 ))
126
126
127
127
x_ticks_labels = [feature for feature in feature_names ] # might be unnecessary
128
128
axis .set_xticks (x )
@@ -137,7 +137,7 @@ def stacked_bar_plot(
137
137
# set title and labels if not provided
138
138
139
139
(
140
- axis .set_title (f"n-SII values up to order ${ n_sii_max_order } $" )
140
+ axis .set_title (f"n-SII values up to order ${ max_order } $" )
141
141
if title is None
142
142
else axis .set_title (title )
143
143
)
0 commit comments