Skip to content

Commit

Permalink
Add gpu pickleable module to torch
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyiwu9494 committed Apr 15, 2021
1 parent 3fb4f50 commit fe138b1
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/garage/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
compute_advantages, expand_var,
filter_valids, flatten_batch,
flatten_to_single_vector, global_device,
NonLinearity, output_height_2d,
Module, NonLinearity, output_height_2d,
output_width_2d, pad_to_last, prefer_gpu,
product_of_gaussians, set_gpu_mode,
soft_update_model, state_dict_to,
Expand All @@ -31,4 +31,5 @@
'output_height_2d',
'expand_var',
'state_dict_to',
'Module',
]
59 changes: 57 additions & 2 deletions src/garage/torch/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,11 @@ def state_dict_to(state_dict, device):
"""Move optimizer to a specified device.
Args:
state_dict (dict): state dictionary to be moved
state_dict (dict): state dictionary to be moved.
device (str): ID of GPU or CPU.
Returns:
dict: state dictionary moved to device
dict: state dictionary moved to device.
"""
for param in state_dict.values():
if isinstance(param, torch.Tensor):
Expand All @@ -345,6 +345,61 @@ def state_dict_to(state_dict, device):
return state_dict


# pylint: disable=abstract-method
class Module(nn.Module):
"""Wrapper class for Garage PyTorch modules."""

def __getstate__(self):
"""Save the current device of the module before saving module state.
Returns:
dict: State dictionary.
"""
# do we always run experiments on global device?
save_from_device = global_device()
self.to('cpu')
state = self.__dict__.copy()
state['device'] = save_from_device
return state

def __setstate__(self, state):
"""Restore the module state, moving it back to the original device if possible.
Args:
state (dict): State dictionary.
"""
system_device = global_device()
save_from_device = state['device']
if save_from_device == system_device:
module_state_to(state, system_device)
# what to do if it doesn't match?
# do I need to set global device to the current device?
del state['device']
self.__dict__ = state
if save_from_device == system_device:
self.to(system_device)


def module_state_to(state, device):
"""Move elements of a module state to a device.
Notes - are there other types of parameters in a
module state to be moved? are some of them recursive?
Args:
state (dict): State dictionary.
device (str): ID of GPU or CPU.
Returns:
dict: moved state dict.
"""
for param in state.values():
if hasattr(param, 'to'):
param = param.to(device)
return state


# pylint: disable=W0223
class NonLinearity(nn.Module):
"""Wrapper class for non linear function or module.
Expand Down
5 changes: 2 additions & 3 deletions src/garage/torch/policies/policy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Base Policy."""
import abc

import torch

from garage.np.policies import Policy as BasePolicy
from garage.torch import Module


class Policy(torch.nn.Module, BasePolicy, abc.ABC):
class Policy(Module, BasePolicy, abc.ABC):
"""Policy base class.
Args:
Expand Down
34 changes: 34 additions & 0 deletions tests/garage/torch/policies/test_categorical_cnn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from garage.envs import GymEnv
from garage.torch import global_device, set_gpu_mode
from garage.torch.policies import CategoricalCNNPolicy

from tests.fixtures.envs.dummy import DummyDictEnv, DummyDiscretePixelEnv
Expand Down Expand Up @@ -152,3 +153,36 @@ def test_obs_unflattened(self, hidden_channels, kernel_sizes, strides,
obs = env.observation_space.sample()
action, _ = policy.get_action(env.observation_space.flatten(obs))
env.step(action)


@pytest.mark.gpu
@pytest.mark.parametrize(
'hidden_channels, kernel_sizes, strides, hidden_sizes', [
((3, ), (3, ), (1, ), (4, )),
((3, 3), (3, 3), (1, 1), (4, 4)),
((3, 3), (3, 3), (2, 2), (4, 4)),
])
def test_is_pickleable_on_gpu(hidden_channels, kernel_sizes, strides,
hidden_sizes):
"""Test if policy is pickable when on gpu."""
set_gpu_mode(True)
env = GymEnv(DummyDiscretePixelEnv(), is_image=True)
policy = CategoricalCNNPolicy(env_spec=env.spec,
image_format='NHWC',
kernel_sizes=kernel_sizes,
hidden_channels=hidden_channels,
strides=strides,
hidden_sizes=hidden_sizes)
policy.to(global_device())
env.reset()
obs = env.step(1).observation

output_action_1, _ = policy.get_action(obs)

p = cloudpickle.dumps(policy)
policy_pickled = cloudpickle.loads(p)
output_action_2, _ = policy_pickled.get_action(obs)

assert env.action_space.contains(output_action_1)
assert env.action_space.contains(output_action_2)
assert output_action_1.shape == output_action_2.shape

0 comments on commit fe138b1

Please sign in to comment.