Skip to content

Commit f076ee4

Browse files
committed
Added possibility to call game based on given player_names
1 parent fe1e46e commit f076ee4

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

shapiq/games/base.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,11 @@ def _check_coalitions(
195195
if tuple_types[0] not in [set(), {int}, {str}]:
196196
raise TypeError("Tuples must contain either integers or strings.")
197197

198-
# convert strings in tuples to integers
198+
# check that string tuples are only used if player names are provided
199+
if self.player_name_lookup is None and tuple_types[0] == {str}:
200+
raise TypeError("Player names have to be provided to evaluate string tuples.")
201+
202+
# convert strings to integers
199203
if tuple_types[0] == {str}:
200204
coalitions = [
201205
tuple([self.player_name_lookup[name] for name in coal]) for coal in coalitions
@@ -211,9 +215,16 @@ def _check_coalitions(
211215
raise TypeError("Elements of tuple must have the same type.")
212216
if tuple_types not in [set(), {int}, {str}]:
213217
raise TypeError("Tuple must contain either integers or strings.")
218+
219+
# check that string tuples are only used if player names are provided
220+
if self.player_name_lookup is None and tuple_types == {str}:
221+
raise TypeError("Player names have to be provided to evaluate string tuples.")
222+
223+
# convert strings to integers
214224
if tuple_types == {str}:
215225
coalitions = tuple([self.player_name_lookup[name] for name in coalitions])
216226

227+
# convert tuple to one-hot encoding
217228
coalitions = transform_coalitions_to_array([coalitions], self.n_players)
218229
return coalitions
219230
elif isinstance(coalitions, np.ndarray):
@@ -232,6 +243,18 @@ def _check_coalitions(
232243
f"the number of players in the game ({self.n_players})."
233244
)
234245
return coalitions
246+
elif isinstance(coalitions, str):
247+
if coalitions == "empty":
248+
return self.empty_coalition.reshape((1, self.n_players))
249+
elif coalitions == "grand":
250+
return self.grand_coalition.reshape((1, self.n_players))
251+
else:
252+
if self.player_name_lookup is None:
253+
raise TypeError("Player names have to be provided to evaluate strings.")
254+
255+
tuple_coal = tuple([self.player_name_lookup[coalitions]])
256+
return transform_coalitions_to_array([tuple_coal], self.n_players)
257+
235258
else:
236259
raise TypeError("Coalitions have to be numpy arrays or lists of tuples or tuple.")
237260

tests/tests_games/test_base_game.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ def value_function(self, coalition):
6868
# test wrong datatype in coalition call
6969
with pytest.raises(TypeError):
7070
assert test_game({0, 1, 2}) == 0.0
71-
72-
# test wrong datatype in coalition call
7371
with pytest.raises(TypeError):
7472
assert test_game([(None)]) == 0.0
7573

@@ -78,6 +76,7 @@ def value_function(self, coalition):
7876
assert test_game(test_coalition) == 0.0
7977
assert test_game(()) == 0.0
8078
assert test_game([()]) == 0.0
79+
assert test_game("empty") == 0.0
8180

8281
# test with grand coalition all call variants
8382
test_coalition = test_game.grand_coalition
@@ -86,6 +85,7 @@ def value_function(self, coalition):
8685
assert test_game([tuple(range(0, test_game.n_players))]) == 1.0
8786
assert test_game(tuple(test_game.player_name_lookup.values())) == 1.0
8887
assert test_game([tuple(test_game.player_name_lookup.values())]) == 1.0
88+
assert test_game("grand") == 1.0
8989

9090
# test with single player coalition all call variants
9191
test_coalition = np.array([True] + [False for _ in range(test_game.n_players - 1)])
@@ -94,6 +94,18 @@ def value_function(self, coalition):
9494
assert test_game([tuple([0])]) - 1 / 6 < 10e-7
9595
assert test_game(tuple(("Alice",))) - 1 / 6 < 10e-7
9696
assert test_game([tuple(("Alice",))]) - 1 / 6 < 10e-7
97+
assert test_game("Alice") - 1 / 6 < 10e-7
98+
99+
# test string calls with missing player names
100+
test_game2 = TestGame(n=n_players)
101+
assert test_game2("grand") == 1.0
102+
assert test_game2("empty") == 0.0
103+
with pytest.raises(TypeError):
104+
assert test_game2("Alice") == 0.0
105+
with pytest.raises(TypeError):
106+
assert test_game2(("Bob",)) == 0.0
107+
with pytest.raises(TypeError):
108+
assert test_game2([("Charlie",)]) == 0.0
97109

98110

99111
def test_precompute():

0 commit comments

Comments
 (0)