From 3fb4f5073ed1b256dccadaaa6ee66ce213373339 Mon Sep 17 00:00:00 2001 From: ziyiwu9494 Date: Sun, 11 Apr 2021 12:06:52 -0700 Subject: [PATCH] Preserve sac log alpha when moving between CPU and GPU (#2260) --- .github/workflows/ci.yml | 2 +- docs/requirements.txt | 1 + setup.cfg | 1 + src/garage/torch/__init__.py | 28 ++++++++++---- src/garage/torch/_functions.py | 18 +++++++++ src/garage/torch/algos/sac.py | 14 +++++-- tests/garage/torch/algos/test_sac.py | 55 +++++++++++++++++++++++++++ tests/garage/torch/test_functions.py | 56 ++++++++++++++++++++++++++-- 8 files changed, 161 insertions(+), 14 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 68204e13d6..94c4fec136 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -186,7 +186,7 @@ jobs: /bin/bash -c \ '[ ! -f ${MJKEY_PATH} ] || mv ${MJKEY_PATH} ${MJKEY_PATH}.bak && pytest --cov=garage --cov-report=xml --reruns 1 -m \ - "not nightly and not huge and not flaky and not large and not mujoco and not mujoco_long" --durations=20 && + "not gpu and not nightly and not huge and not flaky and not large and not mujoco and not mujoco_long" --durations=20 && for i in {1..5}; do bash <(curl -s https://codecov.io/bash --retry 5) -Z && break if [ $i == 5 ]; then diff --git a/docs/requirements.txt b/docs/requirements.txt index 3e7b8fab54..08bcef6b5e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -17,3 +17,4 @@ scipy setproctitle tensorflow tensorflow-probability +docutils<0.17 diff --git a/setup.cfg b/setup.cfg index 899b02abb0..9e4d5aecd7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,6 +72,7 @@ markers = serial mujoco mujoco_long + gpu [yapf] based_on_style = pep8 diff --git a/src/garage/torch/__init__.py b/src/garage/torch/__init__.py index 001de12142..cbe76695ae 100644 --- a/src/garage/torch/__init__.py +++ b/src/garage/torch/__init__.py @@ -7,14 +7,28 @@ NonLinearity, output_height_2d, output_width_2d, pad_to_last, prefer_gpu, product_of_gaussians, set_gpu_mode, - soft_update_model, torch_to_np, - update_module_params) + soft_update_model, state_dict_to, + torch_to_np, update_module_params) # yapf: enable __all__ = [ - 'compute_advantages', 'as_torch_dict', 'filter_valids', 'flatten_batch', - 'global_device', 'as_torch', 'pad_to_last', 'prefer_gpu', - 'product_of_gaussians', 'set_gpu_mode', 'soft_update_model', 'torch_to_np', - 'update_module_params', 'NonLinearity', 'flatten_to_single_vector', - 'output_width_2d', 'output_height_2d', 'expand_var' + 'compute_advantages', + 'as_torch_dict', + 'filter_valids', + 'flatten_batch', + 'global_device', + 'as_torch', + 'pad_to_last', + 'prefer_gpu', + 'product_of_gaussians', + 'set_gpu_mode', + 'soft_update_model', + 'torch_to_np', + 'update_module_params', + 'NonLinearity', + 'flatten_to_single_vector', + 'output_width_2d', + 'output_height_2d', + 'expand_var', + 'state_dict_to', ] diff --git a/src/garage/torch/_functions.py b/src/garage/torch/_functions.py index 64c0f61d54..3bc8391e6f 100644 --- a/src/garage/torch/_functions.py +++ b/src/garage/torch/_functions.py @@ -327,6 +327,24 @@ def product_of_gaussians(mus, sigmas_squared): return mu, sigma_squared +def state_dict_to(state_dict, device): + """Move optimizer to a specified device. + + Args: + state_dict (dict): state dictionary to be moved + device (str): ID of GPU or CPU. + + Returns: + dict: state dictionary moved to device + """ + for param in state_dict.values(): + if isinstance(param, torch.Tensor): + param.data = param.data.to(device) + elif isinstance(param, dict): + state_dict_to(param, device) + return state_dict + + # pylint: disable=W0223 class NonLinearity(nn.Module): """Wrapper class for non linear function or module. diff --git a/src/garage/torch/algos/sac.py b/src/garage/torch/algos/sac.py index 4d328dc912..4c64caedfe 100644 --- a/src/garage/torch/algos/sac.py +++ b/src/garage/torch/algos/sac.py @@ -10,7 +10,7 @@ from garage import log_performance, obtain_evaluation_episodes, StepType from garage.np.algos import RLAlgorithm -from garage.torch import as_torch_dict, global_device +from garage.torch import as_torch_dict, global_device, state_dict_to # yapf: enable @@ -524,7 +524,15 @@ def to(self, device=None): self._log_alpha = torch.Tensor([self._fixed_alpha ]).log().to(device) else: - self._log_alpha = torch.Tensor([self._initial_log_entropy - ]).to(device).requires_grad_() + self._log_alpha = self._log_alpha.detach().to( + device).requires_grad_() self._alpha_optimizer = self._optimizer([self._log_alpha], lr=self._policy_lr) + self._alpha_optimizer.load_state_dict( + state_dict_to(self._alpha_optimizer.state_dict(), device)) + self._qf1_optimizer.load_state_dict( + state_dict_to(self._qf1_optimizer.state_dict(), device)) + self._qf2_optimizer.load_state_dict( + state_dict_to(self._qf2_optimizer.state_dict(), device)) + self._policy_optimizer.load_state_dict( + state_dict_to(self._policy_optimizer.state_dict(), device)) diff --git a/tests/garage/torch/algos/test_sac.py b/tests/garage/torch/algos/test_sac.py index 9bb86d89ee..856946aae3 100644 --- a/tests/garage/torch/algos/test_sac.py +++ b/tests/garage/torch/algos/test_sac.py @@ -284,3 +284,58 @@ def test_fixed_alpha(): trainer.train(n_epochs=1, batch_size=100, plot=False) assert torch.allclose(torch.Tensor([0.5]), sac._log_alpha.cpu()) assert not sac._use_automatic_entropy_tuning + + +@pytest.mark.gpu +def test_sac_to(): + """Test moving Sac between CPU and GPU.""" + env = normalize(GymEnv('InvertedDoublePendulum-v2', + max_episode_length=100)) + deterministic.set_seed(0) + policy = TanhGaussianMLPPolicy( + env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=torch.nn.ReLU, + output_nonlinearity=None, + min_std=np.exp(-20.), + max_std=np.exp(2.), + ) + + qf1 = ContinuousMLPQFunction(env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=F.relu) + + qf2 = ContinuousMLPQFunction(env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=F.relu) + replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), ) + trainer = Trainer(snapshot_config=snapshot_config) + sampler = LocalSampler(agents=policy, + envs=env, + max_episode_length=env.spec.max_episode_length, + worker_class=FragmentWorker) + sac = SAC(env_spec=env.spec, + policy=policy, + qf1=qf1, + qf2=qf2, + sampler=sampler, + gradient_steps_per_itr=100, + replay_buffer=replay_buffer, + min_buffer_size=1e3, + target_update_tau=5e-3, + discount=0.99, + buffer_batch_size=64, + reward_scale=1., + steps_per_epoch=2) + trainer.setup(sac, env) + if torch.cuda.is_available(): + set_gpu_mode(True) + else: + set_gpu_mode(False) + sac.to() + trainer.setup(algo=sac, env=env) + trainer.train(n_epochs=1, batch_size=100) + log_alpha = torch.clone(sac._log_alpha).cpu() + set_gpu_mode(False) + sac.to() + assert torch.allclose(log_alpha, sac._log_alpha) diff --git a/tests/garage/torch/test_functions.py b/tests/garage/torch/test_functions.py index f41dadcdda..b89b0c0042 100644 --- a/tests/garage/torch/test_functions.py +++ b/tests/garage/torch/test_functions.py @@ -1,14 +1,21 @@ """Module to test garage.torch._functions.""" # yapf: disable +import collections + import numpy as np import pytest import torch +from torch import tensor import torch.nn.functional as F +from garage.envs import GymEnv, normalize +from garage.experiment.deterministic import set_seed from garage.torch import (as_torch_dict, compute_advantages, flatten_to_single_vector, global_device, pad_to_last, - product_of_gaussians, set_gpu_mode, torch_to_np) + product_of_gaussians, set_gpu_mode, state_dict_to, + torch_to_np) import garage.torch._functions as tu +from garage.torch.policies import DeterministicMLPPolicy from tests.fixtures import TfGraphTestCase @@ -56,8 +63,8 @@ def test_as_torch_dict(): """Test if dict whose values are tensors can be converted to np arrays.""" dic = {'a': np.zeros(1), 'b': np.ones(1)} as_torch_dict(dic) - for tensor in dic.values(): - assert isinstance(tensor, torch.Tensor) + for dic_value in dic.values(): + assert isinstance(dic_value, torch.Tensor) def test_product_of_gaussians(): @@ -80,6 +87,49 @@ def test_flatten_to_single_vector(): assert expected.shape == flatten_tensor.shape +@pytest.mark.gpu +def test_state_dict_to(): + """Test state_dict_to""" + set_seed(42) + # Using tensor instead of Tensor so it can be declared on GPU + # pylint: disable=not-callable + expected = collections.OrderedDict([ + ('_module._layers.0.linear.weight', + tensor([[ + 0.13957974, -0.2693157, -0.19351028, 0.09471931, -0.43573233, + 0.03590716, -0.4272097, -0.13935488, -0.35843086, -0.25814268, + 0.03060348 + ], + [ + 0.20623916, -0.1914061, 0.46729338, -0.5437773, + -0.50449526, -0.55039907, 0.0141218, -0.02489783, + 0.26499796, -0.03836302, 0.7235093 + ]], + device='cuda:0')), + ('_module._layers.0.linear.bias', tensor([0., 0.], device='cuda:0')), + ('_module._layers.1.linear.weight', + tensor([[-0.7181905, -0.6284401], [0.10591025, -0.14771031]], + device='cuda:0')), + ('_module._layers.1.linear.bias', tensor([0., 0.], device='cuda:0')), + ('_module._output_layers.0.linear.weight', + tensor([[-0.29133463, 0.58353233]], device='cuda:0')), + ('_module._output_layers.0.linear.bias', tensor([0.], device='cuda:0')) + ]) + # pylint: enable=not-callable + env = normalize(GymEnv('InvertedDoublePendulum-v2')) + policy = DeterministicMLPPolicy(env_spec=env.spec, + hidden_sizes=[2, 2], + hidden_nonlinearity=F.relu, + output_nonlinearity=torch.tanh) + moved_state_dict = state_dict_to(policy.state_dict(), 'cuda') + assert np.all([ + torch.allclose(expected[key], moved_state_dict[key]) + for key in expected.keys() + ]) + assert np.all( + [moved_state_dict[key].is_cuda for key in moved_state_dict.keys()]) + + class TestTorchAlgoUtils(TfGraphTestCase): """Test class for torch algo utility functions.""" # yapf: disable