-
Notifications
You must be signed in to change notification settings - Fork 312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add gpu pickleable module to torch #2265
Conversation
fe138b1
to
2805fbe
Compare
2805fbe
to
3654cf0
Compare
Codecov Report
@@ Coverage Diff @@
## master #2265 +/- ##
==========================================
- Coverage 91.20% 91.17% -0.03%
==========================================
Files 199 199
Lines 11627 10996 -631
Branches 1557 1392 -165
==========================================
- Hits 10604 10026 -578
+ Misses 754 703 -51
+ Partials 269 267 -2
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good. There's a few minor improvements I've mentioned above. We should also add this new class to q_functions
and value_functions
(as well as probably everything in garage.torch.modules
), but that can happen in another PR.
save_from_device = global_device() | ||
self.to('cpu') | ||
state = self.__dict__.copy() | ||
state['device'] = save_from_device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should call this 'garage.global_device' or something, so that it definitely doesn't conflict with any sub-field names.
system_device = global_device() | ||
save_from_device = state['device'] | ||
if save_from_device == system_device: | ||
module_state_to(state, system_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can use the other state-dict moving function you wrote here (even though this kinda isn't a state dict).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will every movable parameter of a nn.module be a tensor or a dict?
""" | ||
system_device = global_device() | ||
save_from_device = state['device'] | ||
if save_from_device == system_device: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this if statement is really needed. I think the idea here is to pre-move everything, as an optimization, but I'm not sure if that's actually faster. If you're going to do this, please use timeit.timeit
to measure the performance difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I was doing it this way was because I didn't know if you needed to move any of the other attributes in the dict as well as the module itself, so I'm just moving all internal attributes first and then moving the module itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. .to()
should know how to handle that already (since that's how the module was moved to the device in the first place). If modules need to move something besides the default behavior, they should override .to()
themselves.
A potential fix for #2079, by creating an garage.torch.Module wrapper class which implements pickling with moving to cpu on pickle and back to original device on unpickle if the device exists.