File tree Expand file tree Collapse file tree 3 files changed +21
-22
lines changed Expand file tree Collapse file tree 3 files changed +21
-22
lines changed Original file line number Diff line number Diff line change @@ -148,16 +148,16 @@ def sync_target_network(self):
148148    def  compute_critic_loss (self , batch ):
149149        """Compute loss for critic.""" 
150150
151-         batch_next_state  =  batch ["next_state" ]
152-         batch_rewards  =  batch ["reward" ]
153-         batch_terminal  =  batch ["is_state_terminal" ]
154-         batch_state  =  batch ["state" ]
155-         batch_actions  =  batch ["action" ]
156-         batchsize  =  len (batch_rewards )
157- 
158151        with  torch .no_grad ():
152+             batch_state  =  batch ["state" ]
153+             batch_actions  =  batch ["action" ]
154+             batch_rewards  =  batch ["reward" ]
155+             batchsize  =  len (batch_rewards )
159156            assert  not  self .recurrent 
157+             batch_next_state  =  batch ["next_state" ]
158+             batch_terminal  =  batch ["is_state_terminal" ]
160159            next_actions  =  self .target_policy (batch_next_state ).sample ()
160+             
161161            next_q  =  self .target_q_function ((batch_next_state , next_actions ))
162162            target_q  =  batch_rewards  +  self .gamma  *  (
163163                1.0  -  batch_terminal 
Original file line number Diff line number Diff line change @@ -213,17 +213,16 @@ def sync_target_network(self):
213213
214214    def  update_q_func (self , batch ):
215215        """Compute loss for a given Q-function.""" 
216- 
217-         batch_next_state  =  batch ["next_state" ]
218-         batch_rewards  =  batch ["reward" ]
219-         batch_terminal  =  batch ["is_state_terminal" ]
220-         batch_state  =  batch ["state" ]
221-         batch_actions  =  batch ["action" ]
222-         batch_discount  =  batch ["discount" ]
223- 
224216        with  torch .no_grad (), pfrl .utils .evaluating (self .policy ), pfrl .utils .evaluating (
225217            self .target_q_func1 
226218        ), pfrl .utils .evaluating (self .target_q_func2 ):
219+             batch_state  =  batch ["state" ]
220+             batch_actions  =  batch ["action" ]
221+             batch_next_state  =  batch ["next_state" ]
222+             batch_rewards  =  batch ["reward" ]
223+             batch_terminal  =  batch ["is_state_terminal" ]
224+             batch_discount  =  batch ["discount" ]
225+ 
227226            next_action_distrib  =  self .policy (batch_next_state )
228227            next_actions  =  next_action_distrib .sample ()
229228            next_log_prob  =  next_action_distrib .log_prob (next_actions )
Original file line number Diff line number Diff line change @@ -181,18 +181,18 @@ def sync_target_network(self):
181181    def  update_q_func (self , batch ):
182182        """Compute loss for a given Q-function.""" 
183183
184-         batch_next_state  =  batch ["next_state" ]
185-         batch_rewards  =  batch ["reward" ]
186-         batch_terminal  =  batch ["is_state_terminal" ]
187-         batch_state  =  batch ["state" ]
188-         batch_actions  =  batch ["action" ]
189-         batch_discount  =  batch ["discount" ]
190- 
191184        with  torch .no_grad (), pfrl .utils .evaluating (
192185            self .target_policy 
193186        ), pfrl .utils .evaluating (self .target_q_func1 ), pfrl .utils .evaluating (
194187            self .target_q_func2 
195188        ):
189+             batch_state  =  batch ["state" ]
190+             batch_actions  =  batch ["action" ]
191+             batch_next_state  =  batch ["next_state" ]
192+             batch_rewards  =  batch ["reward" ]
193+             batch_terminal  =  batch ["is_state_terminal" ]
194+             batch_discount  =  batch ["discount" ]
195+ 
196196            next_actions  =  self .target_policy_smoothing_func (
197197                self .target_policy (batch_next_state ).sample ()
198198            )
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments