-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalphamontecarlo.py
147 lines (117 loc) · 4.93 KB
/
alphamontecarlo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import numpy as np
import torch
import math
# Definition for a node for the AlphaMCTS algorithm
class AlphaNode:
"""
The node has a parent, a child, visit count, value sum, the game and its state,
and the action taken. It doesn't need expandable moves any longer as it expands in all directions
"""
def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
self.game = game
self.args = args
self.state = state
self.parent = parent
self.action_taken = action_taken
self.prior = prior
self.children = []
self.visit_count = visit_count
self.value_sum = 0
# Find if the node is fully expanded or not
def is_fully_expanded(self):
return len(self.children) > 0
# Select a child node
def select(self):
best_child = None
best_ucb = -np.inf
for child in self.children:
ucb = self.get_ucb(child)
if ucb > best_ucb:
best_child = child
best_ucb = ucb
return best_child
"""
Get the UCB Score given by Q(s,a) + C*P(s,a)*((Sigma*N(s,b))**0.5)/(1 + N(s,a)
We take 1-q_value because a bad q_value for the child implies a good q_value for the parent
"""
def get_ucb(self, child):
if not child.visit_count:
q_value = 0
else:
q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (1 + child.visit_count)) * child.prior
# Expand the nodes in all directions depending on the policy and its probabilities
def expand(self, policy):
for action, prob in enumerate(policy):
if prob > 0:
"""
The child will always think he is player 1. Whenever we need to switch players, we will flip
the board state instead.
"""
child_state = self.state.copy()
child_state = self.game.get_next_state(child_state, action, 1)
child_state = self.game.change_perspective(child_state, -1)
# Add child Node
child = AlphaNode(self.game, self.args, child_state, self, action, prob)
self.children.append(child)
# Backpropagate the results
def backpropagate(self, value):
self.value_sum += value
self.visit_count += 1
value = self.game.get_opponent_value(value)
if self.parent is not None:
self.parent.backpropagate(value)
# Defining an Alpha MCTS Class
class AlphaMCTS:
def __init__(self, game, args, model):
self.game = game
self.args = args
self.model = model
"""
The search function should perform the three steps of Alpha Monte Carlo Tree Search:
1. Selection
2. Expansion
3. Backpropagation
This is wrapped with no grad to ensure we don't change gradients accidentally
"""
@torch.no_grad()
def search(self, state):
root = AlphaNode(self.game, self.args, state, visit_count=1)
policy, _ = self.model(
torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
)
policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] * \
np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)
valid_moves = self.game.get_valid_moves(state)
policy *= valid_moves
policy /= np.sum(policy)
root.expand(policy)
for search in range(self.args['num_searches']):
node = root
# Selection
while node.is_fully_expanded():
node = node.select()
value, is_terminated = self.game.check_win_and_termination(node.state, node.action_taken)
value = self.game.get_opponent_value(value)
# Encode the tensor and do expansion
if not is_terminated:
policy, value = self.model(
torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
)
policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
# Prevent expansion along invalid moves
valid_moves = self.game.get_valid_moves(node.state)
policy *= valid_moves
policy /= np.sum(policy)
value = value.item()
# Expansion
node.expand(policy)
# Backpropagation
node.backpropagate(value)
# Get the probabilities for the different actions
action_probs = np.zeros(self.game.action_size)
for child in root.children:
action_probs[child.action_taken] = child.visit_count
action_probs /= np.sum(action_probs)
return action_probs