Skip to content

Commit 2842cb6

Browse files
Moves all batch info into no_grad
1 parent b29533b commit 2842cb6

File tree

3 files changed

+21
-22
lines changed

3 files changed

+21
-22
lines changed

pfrl/agents/ddpg.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,16 +148,16 @@ def sync_target_network(self):
148148
def compute_critic_loss(self, batch):
149149
"""Compute loss for critic."""
150150

151-
batch_next_state = batch["next_state"]
152-
batch_rewards = batch["reward"]
153-
batch_terminal = batch["is_state_terminal"]
154-
batch_state = batch["state"]
155-
batch_actions = batch["action"]
156-
batchsize = len(batch_rewards)
157-
158151
with torch.no_grad():
152+
batch_state = batch["state"]
153+
batch_actions = batch["action"]
154+
batch_rewards = batch["reward"]
155+
batchsize = len(batch_rewards)
159156
assert not self.recurrent
157+
batch_next_state = batch["next_state"]
158+
batch_terminal = batch["is_state_terminal"]
160159
next_actions = self.target_policy(batch_next_state).sample()
160+
161161
next_q = self.target_q_function((batch_next_state, next_actions))
162162
target_q = batch_rewards + self.gamma * (
163163
1.0 - batch_terminal

pfrl/agents/soft_actor_critic.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,16 @@ def sync_target_network(self):
213213

214214
def update_q_func(self, batch):
215215
"""Compute loss for a given Q-function."""
216-
217-
batch_next_state = batch["next_state"]
218-
batch_rewards = batch["reward"]
219-
batch_terminal = batch["is_state_terminal"]
220-
batch_state = batch["state"]
221-
batch_actions = batch["action"]
222-
batch_discount = batch["discount"]
223-
224216
with torch.no_grad(), pfrl.utils.evaluating(self.policy), pfrl.utils.evaluating(
225217
self.target_q_func1
226218
), pfrl.utils.evaluating(self.target_q_func2):
219+
batch_state = batch["state"]
220+
batch_actions = batch["action"]
221+
batch_next_state = batch["next_state"]
222+
batch_rewards = batch["reward"]
223+
batch_terminal = batch["is_state_terminal"]
224+
batch_discount = batch["discount"]
225+
227226
next_action_distrib = self.policy(batch_next_state)
228227
next_actions = next_action_distrib.sample()
229228
next_log_prob = next_action_distrib.log_prob(next_actions)

pfrl/agents/td3.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,18 @@ def sync_target_network(self):
181181
def update_q_func(self, batch):
182182
"""Compute loss for a given Q-function."""
183183

184-
batch_next_state = batch["next_state"]
185-
batch_rewards = batch["reward"]
186-
batch_terminal = batch["is_state_terminal"]
187-
batch_state = batch["state"]
188-
batch_actions = batch["action"]
189-
batch_discount = batch["discount"]
190-
191184
with torch.no_grad(), pfrl.utils.evaluating(
192185
self.target_policy
193186
), pfrl.utils.evaluating(self.target_q_func1), pfrl.utils.evaluating(
194187
self.target_q_func2
195188
):
189+
batch_state = batch["state"]
190+
batch_actions = batch["action"]
191+
batch_next_state = batch["next_state"]
192+
batch_rewards = batch["reward"]
193+
batch_terminal = batch["is_state_terminal"]
194+
batch_discount = batch["discount"]
195+
196196
next_actions = self.target_policy_smoothing_func(
197197
self.target_policy(batch_next_state).sample()
198198
)

0 commit comments

Comments
 (0)