From 2842cb666e49d05c470e2b2f7e7beec538de4ea3 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Fri, 12 Apr 2024 23:21:02 -0600 Subject: [PATCH] Moves all batch info into no_grad --- pfrl/agents/ddpg.py | 14 +++++++------- pfrl/agents/soft_actor_critic.py | 15 +++++++-------- pfrl/agents/td3.py | 14 +++++++------- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/pfrl/agents/ddpg.py b/pfrl/agents/ddpg.py index 08c0748da..aa367cfc8 100644 --- a/pfrl/agents/ddpg.py +++ b/pfrl/agents/ddpg.py @@ -148,16 +148,16 @@ def sync_target_network(self): def compute_critic_loss(self, batch): """Compute loss for critic.""" - batch_next_state = batch["next_state"] - batch_rewards = batch["reward"] - batch_terminal = batch["is_state_terminal"] - batch_state = batch["state"] - batch_actions = batch["action"] - batchsize = len(batch_rewards) - with torch.no_grad(): + batch_state = batch["state"] + batch_actions = batch["action"] + batch_rewards = batch["reward"] + batchsize = len(batch_rewards) assert not self.recurrent + batch_next_state = batch["next_state"] + batch_terminal = batch["is_state_terminal"] next_actions = self.target_policy(batch_next_state).sample() + next_q = self.target_q_function((batch_next_state, next_actions)) target_q = batch_rewards + self.gamma * ( 1.0 - batch_terminal diff --git a/pfrl/agents/soft_actor_critic.py b/pfrl/agents/soft_actor_critic.py index 75e8ce98a..ecec8c69f 100644 --- a/pfrl/agents/soft_actor_critic.py +++ b/pfrl/agents/soft_actor_critic.py @@ -213,17 +213,16 @@ def sync_target_network(self): def update_q_func(self, batch): """Compute loss for a given Q-function.""" - - batch_next_state = batch["next_state"] - batch_rewards = batch["reward"] - batch_terminal = batch["is_state_terminal"] - batch_state = batch["state"] - batch_actions = batch["action"] - batch_discount = batch["discount"] - with torch.no_grad(), pfrl.utils.evaluating(self.policy), pfrl.utils.evaluating( self.target_q_func1 ), pfrl.utils.evaluating(self.target_q_func2): + batch_state = batch["state"] + batch_actions = batch["action"] + batch_next_state = batch["next_state"] + batch_rewards = batch["reward"] + batch_terminal = batch["is_state_terminal"] + batch_discount = batch["discount"] + next_action_distrib = self.policy(batch_next_state) next_actions = next_action_distrib.sample() next_log_prob = next_action_distrib.log_prob(next_actions) diff --git a/pfrl/agents/td3.py b/pfrl/agents/td3.py index dc913f56d..bee0b15b1 100644 --- a/pfrl/agents/td3.py +++ b/pfrl/agents/td3.py @@ -181,18 +181,18 @@ def sync_target_network(self): def update_q_func(self, batch): """Compute loss for a given Q-function.""" - batch_next_state = batch["next_state"] - batch_rewards = batch["reward"] - batch_terminal = batch["is_state_terminal"] - batch_state = batch["state"] - batch_actions = batch["action"] - batch_discount = batch["discount"] - with torch.no_grad(), pfrl.utils.evaluating( self.target_policy ), pfrl.utils.evaluating(self.target_q_func1), pfrl.utils.evaluating( self.target_q_func2 ): + batch_state = batch["state"] + batch_actions = batch["action"] + batch_next_state = batch["next_state"] + batch_rewards = batch["reward"] + batch_terminal = batch["is_state_terminal"] + batch_discount = batch["discount"] + next_actions = self.target_policy_smoothing_func( self.target_policy(batch_next_state).sample() )