From fe138b1194201ab70297450c64b440ce055ede1d Mon Sep 17 00:00:00 2001 From: Ziyi Wu Date: Thu, 15 Apr 2021 01:22:35 -0700 Subject: [PATCH] Add gpu pickleable module to torch --- src/garage/torch/__init__.py | 3 +- src/garage/torch/_functions.py | 59 ++++++++++++++++++- src/garage/torch/policies/policy.py | 5 +- .../policies/test_categorical_cnn_policy.py | 34 +++++++++++ 4 files changed, 95 insertions(+), 6 deletions(-) diff --git a/src/garage/torch/__init__.py b/src/garage/torch/__init__.py index cbe76695ae..60942140bf 100644 --- a/src/garage/torch/__init__.py +++ b/src/garage/torch/__init__.py @@ -4,7 +4,7 @@ compute_advantages, expand_var, filter_valids, flatten_batch, flatten_to_single_vector, global_device, - NonLinearity, output_height_2d, + Module, NonLinearity, output_height_2d, output_width_2d, pad_to_last, prefer_gpu, product_of_gaussians, set_gpu_mode, soft_update_model, state_dict_to, @@ -31,4 +31,5 @@ 'output_height_2d', 'expand_var', 'state_dict_to', + 'Module', ] diff --git a/src/garage/torch/_functions.py b/src/garage/torch/_functions.py index 3bc8391e6f..9e8a3aab97 100644 --- a/src/garage/torch/_functions.py +++ b/src/garage/torch/_functions.py @@ -331,11 +331,11 @@ def state_dict_to(state_dict, device): """Move optimizer to a specified device. Args: - state_dict (dict): state dictionary to be moved + state_dict (dict): state dictionary to be moved. device (str): ID of GPU or CPU. Returns: - dict: state dictionary moved to device + dict: state dictionary moved to device. """ for param in state_dict.values(): if isinstance(param, torch.Tensor): @@ -345,6 +345,61 @@ def state_dict_to(state_dict, device): return state_dict +# pylint: disable=abstract-method +class Module(nn.Module): + """Wrapper class for Garage PyTorch modules.""" + + def __getstate__(self): + """Save the current device of the module before saving module state. + + Returns: + dict: State dictionary. + """ + # do we always run experiments on global device? + save_from_device = global_device() + self.to('cpu') + state = self.__dict__.copy() + state['device'] = save_from_device + return state + + def __setstate__(self, state): + """Restore the module state, moving it back to the original device if possible. + + Args: + state (dict): State dictionary. + + """ + system_device = global_device() + save_from_device = state['device'] + if save_from_device == system_device: + module_state_to(state, system_device) + # what to do if it doesn't match? + # do I need to set global device to the current device? + del state['device'] + self.__dict__ = state + if save_from_device == system_device: + self.to(system_device) + + +def module_state_to(state, device): + """Move elements of a module state to a device. + + Notes - are there other types of parameters in a + module state to be moved? are some of them recursive? + + Args: + state (dict): State dictionary. + device (str): ID of GPU or CPU. + + Returns: + dict: moved state dict. + """ + for param in state.values(): + if hasattr(param, 'to'): + param = param.to(device) + return state + + # pylint: disable=W0223 class NonLinearity(nn.Module): """Wrapper class for non linear function or module. diff --git a/src/garage/torch/policies/policy.py b/src/garage/torch/policies/policy.py index 423065c89d..0bf14e79f3 100644 --- a/src/garage/torch/policies/policy.py +++ b/src/garage/torch/policies/policy.py @@ -1,12 +1,11 @@ """Base Policy.""" import abc -import torch - from garage.np.policies import Policy as BasePolicy +from garage.torch import Module -class Policy(torch.nn.Module, BasePolicy, abc.ABC): +class Policy(Module, BasePolicy, abc.ABC): """Policy base class. Args: diff --git a/tests/garage/torch/policies/test_categorical_cnn_policy.py b/tests/garage/torch/policies/test_categorical_cnn_policy.py index e7733edf30..a1ed1ceeac 100644 --- a/tests/garage/torch/policies/test_categorical_cnn_policy.py +++ b/tests/garage/torch/policies/test_categorical_cnn_policy.py @@ -4,6 +4,7 @@ import torch from garage.envs import GymEnv +from garage.torch import global_device, set_gpu_mode from garage.torch.policies import CategoricalCNNPolicy from tests.fixtures.envs.dummy import DummyDictEnv, DummyDiscretePixelEnv @@ -152,3 +153,36 @@ def test_obs_unflattened(self, hidden_channels, kernel_sizes, strides, obs = env.observation_space.sample() action, _ = policy.get_action(env.observation_space.flatten(obs)) env.step(action) + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'hidden_channels, kernel_sizes, strides, hidden_sizes', [ + ((3, ), (3, ), (1, ), (4, )), + ((3, 3), (3, 3), (1, 1), (4, 4)), + ((3, 3), (3, 3), (2, 2), (4, 4)), + ]) +def test_is_pickleable_on_gpu(hidden_channels, kernel_sizes, strides, + hidden_sizes): + """Test if policy is pickable when on gpu.""" + set_gpu_mode(True) + env = GymEnv(DummyDiscretePixelEnv(), is_image=True) + policy = CategoricalCNNPolicy(env_spec=env.spec, + image_format='NHWC', + kernel_sizes=kernel_sizes, + hidden_channels=hidden_channels, + strides=strides, + hidden_sizes=hidden_sizes) + policy.to(global_device()) + env.reset() + obs = env.step(1).observation + + output_action_1, _ = policy.get_action(obs) + + p = cloudpickle.dumps(policy) + policy_pickled = cloudpickle.loads(p) + output_action_2, _ = policy_pickled.get_action(obs) + + assert env.action_space.contains(output_action_1) + assert env.action_space.contains(output_action_2) + assert output_action_1.shape == output_action_2.shape