Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pervasive reshape bugs in train_cg? #12

Open
cooijmanstim opened this issue Apr 5, 2021 · 1 comment
Open

Pervasive reshape bugs in train_cg? #12

cooijmanstim opened this issue Apr 5, 2021 · 1 comment

Comments

@cooijmanstim
Copy link

cooijmanstim commented Apr 5, 2021

Unless I misunderstand what the code is trying to do, the following pattern in train_cg.py is a bug:

trainBatch1 = [[], [], [], [], [], []]  # line 137
...
while j < max_epLength:  # line 149
  ...
  trainBatch1[3].append(s1)  # line 180
  ...
...
# line 236:
last_state = np.reshape(
  np.concatenate(trainBatch1[3], axis=0),
  [batch_size, trace_length, env.ob_space_shape[0],
   env.ob_space_shape[1], env.ob_space_shape[2]])[:,-1,:,:,:]

The issue is that np.concatenate(trainBatch1[3], axis=0) is stored in memory with the time (trace_length) axis first and the batch axis second, and should be reshaped to [trace_length, batch_size, ...] and then transposed to move the batch axis forward. Reshaping straight to [batch_size, trace_length, ...] will silently misinterpret the order in which the elements are stored in memory.

The same buggy append-reshape pattern happens for basically all the things stored in trainBatch0, trainBatch1, with the offending reshapes happening in various places in other files, which expect [batch, time] storage order. I think the easiest fix would be to establish the desired storage order of trainBatch1 right after the loop over j < max_epLength, e.g.

trainBatch1 = [np.stack(seq, axis=1) for seq in trainBatch1]

and similar for trainBatch0. Now trainBatch1[3] has exactly the shape you want it to have at line 236, so last_state = trainBatch1[3][:, -1, :, :, :] will do. You can still trainBatch1[i].reshape([batch_size * trace_length, ...]) if you need the batch and time axes flattened, and this will correctly reshape back to [batch_size, trace_length, ...].

@jakobnicolaus
Copy link
Collaborator

Thank you for finding and filing this - looks like a bad bug :(
Please submit a PR for the fix when you have a chance. Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants