Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit fc2f85c

Browse files
committedMar 19, 2025·
initial attempt
1 parent 84ca666 commit fc2f85c

File tree

3 files changed

+193
-35
lines changed

3 files changed

+193
-35
lines changed
 

‎integration_tests/test_baselines.py

+9
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ async def test_shp():
4242
await asyncio.wait_for(simple_cross_evaluation(5, players=players), timeout=5)
4343

4444

45+
@pytest.mark.asyncio
46+
async def test_shp_in_doubles():
47+
players = [
48+
RandomPlayer(battle_format="gen9randomdoublesbattle"),
49+
SimpleHeuristicsPlayer(battle_format="gen9randomdoublesbattle"),
50+
]
51+
await asyncio.wait_for(simple_cross_evaluation(5, players=players), timeout=5)
52+
53+
4554
@pytest.mark.asyncio
4655
async def test_max_base_power():
4756
players = [RandomPlayer(), MaxBasePowerPlayer()]

‎src/poke_env/player/baselines.py

+174-30
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import random
2-
from typing import List, Optional
2+
from typing import List, Optional, Tuple
33

44
from poke_env.environment.abstract_battle import AbstractBattle
5+
from poke_env.environment.battle import Battle
56
from poke_env.environment.double_battle import DoubleBattle
7+
from poke_env.environment.move import Move
68
from poke_env.environment.move_category import MoveCategory
79
from poke_env.environment.pokemon import Pokemon
810
from poke_env.environment.side_condition import SideCondition
@@ -89,6 +91,30 @@ def move_power_with_double_target(move):
8991
return self.choose_random_move(battle)
9092

9193

94+
class PseudoBattle(Battle):
95+
def __init__(self, battle: DoubleBattle, active_id: int, opp_id: int):
96+
self._active_pokemon = battle.active_pokemon[active_id]
97+
self._opponent_active_pokemon = battle.opponent_active_pokemon[opp_id]
98+
self._team = battle.team
99+
self._opponent_team = battle.opponent_team
100+
self._available_moves = battle.available_moves[active_id]
101+
self._available_switches = battle.available_switches[active_id]
102+
self._side_conditions = battle.side_conditions
103+
self._can_mega_evolve = battle.can_mega_evolve[active_id]
104+
self._can_z_move = battle.can_z_move[active_id]
105+
self._can_dynamax = battle.can_dynamax[active_id]
106+
can_tera = battle.can_tera[active_id]
107+
self._can_tera = None if isinstance(can_tera, bool) else can_tera
108+
109+
@property
110+
def active_pokemon(self):
111+
return self._active_pokemon
112+
113+
@property
114+
def opponent_active_pokemon(self):
115+
return self._opponent_active_pokemon
116+
117+
92118
class SimpleHeuristicsPlayer(Player):
93119
ENTRY_HAZARDS = {
94120
"spikes": SideCondition.SPIKES,
@@ -141,6 +167,33 @@ def _should_dynamax(self, battle: AbstractBattle, n_remaining_mons: int):
141167
return True
142168
return False
143169

170+
def _should_terastallize(
171+
self,
172+
battle: Battle,
173+
move: Move,
174+
n_remaining_mons: int,
175+
):
176+
if (
177+
not battle.can_tera
178+
or not battle.active_pokemon
179+
or not battle.opponent_active_pokemon
180+
):
181+
return False
182+
183+
if (
184+
move.base_power >= 80
185+
and battle.active_pokemon.current_hp_fraction == 1
186+
and battle.opponent_active_pokemon.current_hp_fraction == 1
187+
):
188+
return True
189+
if n_remaining_mons == 1:
190+
return True
191+
# Example: if the Pokémon has a defined Tera type and switching its type grants STAB for this move.
192+
if battle.active_pokemon.tera_type == move.type:
193+
return True
194+
195+
return False
196+
144197
def _should_switch_out(self, battle: AbstractBattle):
145198
active = battle.active_pokemon
146199
opponent = battle.opponent_active_pokemon
@@ -178,16 +231,13 @@ def _stat_estimation(self, mon: Pokemon, stat: str):
178231
boost = 2 / (2 - mon.boosts[stat])
179232
return ((2 * mon.base_stats[stat] + 31) + 5) * boost
180233

181-
def choose_move(self, battle: AbstractBattle):
182-
if isinstance(battle, DoubleBattle):
183-
return self.choose_random_doubles_move(battle)
184-
234+
def choose_move_in_1v1(self, battle: Battle) -> Tuple[BattleOrder, float]:
185235
# Main mons shortcuts
186236
active = battle.active_pokemon
187237
opponent = battle.opponent_active_pokemon
188238

189239
if active is None or opponent is None:
190-
return self.choose_random_move(battle)
240+
return self.choose_random_move(battle), 0
191241

192242
# Rough estimation of damage ratio
193243
physical_ratio = self._stat_estimation(active, "atk") / self._stat_estimation(
@@ -216,15 +266,15 @@ def choose_move(self, battle: AbstractBattle):
216266
and self.ENTRY_HAZARDS[move.id]
217267
not in battle.opponent_side_conditions
218268
):
219-
return self.create_order(move)
269+
return self.create_order(move), 0
220270

221271
# ...removal
222272
elif (
223273
battle.side_conditions
224274
and move.id in self.ANTI_HAZARDS_MOVES
225275
and n_remaining_mons >= 2
226276
):
227-
return self.create_order(move)
277+
return self.create_order(move), 0
228278

229279
# Setup moves
230280
if (
@@ -241,32 +291,126 @@ def choose_move(self, battle: AbstractBattle):
241291
)
242292
< 6
243293
):
244-
return self.create_order(move)
245-
246-
move = max(
247-
battle.available_moves,
248-
key=lambda m: m.base_power
249-
* (1.5 if m.type in active.types else 1)
250-
* (
251-
physical_ratio
252-
if m.category == MoveCategory.PHYSICAL
253-
else special_ratio
294+
return self.create_order(move), 0
295+
296+
def get_score(m: Move):
297+
return (
298+
m.base_power
299+
* (1.5 if m.type in active.types else 1)
300+
* (
301+
physical_ratio
302+
if m.category == MoveCategory.PHYSICAL
303+
else special_ratio
304+
)
305+
* m.accuracy
306+
* m.expected_hits
307+
* opponent.damage_multiplier(m)
254308
)
255-
* m.accuracy
256-
* m.expected_hits
257-
* opponent.damage_multiplier(m),
258-
)
259-
return self.create_order(
260-
move, dynamax=self._should_dynamax(battle, n_remaining_mons)
309+
310+
move = max(battle.available_moves, key=lambda m: get_score(m))
311+
return (
312+
self.create_order(
313+
move,
314+
dynamax=self._should_dynamax(battle, n_remaining_mons),
315+
terastallize=self._should_terastallize(
316+
battle, move, n_remaining_mons
317+
),
318+
),
319+
get_score(move),
261320
)
262321

263322
if battle.available_switches:
264323
switches: List[Pokemon] = battle.available_switches
265-
return self.create_order(
266-
max(
267-
switches,
268-
key=lambda s: self._estimate_matchup(s, opponent),
269-
)
324+
return (
325+
self.create_order(
326+
max(
327+
switches,
328+
key=lambda s: self._estimate_matchup(s, opponent),
329+
)
330+
),
331+
0,
270332
)
271333

272-
return self.choose_random_move(battle)
334+
return self.choose_random_move(battle), 0
335+
336+
@staticmethod
337+
def get_double_target_multiplier(battle: DoubleBattle, order: BattleOrder):
338+
can_target_first_opponent = (
339+
battle.opponent_active_pokemon[0]
340+
and not battle.opponent_active_pokemon[0].fainted
341+
)
342+
can_target_second_opponent = (
343+
battle.opponent_active_pokemon[1]
344+
and not battle.opponent_active_pokemon[1].fainted
345+
)
346+
can_double_target = can_target_first_opponent and can_target_second_opponent
347+
return (
348+
1
349+
if not hasattr(order, "order")
350+
or not isinstance(order.order, Move)
351+
or order.order.target in {Target.NORMAL, Target.ANY}
352+
or not can_double_target
353+
else 2
354+
)
355+
356+
def choose_move(self, battle: AbstractBattle):
357+
if not isinstance(battle, DoubleBattle):
358+
return self.choose_move_in_1v1(battle)[0] # type: ignore
359+
orders = []
360+
for active_id in [0, 1]:
361+
possible_orders, scores = zip(
362+
*[
363+
self.choose_move_in_1v1(PseudoBattle(battle, active_id, opp_id))
364+
for opp_id in [0, 1]
365+
]
366+
)
367+
for order in possible_orders:
368+
mon = battle.active_pokemon[active_id]
369+
if (
370+
order is not None
371+
and hasattr(order, "order")
372+
and isinstance(order.order, Move)
373+
and mon is not None
374+
):
375+
target = [o for o in possible_orders].index(order) + 1
376+
possible_targets = battle.get_possible_showdown_targets(
377+
order.order, mon
378+
)
379+
if target not in possible_targets:
380+
target = possible_targets[0]
381+
order.move_target = target
382+
scores = [
383+
scores[i]
384+
* self.get_double_target_multiplier(battle, possible_orders[i])
385+
for i in [0, 1]
386+
]
387+
order = (
388+
max(zip(possible_orders, scores), key=lambda a: a[1])[0]
389+
if battle.force_switch != [[False, True], [True, False]][active_id]
390+
else None
391+
)
392+
orders += [order]
393+
joined_orders = DoubleBattleOrder.join_orders(
394+
[orders[0]] if orders[0] else [],
395+
[orders[1]] if orders[1] else [],
396+
)
397+
if joined_orders:
398+
return joined_orders[0]
399+
else:
400+
return DoubleBattleOrder(orders[0], DefaultBattleOrder())
401+
402+
def teampreview(self, battle: AbstractBattle) -> str:
403+
team = list(battle.team.values())
404+
scored_team = []
405+
for idx, mon in enumerate(team):
406+
# Calculate a simple score based on speed, attack (atk+spa), and defense (def+spd)
407+
attack = mon.base_stats.get("atk", 0) + mon.base_stats.get("spa", 0)
408+
speed = mon.base_stats.get("spe", 0)
409+
defense = (mon.base_stats.get("def", 0) + mon.base_stats.get("spd", 0)) / 2
410+
score = speed + (attack / 2) + (defense / 4)
411+
scored_team.append((idx + 1, score))
412+
sorted_scored = sorted(scored_team, key=lambda x: x[1], reverse=True)
413+
chosen = sorted_scored[:4]
414+
chosen_sorted = sorted(chosen, key=lambda x: x[1], reverse=True)
415+
team_order = "".join(str(index) for index, _ in chosen_sorted)
416+
return f"/team {team_order}"

‎src/poke_env/player/battle_order.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,16 @@ def join_orders(
9999
DoubleBattleOrder(first_order=first_order, second_order=second_order)
100100
for first_order in first_orders
101101
for second_order in second_orders
102-
if not first_order.mega or not second_order.mega
103-
if not first_order.z_move or not second_order.z_move
104-
if not first_order.dynamax or not second_order.dynamax
105-
if not first_order.terastallize or not second_order.terastallize
106-
if first_order.order != second_order.order
102+
if not (
103+
hasattr(first_order, "order") and hasattr(second_order, "order")
104+
)
105+
or (
106+
not (first_order.mega and second_order.mega)
107+
and not (first_order.z_move and second_order.z_move)
108+
and not (first_order.dynamax and second_order.dynamax)
109+
and not (first_order.terastallize and second_order.terastallize)
110+
and first_order.order != second_order.order
111+
)
107112
]
108113
elif first_orders:
109114
return [DoubleBattleOrder(order, None) for order in first_orders]

0 commit comments

Comments
 (0)
Please sign in to comment.