Skip to content

Commit

Permalink
Moves all batch info into no_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
prabhatnagarajan committed Apr 13, 2024
1 parent b29533b commit 2842cb6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 22 deletions.
14 changes: 7 additions & 7 deletions pfrl/agents/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,16 @@ def sync_target_network(self):
def compute_critic_loss(self, batch):
"""Compute loss for critic."""

batch_next_state = batch["next_state"]
batch_rewards = batch["reward"]
batch_terminal = batch["is_state_terminal"]
batch_state = batch["state"]
batch_actions = batch["action"]
batchsize = len(batch_rewards)

with torch.no_grad():
batch_state = batch["state"]
batch_actions = batch["action"]
batch_rewards = batch["reward"]
batchsize = len(batch_rewards)
assert not self.recurrent
batch_next_state = batch["next_state"]
batch_terminal = batch["is_state_terminal"]
next_actions = self.target_policy(batch_next_state).sample()

next_q = self.target_q_function((batch_next_state, next_actions))
target_q = batch_rewards + self.gamma * (
1.0 - batch_terminal
Expand Down
15 changes: 7 additions & 8 deletions pfrl/agents/soft_actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,16 @@ def sync_target_network(self):

def update_q_func(self, batch):
"""Compute loss for a given Q-function."""

batch_next_state = batch["next_state"]
batch_rewards = batch["reward"]
batch_terminal = batch["is_state_terminal"]
batch_state = batch["state"]
batch_actions = batch["action"]
batch_discount = batch["discount"]

with torch.no_grad(), pfrl.utils.evaluating(self.policy), pfrl.utils.evaluating(
self.target_q_func1
), pfrl.utils.evaluating(self.target_q_func2):
batch_state = batch["state"]
batch_actions = batch["action"]
batch_next_state = batch["next_state"]
batch_rewards = batch["reward"]
batch_terminal = batch["is_state_terminal"]
batch_discount = batch["discount"]

next_action_distrib = self.policy(batch_next_state)
next_actions = next_action_distrib.sample()
next_log_prob = next_action_distrib.log_prob(next_actions)
Expand Down
14 changes: 7 additions & 7 deletions pfrl/agents/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,18 @@ def sync_target_network(self):
def update_q_func(self, batch):
"""Compute loss for a given Q-function."""

batch_next_state = batch["next_state"]
batch_rewards = batch["reward"]
batch_terminal = batch["is_state_terminal"]
batch_state = batch["state"]
batch_actions = batch["action"]
batch_discount = batch["discount"]

with torch.no_grad(), pfrl.utils.evaluating(
self.target_policy
), pfrl.utils.evaluating(self.target_q_func1), pfrl.utils.evaluating(
self.target_q_func2
):
batch_state = batch["state"]
batch_actions = batch["action"]
batch_next_state = batch["next_state"]
batch_rewards = batch["reward"]
batch_terminal = batch["is_state_terminal"]
batch_discount = batch["discount"]

next_actions = self.target_policy_smoothing_func(
self.target_policy(batch_next_state).sample()
)
Expand Down

0 comments on commit 2842cb6

Please sign in to comment.