Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimising encoder twice during CURL? #20

Open
wassname opened this issue Feb 20, 2021 · 9 comments
Open

Optimising encoder twice during CURL? #20

wassname opened this issue Feb 20, 2021 · 9 comments

Comments

@wassname
Copy link

wassname commented Feb 20, 2021

Thanks for sharing your code, it's great to be able to go through the implementation.

Maybe I'm misunderstanding this, but it seem that if you intend self.cpc_optimizer to only optimise W, then

self.cpc_optimizer = torch.optim.Adam(
    self.CURL.parameters(), lr=encoder_lr
)

should be

self.cpc_optimizer = torch.optim.Adam(
    self.CURL.parameters(recursive=False), lr=encoder_lr
)

or

self.cpc_optimizer = torch.optim.Adam(
    [self.CURL.W], lr=encoder_lr
)

The code I'm referring to is here and the torch docs for parameter are here. And I'm comparing it to section 4.7 of your paper.

As it stands it seems that encoder is optimised twice, once in encoder_optimizer and again in cpc_optimizer.

Or am I missing something?

@LostXine
Copy link

Hi @wassname
Thank you for pointing it out and I believe you are correct. Based on my very limited observation, optimizing the encoder once per call didn't affect the performance significantly in cheetah run. But it will be very helpful if someone could test it in more environments.
Thanks.

@IpadLi
Copy link

IpadLi commented Nov 20, 2021

Hi,

It seems that the encoder in actor never be updated either by loss or soft update (EMA), except that in the initialisation.

# tie encoders between actor and critic, and CURL and critic self.actor.encoder.copy_conv_weights_from(self.critic.encoder)

Only the encoder in critic/critic_target is updated by the critic_loss and the cpc.

Is there any insight for not updating the encoder in the actor?

@tejassp2002
Copy link

Hi @IpadLi, I wondered on this a while back and mailed @MishaLaskin about it.
This is the question I asked:

Why you are not updating the Shared encoder with the Actor Loss? Is there any specific reason for this?

@MishaLaskin 's reply:

I found that doing this resulted in more stable learning from pixels but it is also an empirical design choice and can be changed

@IpadLi
Copy link

IpadLi commented Nov 20, 2021

Hi @tejassp2002 Thanks a lot.

@Sobbbbbber
Copy link

Hi, can we integrate the update_critic function and update_cpc function by adding the critic_loss and cpc_loss together?
Meanwhile, we only need two optimizers.
Is it feasible?

self.cpc_optimizer = torch.optim.Adam([self.CURL.W], lr=encoder_lr)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999))
loss = critic_loss + cpc_loss
loss.backward()
self.critic_optimizer.step()
self.cpc_optimizer.step()

@KarlXing
Copy link

KarlXing commented Mar 5, 2022

Hi,

It seems that the encoder in actor never be updated either by loss or soft update (EMA), except that in the initialisation.

# tie encoders between actor and critic, and CURL and critic self.actor.encoder.copy_conv_weights_from(self.critic.encoder)

Only the encoder in critic/critic_target is updated by the critic_loss and the cpc.

Is there any insight for not updating the encoder in the actor?

The work of SAC+AE (https://arxiv.org/pdf/1910.01741.pdf) suggests to use the gradient from critic only (no actor) to update the encoder. Since this repo is based on the implementation of SAC+AE (as said in readme), I think CURL just follows it.

@yufeiwang63
Copy link

yufeiwang63 commented Apr 16, 2022

Hi @IpadLi, I wondered on this a while back and mailed @MishaLaskin about it. This is the question I asked:

Why you are not updating the Shared encoder with the Actor Loss? Is there any specific reason for this?

@MishaLaskin 's reply:

I found that doing this resulted in more stable learning from pixels but it is also an empirical design choice and can be changed

Hi, thanks for posting the reply from the author!
Yet I don't think the reply answers the question -- even if we don't update the encoder with the actor loss, why shouldn't the actor encoder weights be copied from the critic encoder weights after each update to the critic encoder using the critic loss and cpc loss?
It is a bit strange to me that two different encoders are used for the actor and the critic, where the paper seems to indicate there is only 1 shared encoder. Moreover, the weights of the actor encoder is never updated after initialization, so essentially only the MLP part of the actor is being trained/updated.

Update -- Sorry, the tie_weight function actually make the actor encoder and critic encoder share the same weights.

@RayYoh
Copy link

RayYoh commented Mar 16, 2023

Hi @IpadLi, I wondered on this a while back and mailed @MishaLaskin about it. This is the question I asked:

Why you are not updating the Shared encoder with the Actor Loss? Is there any specific reason for this?

@MishaLaskin 's reply:

I found that doing this resulted in more stable learning from pixels but it is also an empirical design choice and can be changed

Hi, thanks for posting the reply from the author! Yet I don't think the reply answers the question -- even if we don't update the encoder with the actor loss, why shouldn't the actor encoder weights be copied from the critic encoder weights after each update to the critic encoder using the critic loss and cpc loss? It is a bit strange to me that two different encoders are used for the actor and the critic, where the paper seems to indicate there is only 1 shared encoder. Moreover, the weights of the actor encoder is never updated after initialization, so essentially only the MLP part of the actor is being trained/updated.

Update -- Sorry, the tie_weight function actually make the actor encoder and critic encoder share the same weights.

Hello! Does it mean the weights of actor encoder are still same with the critic encoder after the critic encoder is updated?

@LiuZhenxian123
Copy link

Hi @IpadLi, I wondered on this a while back and mailed @MishaLaskin about it. This is the question I asked:

Why you are not updating the Shared encoder with the Actor Loss? Is there any specific reason for this?

@MishaLaskin 's reply:

I found that doing this resulted in more stable learning from pixels but it is also an empirical design choice and can be changed

Hi, thanks for posting the reply from the author! Yet I don't think the reply answers the question -- even if we don't update the encoder with the actor loss, why shouldn't the actor encoder weights be copied from the critic encoder weights after each update to the critic encoder using the critic loss and cpc loss? It is a bit strange to me that two different encoders are used for the actor and the critic, where the paper seems to indicate there is only 1 shared encoder. Moreover, the weights of the actor encoder is never updated after initialization, so essentially only the MLP part of the actor is being trained/updated.
Update -- Sorry, the tie_weight function actually make the actor encoder and critic encoder share the same weights.

Hello! Does it mean the weights of actor encoder are still same with the critic encoder after the critic encoder is updated?

yes,actor and critic indeed share the same encoder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants