@@ -407,6 +407,7 @@ def __init__(
407
407
)
408
408
else :
409
409
self .critic_coef = None
410
+ self ._has_critic = bool (self .critic_coef is not None and self .critic_coef > 0 )
410
411
self .loss_critic_type = loss_critic_type
411
412
self .normalize_advantage = normalize_advantage
412
413
self .normalize_advantage_exclude_dims = normalize_advantage_exclude_dims
@@ -689,7 +690,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
689
690
"target_actor_network_params" ,
690
691
"target_critic_network_params" ,
691
692
)
692
- if self .critic_coef is not None :
693
+ if self ._has_critic :
693
694
return self .critic_coef * loss_value , clip_fraction
694
695
return loss_value , clip_fraction
695
696
@@ -737,7 +738,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
737
738
entropy = _sum_td_features (entropy )
738
739
td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
739
740
td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
740
- if self .critic_coef is not None :
741
+ if self ._has_critic :
741
742
loss_critic , value_clip_fraction = self .loss_critic (tensordict )
742
743
td_out .set ("loss_critic" , loss_critic )
743
744
if value_clip_fraction is not None :
@@ -1048,7 +1049,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
1048
1049
entropy = _sum_td_features (entropy )
1049
1050
td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1050
1051
td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
1051
- if self .critic_coef is not None and self . critic_coef > 0 :
1052
+ if self ._has_critic :
1052
1053
loss_critic , value_clip_fraction = self .loss_critic (tensordict )
1053
1054
td_out .set ("loss_critic" , loss_critic )
1054
1055
if value_clip_fraction is not None :
@@ -1375,7 +1376,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
1375
1376
entropy = _sum_td_features (entropy )
1376
1377
td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1377
1378
td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
1378
- if self .critic_coef is not None :
1379
+ if self ._has_critic :
1379
1380
loss_critic , value_clip_fraction = self .loss_critic (tensordict_copy )
1380
1381
td_out .set ("loss_critic" , loss_critic )
1381
1382
if value_clip_fraction is not None :
0 commit comments