-
Notifications
You must be signed in to change notification settings - Fork 0
/
UCB.py
49 lines (43 loc) · 1.8 KB
/
UCB.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from baselines import BaseAgent
import numpy as np
from constants import LINEAR_BANDIT_COLUMNS
class UCBAgent(BaseAgent):
def __init__(self):
self.num_actions = 3
self.Q = np.ones(self.num_actions)
self.N = np.zeros(self.num_actions)
self.t = 0
def act(self, observation):
if self.t < self.num_actions:
return self.t
else:
return np.argmax(self.Q + np.sqrt(2 * np.log(self.t) / self.N))
def update(self, observation, action, reward):
self.t += 1
self.N[action] += 1
self.Q[action] = self.Q[action] * ((self.N[action] - 1) / self.N[action]) + reward / self.N[action]
class LinUCBAgent(BaseAgent):
def __init__(self):
self.alpha = 0.1
self.num_actions = 3
self.num_features = len(LINEAR_BANDIT_COLUMNS)
self.A = np.array([np.identity(self.num_features) for _ in range(self.num_actions)])
self.b = np.zeros((self.num_actions, self.num_features))
self.theta = np.zeros((self.num_actions, self.num_features))
self.t = 0
def act(self, observation):
if self.t < self.num_actions:
return self.t
else:
x = observation[LINEAR_BANDIT_COLUMNS].to_numpy()
p = np.zeros(self.num_actions)
for a in range(self.num_actions):
self.theta[a] = np.linalg.solve(self.A[a], self.b[a])
p[a] = self.theta[a] @ x + self.alpha * np.sqrt(x.T @ np.linalg.solve(self.A[a], x))
return np.argmax(p)
def update(self, observation, action, reward):
self.t += 1
x = observation[LINEAR_BANDIT_COLUMNS].to_numpy()
self.A[action] += np.outer(x, x)
self.b[action] += reward * x
self.theta[action] = np.linalg.inv(self.A[action]) @ self.b[action]