diff --git a/setup.py b/setup.py index 0a1e0ff317..c4857286c4 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,8 @@ # Please keep alphabetized 'akro>=0.0.8', 'click>=2.0', - 'cloudpickle', + # Older versions don't work with torch.save + 'cloudpickle>=1.6.0', 'cma==2.7.0', 'dowel==0.0.3', 'numpy>=1.14.5', diff --git a/src/garage/experiment/snapshotter.py b/src/garage/experiment/snapshotter.py index dd5812e232..a3f44ebe67 100644 --- a/src/garage/experiment/snapshotter.py +++ b/src/garage/experiment/snapshotter.py @@ -6,6 +6,7 @@ import sys import cloudpickle +from dowel import logger # pylint: disable=no-name-in-module @@ -100,8 +101,9 @@ def save_itr_params(self, itr, params): """ # pylint: disable=import-outside-toplevel torch = False - if torch in sys.modules: + if 'torch' in sys.modules: import torch + from garage.torch import global_device file_name = None # pylint: enable=import-outside-toplevel if self._snapshot_mode == 'all': @@ -134,6 +136,7 @@ def save_itr_params(self, itr, params): if file_name: if torch: + params['global_device'] = global_device() torch.save(params, file_name, pickle_module=cloudpickle) else: with open(file_name, 'wb') as file: @@ -161,9 +164,9 @@ def load(self, load_dir, itr='last'): """ torch = False # pylint: disable=import-outside-toplevel - if torch in sys.modules: + if 'torch' in sys.modules: import torch - import garage.torch + from garage.torch import global_device # pylint: enable=import-outside-toplevel if isinstance(itr, int) or itr.isdigit(): load_from_file = os.path.join(load_dir, 'itr_{}.pkl'.format(itr)) @@ -185,12 +188,13 @@ def load(self, load_dir, itr='last'): if not os.path.isfile(load_from_file): raise NotAFileError('File not existing: ', load_from_file) - if torch: - device = garage.torch.global_device() - return torch.load(load_from_file, - map_location=device, - pickle_module=cloudpickle) + device = global_device() + params = torch.load(load_from_file, map_location=device) + origin_device = params['global_device'] + del params['global_device'] + logger.log(f'Resuming experiment from {origin_device} on {device}') + return params with open(load_from_file, 'rb') as file: return cloudpickle.load(file) diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 07d875604b..8ec18ac96a 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -1,9 +1,10 @@ """Test fixtures.""" # yapf: disable -from tests.fixtures.fixtures import (snapshot_config, - TfGraphTestCase, - TfTestCase) +from tests.fixtures.fixtures import (reset_gpu_mode, snapshot_config, + TfGraphTestCase, TfTestCase) # yapf: enable -__all__ = ['snapshot_config', 'TfGraphTestCase', 'TfTestCase'] +__all__ = [ + 'reset_gpu_mode', 'snapshot_config', 'TfGraphTestCase', 'TfTestCase' +] diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index edda5962b0..15bd98ce72 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -9,6 +9,7 @@ from garage.experiment import deterministic from garage.experiment.snapshotter import SnapshotConfig +from garage.torch import set_gpu_mode from tests.fixtures.logger import NullOutput @@ -64,3 +65,8 @@ def teardown_method(self): del self.graph del self.sess gc.collect() + + +def reset_gpu_mode(): + """Reset mode to CPU after test.""" + set_gpu_mode(False) diff --git a/tests/garage/torch/algos/test_mtsac.py b/tests/garage/torch/algos/test_mtsac.py index 9995551bb1..8a048ced70 100644 --- a/tests/garage/torch/algos/test_mtsac.py +++ b/tests/garage/torch/algos/test_mtsac.py @@ -15,7 +15,7 @@ from garage.torch.q_functions import ContinuousMLPQFunction from garage.trainer import Trainer -from tests.fixtures import snapshot_config +from tests.fixtures import reset_gpu_mode, snapshot_config @pytest.mark.mujoco @@ -178,6 +178,7 @@ def test_mtsac_inverted_double_pendulum(): assert ret > 0 +@pytest.mark.serial def test_to(): """Test the torch function that moves modules to GPU. @@ -236,8 +237,10 @@ def test_to(): for param in mtsac.policy.parameters(): assert param.device == device assert mtsac._log_alpha.device == device + reset_gpu_mode() +@pytest.mark.serial @pytest.mark.mujoco def test_fixed_alpha(): """Test if using fixed_alpha ensures that alpha is non differentiable.""" @@ -298,3 +301,4 @@ def test_fixed_alpha(): assert torch.allclose(torch.Tensor([0.5] * num_tasks), mtsac._log_alpha.to('cpu')) assert not mtsac._use_automatic_entropy_tuning + reset_gpu_mode() diff --git a/tests/garage/torch/algos/test_sac.py b/tests/garage/torch/algos/test_sac.py index 856946aae3..fb87d05f8f 100644 --- a/tests/garage/torch/algos/test_sac.py +++ b/tests/garage/torch/algos/test_sac.py @@ -16,7 +16,7 @@ from garage.torch.q_functions import ContinuousMLPQFunction from garage.trainer import Trainer -from tests.fixtures import snapshot_config +from tests.fixtures import reset_gpu_mode, snapshot_config class _MockDistribution: @@ -177,6 +177,7 @@ def testTemperatureLoss(): assert np.all(np.isclose(loss, expected_loss)) +@pytest.mark.serial @pytest.mark.mujoco def test_sac_inverted_double_pendulum(): """Test Sac performance on inverted pendulum.""" @@ -234,6 +235,7 @@ def test_sac_inverted_double_pendulum(): assert not torch.allclose(torch.Tensor([1.]), sac._log_alpha.to('cpu')) # check that policy is learning beyond predecided threshold assert ret > 80 + reset_gpu_mode() @pytest.mark.mujoco @@ -286,6 +288,7 @@ def test_fixed_alpha(): assert not sac._use_automatic_entropy_tuning +@pytest.mark.serial @pytest.mark.gpu def test_sac_to(): """Test moving Sac between CPU and GPU.""" @@ -339,3 +342,4 @@ def test_sac_to(): set_gpu_mode(False) sac.to() assert torch.allclose(log_alpha, sac._log_alpha) + reset_gpu_mode() diff --git a/tests/garage/torch/algos/test_td3.py b/tests/garage/torch/algos/test_td3.py index 524649a1d6..74aecb1170 100644 --- a/tests/garage/torch/algos/test_td3.py +++ b/tests/garage/torch/algos/test_td3.py @@ -15,12 +15,13 @@ from garage.torch.q_functions import ContinuousMLPQFunction from garage.trainer import Trainer -from tests.fixtures import snapshot_config, TfGraphTestCase +from tests.fixtures import reset_gpu_mode, snapshot_config, TfGraphTestCase class TestTD3(TfGraphTestCase): """Test class for TD3.""" + @pytest.mark.serial @pytest.mark.mujoco def test_td3_inverted_double_pendulum(self): deterministic.set_seed(0) @@ -67,7 +68,9 @@ def test_td3_inverted_double_pendulum(self): td3.to() trainer.setup(td3, env) trainer.train(n_epochs=n_epochs, batch_size=sampler_batch_size) + reset_gpu_mode() + @pytest.mark.serial @pytest.mark.mujoco def test_pickling(self): """Test pickle and unpickle.""" @@ -116,3 +119,4 @@ def test_pickling(self): pickled = pickle.dumps(td3) unpickled = pickle.loads(pickled) assert unpickled + reset_gpu_mode() diff --git a/tests/garage/torch/algos/test_torch_resume.py b/tests/garage/torch/algos/test_torch_resume.py new file mode 100644 index 0000000000..43277ab758 --- /dev/null +++ b/tests/garage/torch/algos/test_torch_resume.py @@ -0,0 +1,252 @@ +"""This script creates a test which fails when + saving/resuming a model is unsuccessful.""" + +import tempfile + +import numpy as np +import pytest +import torch +from torch.nn import functional as F + +from garage.envs import GymEnv, normalize +from garage.experiment import deterministic, SnapshotConfig +from garage.replay_buffer import PathBuffer +from garage.sampler import FragmentWorker, LocalSampler +from garage.torch import set_gpu_mode +from garage.torch.algos import SAC +from garage.torch.policies import TanhGaussianMLPPolicy +from garage.torch.q_functions import ContinuousMLPQFunction +from garage.trainer import Trainer + + +@pytest.mark.mujoco +def test_torch_cpu_resume_cpu(): + """Test saving on CPU and resuming on CPU.""" + temp_dir = tempfile.TemporaryDirectory() + snapshot_config = SnapshotConfig(snapshot_dir=temp_dir.name, + snapshot_mode='last', + snapshot_gap=1) + env = normalize(GymEnv('InvertedDoublePendulum-v2', + max_episode_length=100)) + deterministic.set_seed(0) + policy = TanhGaussianMLPPolicy( + env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=torch.nn.ReLU, + output_nonlinearity=None, + min_std=np.exp(-20.), + max_std=np.exp(2.), + ) + + qf1 = ContinuousMLPQFunction(env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=F.relu) + + qf2 = ContinuousMLPQFunction(env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=F.relu) + replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), ) + trainer = Trainer(snapshot_config=snapshot_config) + sampler = LocalSampler(agents=policy, + envs=env, + max_episode_length=env.spec.max_episode_length, + worker_class=FragmentWorker) + sac = SAC(env_spec=env.spec, + policy=policy, + qf1=qf1, + qf2=qf2, + sampler=sampler, + gradient_steps_per_itr=100, + replay_buffer=replay_buffer, + min_buffer_size=1e3, + target_update_tau=5e-3, + discount=0.99, + buffer_batch_size=64, + reward_scale=1., + steps_per_epoch=2) + sac.has_lambda = lambda x: x + 1 + trainer.setup(sac, env) + set_gpu_mode(False) + sac.to() + trainer.setup(algo=sac, env=env) + trainer.train(n_epochs=10, batch_size=100) + trainer = Trainer(snapshot_config) + trainer.restore(temp_dir.name) + trainer.resume(n_epochs=20) + temp_dir.cleanup() + + +@pytest.mark.gpu +@pytest.mark.mujoco +def test_torch_cpu_resume_gpu(): + """Test saving on CPU and resuming on GPU.""" + temp_dir = tempfile.TemporaryDirectory() + snapshot_config = SnapshotConfig(snapshot_dir=temp_dir.name, + snapshot_mode='last', + snapshot_gap=1) + env = normalize(GymEnv('InvertedDoublePendulum-v2', + max_episode_length=100)) + deterministic.set_seed(0) + policy = TanhGaussianMLPPolicy( + env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=torch.nn.ReLU, + output_nonlinearity=None, + min_std=np.exp(-20.), + max_std=np.exp(2.), + ) + + qf1 = ContinuousMLPQFunction(env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=F.relu) + + qf2 = ContinuousMLPQFunction(env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=F.relu) + replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), ) + trainer = Trainer(snapshot_config=snapshot_config) + sampler = LocalSampler(agents=policy, + envs=env, + max_episode_length=env.spec.max_episode_length, + worker_class=FragmentWorker) + sac = SAC(env_spec=env.spec, + policy=policy, + qf1=qf1, + qf2=qf2, + sampler=sampler, + gradient_steps_per_itr=100, + replay_buffer=replay_buffer, + min_buffer_size=1e3, + target_update_tau=5e-3, + discount=0.99, + buffer_batch_size=64, + reward_scale=1., + steps_per_epoch=2) + sac.has_lambda = lambda x: x + 1 + trainer.setup(sac, env) + set_gpu_mode(False) + sac.to() + trainer.setup(algo=sac, env=env) + trainer.train(n_epochs=10, batch_size=100) + trainer = Trainer(snapshot_config) + set_gpu_mode(True) + trainer.restore(temp_dir.name) + trainer.resume(n_epochs=20) + temp_dir.cleanup() + + +@pytest.mark.gpu +@pytest.mark.mujoco +def test_torch_gpu_resume_cpu(): + """Test saving on GPU and resuming on CPU.""" + temp_dir = tempfile.TemporaryDirectory() + snapshot_config = SnapshotConfig(snapshot_dir=temp_dir.name, + snapshot_mode='last', + snapshot_gap=1) + env = normalize(GymEnv('InvertedDoublePendulum-v2', + max_episode_length=100)) + deterministic.set_seed(0) + policy = TanhGaussianMLPPolicy( + env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=torch.nn.ReLU, + output_nonlinearity=None, + min_std=np.exp(-20.), + max_std=np.exp(2.), + ) + + qf1 = ContinuousMLPQFunction(env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=F.relu) + + qf2 = ContinuousMLPQFunction(env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=F.relu) + replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), ) + trainer = Trainer(snapshot_config=snapshot_config) + sampler = LocalSampler(agents=policy, + envs=env, + max_episode_length=env.spec.max_episode_length, + worker_class=FragmentWorker) + sac = SAC(env_spec=env.spec, + policy=policy, + qf1=qf1, + qf2=qf2, + sampler=sampler, + gradient_steps_per_itr=100, + replay_buffer=replay_buffer, + min_buffer_size=1e3, + target_update_tau=5e-3, + discount=0.99, + buffer_batch_size=64, + reward_scale=1., + steps_per_epoch=2) + sac.has_lambda = lambda x: x + 1 + trainer.setup(sac, env) + set_gpu_mode(True) + sac.to() + trainer.setup(algo=sac, env=env) + trainer.train(n_epochs=10, batch_size=100) + set_gpu_mode(False) + trainer = Trainer(snapshot_config) + trainer.restore(temp_dir.name) + trainer.resume(n_epochs=20) + temp_dir.cleanup() + + +@pytest.mark.gpu +@pytest.mark.mujoco +def test_torch_gpu_resume_gpu(): + """Test saving on GPU and resuming on GPU.""" + temp_dir = tempfile.TemporaryDirectory() + snapshot_config = SnapshotConfig(snapshot_dir=temp_dir.name, + snapshot_mode='last', + snapshot_gap=1) + env = normalize(GymEnv('InvertedDoublePendulum-v2', + max_episode_length=100)) + deterministic.set_seed(0) + policy = TanhGaussianMLPPolicy( + env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=torch.nn.ReLU, + output_nonlinearity=None, + min_std=np.exp(-20.), + max_std=np.exp(2.), + ) + + qf1 = ContinuousMLPQFunction(env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=F.relu) + + qf2 = ContinuousMLPQFunction(env_spec=env.spec, + hidden_sizes=[32, 32], + hidden_nonlinearity=F.relu) + replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), ) + trainer = Trainer(snapshot_config=snapshot_config) + sampler = LocalSampler(agents=policy, + envs=env, + max_episode_length=env.spec.max_episode_length, + worker_class=FragmentWorker) + sac = SAC(env_spec=env.spec, + policy=policy, + qf1=qf1, + qf2=qf2, + sampler=sampler, + gradient_steps_per_itr=100, + replay_buffer=replay_buffer, + min_buffer_size=1e3, + target_update_tau=5e-3, + discount=0.99, + buffer_batch_size=64, + reward_scale=1., + steps_per_epoch=2) + sac.has_lambda = lambda x: x + 1 + trainer.setup(sac, env) + set_gpu_mode(True) + sac.to() + trainer.setup(algo=sac, env=env) + trainer.train(n_epochs=10, batch_size=100) + trainer = Trainer(snapshot_config) + trainer.restore(temp_dir.name) + trainer.resume(n_epochs=20) + temp_dir.cleanup()