Skip to content

Commit ada06b6

Browse files
cleaner epsilon-greedy handling
1 parent c7ec5f6 commit ada06b6

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

hopes/policy/policies.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111

1212

1313
class Policy(ABC):
14+
epsilon: float | None = None
15+
16+
def with_epsilon(self, epsilon: float | None = None) -> "Policy":
17+
"""Set the epsilon value for epsilon-greedy action selection."""
18+
assert epsilon is None or 0 <= epsilon <= 1, "Epsilon must be in [0, 1]."
19+
self.epsilon = epsilon
20+
return self
21+
1422
@abstractmethod
1523
def log_likelihoods(self, obs: np.ndarray) -> np.ndarray:
1624
"""Compute the log-likelihoods of the actions under the policy for a given set of
@@ -32,39 +40,31 @@ def compute_action_probs(self, obs: np.ndarray) -> np.ndarray:
3240

3341
log_likelihoods = self.log_likelihoods(obs)
3442
action_probs = np.exp(log_likelihoods)
43+
# epsilon-greedy action selection
44+
if self.epsilon is not None and (np.random.rand() < self.epsilon):
45+
action_probs = np.ones_like(action_probs) / action_probs.shape[1]
3546
return action_probs
3647

37-
def select_action(
38-
self, obs: np.ndarray, deterministic: float = False, epsilon: float | None = None
39-
) -> np.ndarray:
48+
def select_action(self, obs: np.ndarray, deterministic: float = False) -> np.ndarray:
4049
"""Select actions under the policy for given observations.
4150
4251
:param obs: the observation(s) for which to select an action, shape (batch_size,
4352
obs_dim).
4453
:param deterministic: whether to select actions deterministically.
45-
:param epsilon: the epsilon value for epsilon-greedy action selection.
4654
:return: the selected action(s).
4755
"""
48-
assert epsilon is None or 0 <= epsilon <= 1, "Epsilon must be in [0, 1]."
4956
assert not (
50-
deterministic and epsilon is not None
57+
deterministic and self.epsilon is not None
5158
), "Cannot be deterministic and epsilon-greedy at the same time."
5259

5360
action_probs = self.compute_action_probs(obs)
5461

5562
# deterministic or greedy action selection
56-
if deterministic or (epsilon is not None and np.random.rand() > epsilon):
63+
if deterministic:
5764
return np.argmax(action_probs, axis=1)
5865

5966
# action selection based on computed action probabilities
60-
# or epsilon-greedy action selection
6167
else:
62-
if epsilon is not None:
63-
action_probs = (
64-
epsilon / len(action_probs[0]) * np.ones_like(action_probs)
65-
+ (1 - epsilon) * action_probs
66-
)
67-
6868
return np.array([np.random.choice(len(probs), p=probs) for probs in action_probs])
6969

7070

tests/test_policies.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,8 @@ def test_select_action_rnd_determ_eps(self):
154154
self.assertTrue(np.var(actions) > 0)
155155

156156
# epsilon-greedy action selection
157-
actions = [
158-
class_pol.select_action(obs=obs, deterministic=False, epsilon=0.5) for _ in range(100)
159-
]
157+
class_pol.with_epsilon(0.5)
158+
actions = [class_pol.select_action(obs=obs, deterministic=False) for _ in range(100)]
160159
self.assertTrue(np.var(actions) > 0)
161160

162161
def assert_log_probs(self, log_probs: np.ndarray, expected_shape: tuple):

0 commit comments

Comments
 (0)