-
Notifications
You must be signed in to change notification settings - Fork 312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add gpu pickleable module to torch #2265
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -369,11 +369,11 @@ def state_dict_to(state_dict, device): | |
"""Move optimizer to a specified device. | ||
|
||
Args: | ||
state_dict (dict): state dictionary to be moved | ||
state_dict (dict): state dictionary to be moved. | ||
device (str): ID of GPU or CPU. | ||
|
||
Returns: | ||
dict: state dictionary moved to device | ||
dict: state dictionary moved to device. | ||
""" | ||
for param in state_dict.values(): | ||
if isinstance(param, torch.Tensor): | ||
|
@@ -383,6 +383,61 @@ def state_dict_to(state_dict, device): | |
return state_dict | ||
|
||
|
||
# pylint: disable=abstract-method | ||
class Module(nn.Module): | ||
"""Wrapper class for Garage PyTorch modules.""" | ||
|
||
def __getstate__(self): | ||
"""Save the current device of the module before saving module state. | ||
|
||
Returns: | ||
dict: State dictionary. | ||
""" | ||
# do we always run experiments on global device? | ||
save_from_device = global_device() | ||
self.to('cpu') | ||
state = self.__dict__.copy() | ||
state['device'] = save_from_device | ||
return state | ||
|
||
def __setstate__(self, state): | ||
"""Restore the module state, moving it back to the original device if possible. | ||
|
||
Args: | ||
state (dict): State dictionary. | ||
|
||
""" | ||
system_device = global_device() | ||
save_from_device = state['device'] | ||
if save_from_device == system_device: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if this if statement is really needed. I think the idea here is to pre-move everything, as an optimization, but I'm not sure if that's actually faster. If you're going to do this, please use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason I was doing it this way was because I didn't know if you needed to move any of the other attributes in the dict as well as the module itself, so I'm just moving all internal attributes first and then moving the module itself. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I see. |
||
module_state_to(state, system_device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can use the other state-dict moving function you wrote here (even though this kinda isn't a state dict). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will every movable parameter of a nn.module be a tensor or a dict? |
||
# what to do if it doesn't match? | ||
# do I need to set global device to the current device? | ||
del state['device'] | ||
self.__dict__ = state | ||
if save_from_device == system_device: | ||
self.to(system_device) | ||
|
||
|
||
def module_state_to(state, device): | ||
"""Move elements of a module state to a device. | ||
|
||
Notes - are there other types of parameters in a | ||
module state to be moved? are some of them recursive? | ||
|
||
Args: | ||
state (dict): State dictionary. | ||
device (str): ID of GPU or CPU. | ||
|
||
Returns: | ||
dict: moved state dict. | ||
""" | ||
for param in state.values(): | ||
if hasattr(param, 'to'): | ||
param = param.to(device) | ||
return state | ||
|
||
|
||
# pylint: disable=W0223 | ||
class NonLinearity(nn.Module): | ||
"""Wrapper class for non linear function or module. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should call this 'garage.global_device' or something, so that it definitely doesn't conflict with any sub-field names.