Skip to content

Commit d5b8853

Browse files
committed
Retrace returns for off-policy RL.
1 parent b60667a commit d5b8853

File tree

4 files changed

+50
-34
lines changed

4 files changed

+50
-34
lines changed

phillip/RL.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,23 +121,24 @@ def process_experiences(f, keys):
121121
delayed = live.copy()
122122
delayed.update(process_experiences(lambda t: t[:,self.config.delay:], ['state', 'reward']))
123123

124-
policy_args = live
125-
critic_args = delayed
124+
policy_args = live.copy()
125+
critic_args = delayed.copy()
126126

127127
print("Creating train ops")
128128

129129
train_ops = []
130130

131-
if self.train_policy or self.train_critic:
131+
if self.train_policy:
132+
probs = self.policy.probs(**policy_args)
133+
critic_args.update(**probs)
134+
132135
train_critic, targets, advantages = self.critic(**critic_args)
133-
134-
if self.train_critic:
135136
train_ops.append(train_critic)
137+
138+
probs.update(advantages=advantages)
139+
train_policy = self.policy.train(**probs)
140+
train_ops.append(train_policy)
136141

137-
if self.train_policy:
138-
policy_args.update(advantages=tf.stop_gradient(advantages), targets=targets)
139-
train_ops.append(self.policy.train(**policy_args))
140-
141142
if self.train_model:
142143
train_ops.append(self.model.train(**delayed))
143144

phillip/ac.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,39 +50,47 @@ def __init__(self, embedGame, embedAction, global_step, rlConfig, **kwargs):
5050

5151
self.actor = net
5252

53-
def train(self, state, prev_action, action, prob, advantages, **unused):
53+
def train(self, target_log_probs, advantages, **unused):
54+
train_log_probs = target_log_probs[:,:-1] # last state has no advantage
55+
actor_gain = tf.reduce_mean(tf.mul(train_log_probs, advantages - self.entropy_scale))
56+
57+
actor_params = self.actor.getVariables()
58+
59+
def metric(log_p1, log_p2):
60+
return tf.reduce_mean(tf.squared_difference(log_p1, log_p2))
61+
62+
return self.optimizer.optimize(-actor_gain, actor_params, target_log_probs, metric)
63+
64+
def probs(self, state, prev_action, action, prob, **unused):
5465
embedded_state = self.embedGame(state)
5566
embedded_prev_action = self.embedAction(prev_action)
5667
history = RL.makeHistory(embedded_state, embedded_prev_action, self.rlConfig.memory)
57-
58-
actor_probs = self.actor(history)
59-
log_actor_probs = tf.log(actor_probs)
6068

69+
actions = self.embedAction(action[:,self.rlConfig.memory:])
70+
71+
actor_probs = self.actor(history)
72+
real_actor_probs = tfl.batch_dot(actions, actor_probs)
73+
74+
"""
6175
entropy = - tfl.batch_dot(actor_probs, log_actor_probs)
6276
entropy_avg = tfl.power_mean(self.entropy_power, entropy)
6377
tf.scalar_summary('entropy_avg', entropy_avg)
6478
tf.scalar_summary('entropy_min', tf.reduce_min(entropy))
6579
tf.histogram_summary('entropy', entropy)
66-
67-
actions = self.embedAction(action[:,self.rlConfig.memory:])
68-
69-
real_actor_probs = tfl.batch_dot(actions, actor_probs)
70-
prob_ratios = prob[:,self.rlConfig.memory:] / real_actor_probs
71-
tf.scalar_summary('kl', tf.reduce_mean(tf.log(prob_ratios)))
72-
73-
real_log_actor_probs = tfl.batch_dot(actions, log_actor_probs)
74-
train_log_actor_probs = real_log_actor_probs[:,:-1] # last state has no advantage
75-
actor_gain = tf.reduce_mean(tf.mul(train_log_actor_probs, tf.stop_gradient(advantages)))
76-
#tf.scalar_summary('actor_gain', actor_gain)
80+
"""
7781

78-
actor_loss = - (actor_gain + self.entropy_scale * entropy_avg)
82+
tf.scalar_summary('entropy_avg', -tf.reduce_mean(tf.log(prob)))
7983

80-
actor_params = self.actor.getVariables()
81-
82-
def metric(p1, p2):
83-
return tf.reduce_mean(tfl.kl(p1, p2))
84+
behavior_probs = prob[:,self.rlConfig.memory:]
85+
ratios = real_actor_probs / behavior_probs
86+
87+
tf.scalar_summary('kl', -tf.reduce_mean(tf.log(ratios)))
8488

85-
return self.optimizer.optimize(actor_loss, actor_params, log_actor_probs, metric)
89+
return dict(
90+
target_probs = real_actor_probs,
91+
target_log_probs = tf.log(real_actor_probs),
92+
ratios = ratios
93+
)
8694

8795
def getPolicy(self, state, **unused):
8896
return self.actor(state)

phillip/critic.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class Critic(Default):
1010
Option('critic_learning_rate', type=float, default=1e-4),
1111
Option('gae_lambda', type=float, default=1., help="Generalized Advantage Estimation"),
1212
Option('fix_scopes', type=bool, default=False),
13+
Option('retrace', type=bool, default=True, help="Retrace(lambda) - correct for off-policy behavior"),
1314
]
1415

1516
_members = [
@@ -41,7 +42,7 @@ def __init__(self, embedGame, embedAction, scope='critic', **kwargs):
4142

4243
self.variables = self.net.getVariables()
4344

44-
def __call__(self, state, prev_action, reward, **unused):
45+
def __call__(self, state, prev_action, reward, ratios, **unused):
4546
embedded_state = self.embedGame(state)
4647
embedded_prev_action = self.embedAction(prev_action)
4748
history = makeHistory(embedded_state, embedded_prev_action, self.rlConfig.memory)
@@ -53,7 +54,13 @@ def __call__(self, state, prev_action, reward, **unused):
5354
rewards = reward[:,self.rlConfig.memory:]
5455
deltaVs = rewards + self.rlConfig.discount * values[:,1:] - trainVs
5556

56-
advantages = tfl.discount2(deltaVs, self.rlConfig.discount * self.gae_lambda)
57+
if self.retrace:
58+
discounts = tf.minimum(1., ratios[:,:-1])
59+
else:
60+
discounts = tf.ones_like(trainVs)
61+
62+
discounts *= self.rlConfig.discount * self.gae_lambda
63+
advantages = tfl.discount2(deltaVs, discounts)
5764

5865
targets = trainVs + advantages
5966
# targets = tfl.discount2(rewards, self.rlConfig.discount, lastV)

phillip/tf_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def discount(values, gamma, initial=None):
367367

368368
return tf.pack(values, axis=1)
369369

370-
def discount2(values, gamma, initial=None):
370+
def discount2(values, gammas, initial=None):
371371
"""Compute returns from rewards.
372372
373373
Uses tf.while_loop instead of unrolling in python.
@@ -382,7 +382,7 @@ def discount2(values, gamma, initial=None):
382382
"""
383383

384384
def body(i, prev, returns):
385-
next = values[:,i] + gamma * prev
385+
next = values[:,i] + gammas[:,i] * prev
386386
next.set_shape(prev.get_shape())
387387

388388
returns = returns.write(i, next)

0 commit comments

Comments
 (0)