Skip to content

Commit 9bc85f4

Browse files
author
Vincent Moens
committed
[BugFix] Fix compile compatibility of PPO losses
ghstack-source-id: b346033 Pull Request resolved: #2889
1 parent 82f8ec2 commit 9bc85f4

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

torchrl/objectives/ppo.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def __init__(
407407
)
408408
else:
409409
self.critic_coef = None
410+
self._has_critic = bool(self.critic_coef is not None and self.critic_coef > 0)
410411
self.loss_critic_type = loss_critic_type
411412
self.normalize_advantage = normalize_advantage
412413
self.normalize_advantage_exclude_dims = normalize_advantage_exclude_dims
@@ -689,7 +690,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
689690
"target_actor_network_params",
690691
"target_critic_network_params",
691692
)
692-
if self.critic_coef is not None:
693+
if self._has_critic:
693694
return self.critic_coef * loss_value, clip_fraction
694695
return loss_value, clip_fraction
695696

@@ -737,7 +738,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
737738
entropy = _sum_td_features(entropy)
738739
td_out.set("entropy", entropy.detach().mean()) # for logging
739740
td_out.set("loss_entropy", -self.entropy_coef * entropy)
740-
if self.critic_coef is not None:
741+
if self._has_critic:
741742
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
742743
td_out.set("loss_critic", loss_critic)
743744
if value_clip_fraction is not None:
@@ -1048,7 +1049,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
10481049
entropy = _sum_td_features(entropy)
10491050
td_out.set("entropy", entropy.detach().mean()) # for logging
10501051
td_out.set("loss_entropy", -self.entropy_coef * entropy)
1051-
if self.critic_coef is not None and self.critic_coef > 0:
1052+
if self._has_critic:
10521053
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
10531054
td_out.set("loss_critic", loss_critic)
10541055
if value_clip_fraction is not None:
@@ -1375,7 +1376,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
13751376
entropy = _sum_td_features(entropy)
13761377
td_out.set("entropy", entropy.detach().mean()) # for logging
13771378
td_out.set("loss_entropy", -self.entropy_coef * entropy)
1378-
if self.critic_coef is not None:
1379+
if self._has_critic:
13791380
loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy)
13801381
td_out.set("loss_critic", loss_critic)
13811382
if value_clip_fraction is not None:

0 commit comments

Comments
 (0)