Skip to content

Commit bf34321

Browse files
committed
fix device mismatch
1 parent d53434f commit bf34321

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchbeast/core/models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def initial_state(self, batch_size=1):
9090

9191
def forward(self, obs, done, core_state):
9292
T, B, C, H, W = obs["canvas"].shape
93-
grid = self._grid(T * B, H, W)
93+
grid = self._grid(T * B, H, W).to(obs["canvas"].device)
9494

9595
notdone = (~done).float()
9696
obs["prev_action"] = obs["prev_action"] * notdone.unsqueeze(dim=2)
@@ -209,7 +209,9 @@ def forward(self, h, actions=None):
209209
return actions, logits
210210

211211
else:
212-
dict_actions = collections.OrderedDict({k: None for k in self._action_order})
212+
dict_actions = collections.OrderedDict(
213+
{k: None for k in self._action_order}
214+
)
213215

214216
for k in self._order:
215217
logit = self.decode[k](h)

0 commit comments

Comments
 (0)