-
Notifications
You must be signed in to change notification settings - Fork 391
Description
First of all, thank you so much to everyone working on this project. I'm currently using it in several projects and it's (mostly) been a very pleasant experience.
But there is one thing in torchrl that I find particularly confusing.
Motivation
I'm confused about whether weight updates in LossModule
s get propagated back to the original module that was passed in to the LossModule
. Let's take DQNLoss
as an example and suppose we have a TensorDictModule
called module
that we pass in as the value_network
argument. As I understand it, the loss module has its own weight copies of the network that was passed in as value_network
. So when we call loss_module.backward()
, the weights of the original network are not updated.
When we give the original module (or rather an epsilon-greedy version of it) to a SyncDataCollector
, we have to call collector.update_policy_weights_()
in order for the loss module weights to get copied to the original module. But what if we want to do inference without a collector?
I've seen snippets like
with self.loss_module.qvalue_network_params["module"]["0"].to_module(self.module):
# self.module can now be used with updated weights
but I'm not sure if that's the proper way to do it...
I've tried looking into the source code for SyncDataCollector
and VanillaWeightUpdater
but I think that this would be a very useful piece of documentation to have.
Solution
A documentation page that briefly mentions
- whether weight updates from the loss module are propagated back to the original module
- an example of how to use a policy for inference (without using a collector)
Alternatives
Maybe I'm not understanding something that is obvious to everyone else. In that case, please point me to a source that explains these things.
Checklist
- I have checked that there is no similar issue in the repo (required)