Skip to content

Commit 890ef20

Browse files
committed
refactored stacked_bar_plot parameter names
1 parent 6f01eee commit 890ef20

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

shapiq/plot/stacked_bar.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616

1717
def stacked_bar_plot(
18-
n_shapley_interaction_values: InteractionValues,
18+
interaction_values: InteractionValues,
1919
feature_names: Optional[list[Any]] = None,
20-
n_sii_max_order: Optional[int] = None,
20+
max_order: Optional[int] = None,
2121
title: Optional[str] = None,
2222
xlabel: Optional[str] = None,
2323
ylabel: Optional[str] = None,
@@ -36,10 +36,10 @@ def stacked_bar_plot(
3636
:align: center
3737
3838
Args:
39-
n_shapley_interaction_values(InteractionValues): n-SII values as InteractionValues object
39+
interaction_values(InteractionValues): n-SII values as InteractionValues object
4040
feature_names: The feature names used for plotting. If no feature names are provided, the
4141
feature indices are used instead. Defaults to ``None``.
42-
n_sii_max_order (int): The order of the n-SII values.
42+
max_order (int): The order of the n-SII values.
4343
title (str): The title of the plot.
4444
xlabel (str): The label of the x-axis.
4545
ylabel (str): The label of the y-axis.
@@ -64,32 +64,32 @@ def stacked_bar_plot(
6464
... )
6565
>>> feature_names = ["a", "b", "c"]
6666
>>> fig, axes = stacked_bar_plot(
67-
... n_shapley_interaction_values=interaction_values,
67+
... interaction_values=interaction_values,
6868
... feature_names=feature_names,
6969
... )
7070
>>> plt.show()
7171
"""
7272
# sanitize inputs
73-
if n_sii_max_order is None:
74-
n_sii_max_order = n_shapley_interaction_values.max_order
73+
if max_order is None:
74+
max_order = interaction_values.max_order
7575

7676
fig, axis = plt.subplots()
7777

7878
# transform data to make plotting easier
7979
values_pos = np.array(
8080
[
81-
n_shapley_interaction_values.get_n_order_values(order)
81+
interaction_values.get_n_order_values(order)
8282
.clip(min=0)
8383
.sum(axis=tuple(range(1, order)))
84-
for order in range(1, n_sii_max_order + 1)
84+
for order in range(1, max_order + 1)
8585
]
8686
)
8787
values_neg = np.array(
8888
[
89-
n_shapley_interaction_values.get_n_order_values(order)
89+
interaction_values.get_n_order_values(order)
9090
.clip(max=0)
9191
.sum(axis=tuple(range(1, order)))
92-
for order in range(1, n_sii_max_order + 1)
92+
for order in range(1, max_order + 1)
9393
]
9494
)
9595
# get the number of features and the feature names
@@ -118,11 +118,11 @@ def stacked_bar_plot(
118118

119119
# add a legend to the plots
120120
legend_elements = []
121-
for order in range(n_sii_max_order):
121+
for order in range(max_order):
122122
legend_elements.append(
123123
Patch(facecolor=COLORS_K_SII[order], edgecolor="black", label=f"Order {order + 1}")
124124
)
125-
axis.legend(handles=legend_elements, loc="upper center", ncol=min(n_sii_max_order, 4))
125+
axis.legend(handles=legend_elements, loc="upper center", ncol=min(max_order, 4))
126126

127127
x_ticks_labels = [feature for feature in feature_names] # might be unnecessary
128128
axis.set_xticks(x)
@@ -137,7 +137,7 @@ def stacked_bar_plot(
137137
# set title and labels if not provided
138138

139139
(
140-
axis.set_title(f"n-SII values up to order ${n_sii_max_order}$")
140+
axis.set_title(f"n-SII values up to order ${max_order}$")
141141
if title is None
142142
else axis.set_title(title)
143143
)

tests/tests_plots/test_stacked_bar.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ def test_stacked_bar_plot():
2020
)
2121
feature_names = ["a", "b", "c"]
2222
fig, axes = stacked_bar_plot(
23-
n_shapley_interaction_values=interaction_values,
23+
interaction_values=interaction_values,
2424
feature_names=feature_names,
2525
)
2626
assert isinstance(fig, plt.Figure)
2727
assert isinstance(axes, plt.Axes)
2828
plt.close()
2929

3030
fig, axes = stacked_bar_plot(
31-
n_shapley_interaction_values=interaction_values,
31+
interaction_values=interaction_values,
3232
feature_names=feature_names,
33-
n_sii_max_order=2,
33+
max_order=2,
3434
title="Title",
3535
xlabel="X",
3636
ylabel="Y",
@@ -40,9 +40,9 @@ def test_stacked_bar_plot():
4040
plt.close()
4141

4242
fig, axes = stacked_bar_plot(
43-
n_shapley_interaction_values=interaction_values,
43+
interaction_values=interaction_values,
4444
feature_names=None,
45-
n_sii_max_order=2,
45+
max_order=2,
4646
title="Title",
4747
xlabel="X",
4848
ylabel="Y",

0 commit comments

Comments
 (0)