You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Unless I misunderstand what the code is trying to do, the following pattern in train_cg.py is a bug:
trainBatch1= [[], [], [], [], [], []] # line 137
...
whilej<max_epLength: # line 149
...
trainBatch1[3].append(s1) # line 180
...
...
# line 236:last_state=np.reshape(
np.concatenate(trainBatch1[3], axis=0),
[batch_size, trace_length, env.ob_space_shape[0],
env.ob_space_shape[1], env.ob_space_shape[2]])[:,-1,:,:,:]
The issue is that np.concatenate(trainBatch1[3], axis=0) is stored in memory with the time (trace_length) axis first and the batch axis second, and should be reshaped to [trace_length, batch_size, ...] and then transposed to move the batch axis forward. Reshaping straight to [batch_size, trace_length, ...] will silently misinterpret the order in which the elements are stored in memory.
The same buggy append-reshape pattern happens for basically all the things stored in trainBatch0, trainBatch1, with the offending reshapes happening in various places in other files, which expect [batch, time] storage order. I think the easiest fix would be to establish the desired storage order of trainBatch1 right after the loop over j < max_epLength, e.g.
and similar for trainBatch0. Now trainBatch1[3] has exactly the shape you want it to have at line 236, so last_state = trainBatch1[3][:, -1, :, :, :] will do. You can still trainBatch1[i].reshape([batch_size * trace_length, ...]) if you need the batch and time axes flattened, and this will correctly reshape back to [batch_size, trace_length, ...].
The text was updated successfully, but these errors were encountered:
Unless I misunderstand what the code is trying to do, the following pattern in
train_cg.py
is a bug:The issue is that
np.concatenate(trainBatch1[3], axis=0)
is stored in memory with the time (trace_length
) axis first and the batch axis second, and should be reshaped to[trace_length, batch_size, ...]
and then transposed to move the batch axis forward. Reshaping straight to[batch_size, trace_length, ...]
will silently misinterpret the order in which the elements are stored in memory.The same buggy append-reshape pattern happens for basically all the things stored in
trainBatch0, trainBatch1
, with the offending reshapes happening in various places in other files, which expect[batch, time]
storage order. I think the easiest fix would be to establish the desired storage order oftrainBatch1
right after the loop overj < max_epLength
, e.g.and similar for
trainBatch0
. NowtrainBatch1[3]
has exactly the shape you want it to have at line 236, solast_state = trainBatch1[3][:, -1, :, :, :]
will do. You can stilltrainBatch1[i].reshape([batch_size * trace_length, ...])
if you need the batch and time axes flattened, and this will correctly reshape back to[batch_size, trace_length, ...]
.The text was updated successfully, but these errors were encountered: