-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmontecarlo.py
144 lines (113 loc) · 4.51 KB
/
montecarlo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import numpy as np
import math
# Definition for a node for the MCTS algorithm
class Node:
"""
The node has a parent, a child, expandable nodes, visit count, value sum, the game and its state,
and the action taken
"""
def __init__(self, game, args, state, parent=None, action_taken=None):
self.game = game
self.args = args
self.state = state
self.parent = parent
self.action_taken = action_taken
self.children = []
self.expandable_moves = game.get_valid_moves(state)
self.visit_count = 0
self.value_sum = 0
# Find if the node is fully expanded or not
def is_fully_expanded(self):
return np.sum(self.expandable_moves) == 0 and len(self.children) > 0
# Select a child node
def select(self):
best_child = None
best_ucb = -np.inf
for child in self.children:
ucb = self.get_ucb(child)
if ucb > best_ucb:
best_child = child
best_ucb = ucb
return best_child
"""
Get the UCB Score given by Q(s,a) + C*((ln(N)/n)**(0.5))
We take 1-q_value because a bad q_value for the child implies a good q_value for the parent
"""
def get_ucb(self, child):
q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
return q_value + self.args['C'] * math.sqrt(math.log(self.visit_count) / child.visit_count)
# Expand the nodes
def expand(self):
action = np.random.choice(np.where(self.expandable_moves == 1)[0])
self.expandable_moves[action] = 0
"""
The child will always think he is player 1. Whenever we need to switch players, we will flip
the board state instead.
"""
child_state = self.state.copy()
child_state = self.game.get_next_state(child_state, action, 1)
child_state = self.game.change_perspective(child_state, -1)
# Add child Node
child = Node(self.game, self.args, child_state, self, action)
self.children.append(child)
return child
# Simulate the game result
def simulate(self):
value, is_terminated = self.game.check_win_and_termination(self.state, self.action_taken)
value = self.game.get_opponent_value(value)
if is_terminated:
return value
rollout_state = self.state.copy()
rollout_player = 1
while True:
valid_moves = self.game.get_valid_moves(rollout_state)
action = np.random.choice(np.where(valid_moves == 1)[0])
rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)
value, is_terminated = self.game.check_win_and_termination(rollout_state, action)
if is_terminated:
if rollout_player == -1:
value = self.game.get_opponent_value(value)
return value
rollout_player = self.game.get_opponent(rollout_player)
# Backpropagate the results
def backpropagate(self, value):
self.value_sum += value
self.visit_count += 1
value = self.game.get_opponent_value(value)
if self.parent is not None:
self.parent.backpropagate(value)
# Definition for the Monte Carlo Tree Search
class MCTS:
def __init__(self, game, args):
self.game = game
self.args = args
"""
The search function should perform the four steps of a Monte Carlo Tree Search:
1. Selection
2. Expansion
3. Simulation
4. Backpropagation
"""
def search(self, state):
root = Node(self.game, self.args, state)
for search in range(self.args['num_searches']):
node = root
# Selection
while node.is_fully_expanded():
node = node.select()
value, is_terminated = self.game.check_win_and_termination(node.state, node.action_taken)
value = self.game.get_opponent_value(value)
# Do expansion and simulation if node is not terminal
if not is_terminated:
# Expansion
node = node.expand()
# Simulation
value = node.simulate()
# Backpropagation
node.backpropagate(value)
# Get the probabilities for the different actions
action_probs = np.zeros(self.game.action_size)
for child in root.children:
action_probs[child.action_taken] = child.visit_count
action_probs /= np.sum(action_probs)
return action_probs