11
11
12
12
13
13
class Policy (ABC ):
14
+ epsilon : float | None = None
15
+
16
+ def with_epsilon (self , epsilon : float | None = None ) -> "Policy" :
17
+ """Set the epsilon value for epsilon-greedy action selection."""
18
+ assert epsilon is None or 0 <= epsilon <= 1 , "Epsilon must be in [0, 1]."
19
+ self .epsilon = epsilon
20
+ return self
21
+
14
22
@abstractmethod
15
23
def log_likelihoods (self , obs : np .ndarray ) -> np .ndarray :
16
24
"""Compute the log-likelihoods of the actions under the policy for a given set of
@@ -32,39 +40,31 @@ def compute_action_probs(self, obs: np.ndarray) -> np.ndarray:
32
40
33
41
log_likelihoods = self .log_likelihoods (obs )
34
42
action_probs = np .exp (log_likelihoods )
43
+ # epsilon-greedy action selection
44
+ if self .epsilon is not None and (np .random .rand () < self .epsilon ):
45
+ action_probs = np .ones_like (action_probs ) / action_probs .shape [1 ]
35
46
return action_probs
36
47
37
- def select_action (
38
- self , obs : np .ndarray , deterministic : float = False , epsilon : float | None = None
39
- ) -> np .ndarray :
48
+ def select_action (self , obs : np .ndarray , deterministic : float = False ) -> np .ndarray :
40
49
"""Select actions under the policy for given observations.
41
50
42
51
:param obs: the observation(s) for which to select an action, shape (batch_size,
43
52
obs_dim).
44
53
:param deterministic: whether to select actions deterministically.
45
- :param epsilon: the epsilon value for epsilon-greedy action selection.
46
54
:return: the selected action(s).
47
55
"""
48
- assert epsilon is None or 0 <= epsilon <= 1 , "Epsilon must be in [0, 1]."
49
56
assert not (
50
- deterministic and epsilon is not None
57
+ deterministic and self . epsilon is not None
51
58
), "Cannot be deterministic and epsilon-greedy at the same time."
52
59
53
60
action_probs = self .compute_action_probs (obs )
54
61
55
62
# deterministic or greedy action selection
56
- if deterministic or ( epsilon is not None and np . random . rand () > epsilon ) :
63
+ if deterministic :
57
64
return np .argmax (action_probs , axis = 1 )
58
65
59
66
# action selection based on computed action probabilities
60
- # or epsilon-greedy action selection
61
67
else :
62
- if epsilon is not None :
63
- action_probs = (
64
- epsilon / len (action_probs [0 ]) * np .ones_like (action_probs )
65
- + (1 - epsilon ) * action_probs
66
- )
67
-
68
68
return np .array ([np .random .choice (len (probs ), p = probs ) for probs in action_probs ])
69
69
70
70
0 commit comments