Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gpu pickleable module to torch #2265

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/garage/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from garage.torch._functions import (as_torch_dict, compute_advantages,
expand_var, filter_valids, flatten_batch,
flatten_to_single_vector, global_device,
NonLinearity, np_to_torch,
Module, NonLinearity, np_to_torch,
output_height_2d, output_width_2d,
pad_to_last, prefer_gpu,
product_of_gaussians, set_gpu_mode,
Expand Down Expand Up @@ -31,4 +31,5 @@
'state_dict_to',
'torch_to_np',
'update_module_params',
'Module',
]
59 changes: 57 additions & 2 deletions src/garage/torch/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,11 @@ def state_dict_to(state_dict, device):
"""Move optimizer to a specified device.

Args:
state_dict (dict): state dictionary to be moved
state_dict (dict): state dictionary to be moved.
device (str): ID of GPU or CPU.

Returns:
dict: state dictionary moved to device
dict: state dictionary moved to device.
"""
for param in state_dict.values():
if isinstance(param, torch.Tensor):
Expand All @@ -383,6 +383,61 @@ def state_dict_to(state_dict, device):
return state_dict


# pylint: disable=abstract-method
class Module(nn.Module):
"""Wrapper class for Garage PyTorch modules."""

def __getstate__(self):
"""Save the current device of the module before saving module state.

Returns:
dict: State dictionary.
"""
# do we always run experiments on global device?
save_from_device = global_device()
self.to('cpu')
state = self.__dict__.copy()
state['device'] = save_from_device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should call this 'garage.global_device' or something, so that it definitely doesn't conflict with any sub-field names.

return state

def __setstate__(self, state):
"""Restore the module state, moving it back to the original device if possible.

Args:
state (dict): State dictionary.

"""
system_device = global_device()
save_from_device = state['device']
if save_from_device == system_device:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this if statement is really needed. I think the idea here is to pre-move everything, as an optimization, but I'm not sure if that's actually faster. If you're going to do this, please use timeit.timeit to measure the performance difference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I was doing it this way was because I didn't know if you needed to move any of the other attributes in the dict as well as the module itself, so I'm just moving all internal attributes first and then moving the module itself.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. .to() should know how to handle that already (since that's how the module was moved to the device in the first place). If modules need to move something besides the default behavior, they should override .to() themselves.

module_state_to(state, system_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use the other state-dict moving function you wrote here (even though this kinda isn't a state dict).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will every movable parameter of a nn.module be a tensor or a dict?

# what to do if it doesn't match?
# do I need to set global device to the current device?
del state['device']
self.__dict__ = state
if save_from_device == system_device:
self.to(system_device)


def module_state_to(state, device):
"""Move elements of a module state to a device.

Notes - are there other types of parameters in a
module state to be moved? are some of them recursive?

Args:
state (dict): State dictionary.
device (str): ID of GPU or CPU.

Returns:
dict: moved state dict.
"""
for param in state.values():
if hasattr(param, 'to'):
param = param.to(device)
return state


# pylint: disable=W0223
class NonLinearity(nn.Module):
"""Wrapper class for non linear function or module.
Expand Down
5 changes: 2 additions & 3 deletions src/garage/torch/policies/policy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Base Policy."""
import abc

import torch

from garage.np.policies import Policy as BasePolicy
from garage.torch import Module


class Policy(torch.nn.Module, BasePolicy, abc.ABC):
class Policy(Module, BasePolicy, abc.ABC):
"""Policy base class.

Args:
Expand Down
34 changes: 34 additions & 0 deletions tests/garage/torch/policies/test_categorical_cnn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from garage.envs import GymEnv
from garage.torch import global_device, set_gpu_mode
from garage.torch.policies import CategoricalCNNPolicy

from tests.fixtures.envs.dummy import DummyDictEnv, DummyDiscretePixelEnv
Expand Down Expand Up @@ -152,3 +153,36 @@ def test_obs_unflattened(self, hidden_channels, kernel_sizes, strides,
obs = env.observation_space.sample()
action, _ = policy.get_action(env.observation_space.flatten(obs))
env.step(action)


@pytest.mark.gpu
@pytest.mark.parametrize(
'hidden_channels, kernel_sizes, strides, hidden_sizes', [
((3, ), (3, ), (1, ), (4, )),
((3, 3), (3, 3), (1, 1), (4, 4)),
((3, 3), (3, 3), (2, 2), (4, 4)),
])
def test_is_pickleable_on_gpu(hidden_channels, kernel_sizes, strides,
hidden_sizes):
"""Test if policy is pickable when on gpu."""
set_gpu_mode(True)
env = GymEnv(DummyDiscretePixelEnv(), is_image=True)
policy = CategoricalCNNPolicy(env_spec=env.spec,
image_format='NHWC',
kernel_sizes=kernel_sizes,
hidden_channels=hidden_channels,
strides=strides,
hidden_sizes=hidden_sizes)
policy.to(global_device())
env.reset()
obs = env.step(1).observation

output_action_1, _ = policy.get_action(obs)

p = cloudpickle.dumps(policy)
policy_pickled = cloudpickle.loads(p)
output_action_2, _ = policy_pickled.get_action(obs)

assert env.action_space.contains(output_action_1)
assert env.action_space.contains(output_action_2)
assert output_action_1.shape == output_action_2.shape