Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training a recurrent policy #4

Open
erschmidt opened this issue Feb 2, 2018 · 4 comments
Open

Training a recurrent policy #4

erschmidt opened this issue Feb 2, 2018 · 4 comments

Comments

@erschmidt
Copy link

I am still struggling with the implementation of a recurrent policy. The trick from #1 worked and I can now start running my RNN GAIL Network. But no matter what I try the mean reward is actually decreasing over time.

I am currently using the same ValueNet and Advantage Estimation as in the repository.

Do I have to change something in trpo_step in order to make RNN Policies work?

Thank you so much!

@Khrylx
Copy link
Owner

Khrylx commented Feb 2, 2018

The value function also needs to be an RNN, or you can pass the rnn output from the policy to the value net.

@sandeepnRES
Copy link

Any help regarding how to use RNN? should it be used single step, or over an episode(specifying times steps as sequences)? and how will the backpropagation take place, all at once at the end of episode? Do you have any code regarding this available?

@Khrylx
Copy link
Owner

Khrylx commented Jan 9, 2019

I don't have any code for RNN yet. But I can imagine how it can be done.

Suppose the agent collected a batch of three episodes of length 5,6,7, so the total length is 18. You need to pad these episodes to be of the same length 7. So you will have an input size 7 x 3 x d, where d is your input dim. Then you pass it through an LSTM, which will give you an output of size 7 x 3 x h, where h is the output dim. You reshape it into 21 x h, and find where the episode steps (5,6,7) correspond to by indexing operation, then you will get the right output of 18 x h. Then you can just pass this to an MLP or any other network you want. All these operations are achievable by pytorch.

@Khrylx Khrylx closed this as completed Jan 9, 2019
@Khrylx Khrylx reopened this Jan 9, 2019
@sandeepnRES
Copy link

sandeepnRES commented Jan 14, 2019

Okay, but should the functions like get_log_probability which is used in ppo step, be updated? because they compute forward propagation, getting new hidden vector after every backpropagation. or should the new hidden vector be ignored, and old hidden vector be used(that was obtained during data collection by agent).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants