File tree Expand file tree Collapse file tree 3 files changed +21
-22
lines changed Expand file tree Collapse file tree 3 files changed +21
-22
lines changed Original file line number Diff line number Diff line change @@ -148,16 +148,16 @@ def sync_target_network(self):
148
148
def compute_critic_loss (self , batch ):
149
149
"""Compute loss for critic."""
150
150
151
- batch_next_state = batch ["next_state" ]
152
- batch_rewards = batch ["reward" ]
153
- batch_terminal = batch ["is_state_terminal" ]
154
- batch_state = batch ["state" ]
155
- batch_actions = batch ["action" ]
156
- batchsize = len (batch_rewards )
157
-
158
151
with torch .no_grad ():
152
+ batch_state = batch ["state" ]
153
+ batch_actions = batch ["action" ]
154
+ batch_rewards = batch ["reward" ]
155
+ batchsize = len (batch_rewards )
159
156
assert not self .recurrent
157
+ batch_next_state = batch ["next_state" ]
158
+ batch_terminal = batch ["is_state_terminal" ]
160
159
next_actions = self .target_policy (batch_next_state ).sample ()
160
+
161
161
next_q = self .target_q_function ((batch_next_state , next_actions ))
162
162
target_q = batch_rewards + self .gamma * (
163
163
1.0 - batch_terminal
Original file line number Diff line number Diff line change @@ -213,17 +213,16 @@ def sync_target_network(self):
213
213
214
214
def update_q_func (self , batch ):
215
215
"""Compute loss for a given Q-function."""
216
-
217
- batch_next_state = batch ["next_state" ]
218
- batch_rewards = batch ["reward" ]
219
- batch_terminal = batch ["is_state_terminal" ]
220
- batch_state = batch ["state" ]
221
- batch_actions = batch ["action" ]
222
- batch_discount = batch ["discount" ]
223
-
224
216
with torch .no_grad (), pfrl .utils .evaluating (self .policy ), pfrl .utils .evaluating (
225
217
self .target_q_func1
226
218
), pfrl .utils .evaluating (self .target_q_func2 ):
219
+ batch_state = batch ["state" ]
220
+ batch_actions = batch ["action" ]
221
+ batch_next_state = batch ["next_state" ]
222
+ batch_rewards = batch ["reward" ]
223
+ batch_terminal = batch ["is_state_terminal" ]
224
+ batch_discount = batch ["discount" ]
225
+
227
226
next_action_distrib = self .policy (batch_next_state )
228
227
next_actions = next_action_distrib .sample ()
229
228
next_log_prob = next_action_distrib .log_prob (next_actions )
Original file line number Diff line number Diff line change @@ -181,18 +181,18 @@ def sync_target_network(self):
181
181
def update_q_func (self , batch ):
182
182
"""Compute loss for a given Q-function."""
183
183
184
- batch_next_state = batch ["next_state" ]
185
- batch_rewards = batch ["reward" ]
186
- batch_terminal = batch ["is_state_terminal" ]
187
- batch_state = batch ["state" ]
188
- batch_actions = batch ["action" ]
189
- batch_discount = batch ["discount" ]
190
-
191
184
with torch .no_grad (), pfrl .utils .evaluating (
192
185
self .target_policy
193
186
), pfrl .utils .evaluating (self .target_q_func1 ), pfrl .utils .evaluating (
194
187
self .target_q_func2
195
188
):
189
+ batch_state = batch ["state" ]
190
+ batch_actions = batch ["action" ]
191
+ batch_next_state = batch ["next_state" ]
192
+ batch_rewards = batch ["reward" ]
193
+ batch_terminal = batch ["is_state_terminal" ]
194
+ batch_discount = batch ["discount" ]
195
+
196
196
next_actions = self .target_policy_smoothing_func (
197
197
self .target_policy (batch_next_state ).sample ()
198
198
)
You can’t perform that action at this time.
0 commit comments