-
Notifications
You must be signed in to change notification settings - Fork 1
/
battleship_gym_v7.py
129 lines (90 loc) · 3.67 KB
/
battleship_gym_v7.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gym, gym.spaces, gym.utils, gym.utils.seeding
import numpy as np
from battleship import Board
from static_board_v1 import ship_config
BOARD_DIM = 10
# reward function parameters
PERSISTENCE_PENALTY = 0
HIT_REWARD = .5
REPEATED_PENALTY = -0.2
RADIUS = 2
PROXIMAL_REWARD = 0.2
SCORE_REWARD = 100
class BattleshipEnvClass(gym.Env):
""" BattleshipEnvClass v7 (inherits from v6).
action_space is a index representing grid coordinate to probe next
obs_space is the entire (n x n) grid with discrete values (-1, 0, 1)
"""
def __init__(self):
self.board_dim = BOARD_DIM
# Action space is index of action for grid.flatten()
# get i, j with i, j = (action % BOARD_DIM, action // BOARD_DIM)
self.action_space = gym.spaces.Discrete(BOARD_DIM * BOARD_DIM)
# Observation space is an integer array that summarizes knowledge of each
# grid block according to: {0: unknown, 1: hit, -1: miss}
self.observation_space = gym.spaces.Box(low=-1, high=1,
shape=(BOARD_DIM, BOARD_DIM), dtype=np.int32)
self.reset()
def step(self, action):
state_prev = np.copy(self.state)
action = (action % BOARD_DIM, action // BOARD_DIM)
if state_prev[action[0], action[1]] != 0:
# - reward if torpedoing an already torpedo'd grid space (repeated penalty)
return self.state, REPEATED_PENALTY, self.done, {}
####################
# ADVANCE ENVIRONMENT -- Produce next state, check done condition
hit = self.board.torpedo(action)
if hit == 0:
self.state[action[0], action[1]] = -1
elif hit == 1:
self.state[action[0], action[1]] = 1
else:
raise ValueError("Invalid return from board.torpedo(), f{hit}")
self.done = self.board.check_gameover()
####################
# REWARD CALCULATION
reward = PERSISTENCE_PENALTY
if hit:
reward += HIT_REWARD
# + reward if next torpedo is near a known hit grid space (proximal reward)
neighbors = self._neighbors(action[0], action[1], RADIUS, self.board_dim)
for neigh in neighbors:
if self.state[neigh[0], neigh[1]] == 1:
reward += PROXIMAL_REWARD
if self.done:
# score = self.board.score()
reward += SCORE_REWARD
info = {}
return self.state, reward, self.done, info
def render(self):
pass
def reset(self):
self.board = Board(dim=self.board_dim, ship_config='default', vis=False, playmode=False)
self.state = np.zeros((self.board_dim, self.board_dim), dtype=np.int32)
self.done = False
return self.state
def seed(self, seed=None):
self.np_random, seed = gym.utils.seeding.np_random(seed)
return [seed]
def _neighbors(self, i, j, radius, dim):
neighbors = list()
for idx in range(radius):
rad = idx + 1
neighbors.extend([(i+rad, j),
(i, j+rad),
(i-rad, j),
(i, j-rad),
(i+rad, j+rad),
(i+rad, j-rad),
(i-rad, j+rad),
(i-rad, j-rad),
])
out = list()
for neighbor in neighbors:
if (0 <= neighbor[0] < dim) and (0 <= neighbor[1] < dim):
out.append(neighbor)
else:
pass
return out
def _overwrite_board(self, board):
self.board = board