Skip to content

Commit

Permalink
Torch save bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyiwu9494 committed Apr 30, 2021
1 parent f235e99 commit 3d4368d
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 16 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# Please keep alphabetized
'akro>=0.0.8',
'click>=2.0',
'cloudpickle',
# Older versions don't work with torch.save
'cloudpickle>=1.6.0',
'cma==2.7.0',
'dowel==0.0.3',
'numpy>=1.14.5',
Expand Down
20 changes: 12 additions & 8 deletions src/garage/experiment/snapshotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys

import cloudpickle
from dowel import logger

# pylint: disable=no-name-in-module

Expand Down Expand Up @@ -100,8 +101,9 @@ def save_itr_params(self, itr, params):
"""
# pylint: disable=import-outside-toplevel
torch = False
if torch in sys.modules:
if 'torch' in sys.modules:
import torch
from garage.torch import global_device
file_name = None
# pylint: enable=import-outside-toplevel
if self._snapshot_mode == 'all':
Expand Down Expand Up @@ -134,6 +136,7 @@ def save_itr_params(self, itr, params):

if file_name:
if torch:
params['global_device'] = global_device()
torch.save(params, file_name, pickle_module=cloudpickle)
else:
with open(file_name, 'wb') as file:
Expand Down Expand Up @@ -161,9 +164,9 @@ def load(self, load_dir, itr='last'):
"""
torch = False
# pylint: disable=import-outside-toplevel
if torch in sys.modules:
if 'torch' in sys.modules:
import torch
import garage.torch
from garage.torch import global_device
# pylint: enable=import-outside-toplevel
if isinstance(itr, int) or itr.isdigit():
load_from_file = os.path.join(load_dir, 'itr_{}.pkl'.format(itr))
Expand All @@ -185,12 +188,13 @@ def load(self, load_dir, itr='last'):

if not os.path.isfile(load_from_file):
raise NotAFileError('File not existing: ', load_from_file)

if torch:
device = garage.torch.global_device()
return torch.load(load_from_file,
map_location=device,
pickle_module=cloudpickle)
device = global_device()
params = torch.load(load_from_file, map_location=device)
origin_device = params['global_device']
del params['global_device']
logger.log(f'Resuming experiment from {origin_device} on {device}')
return params
with open(load_from_file, 'rb') as file:
return cloudpickle.load(file)

Expand Down
9 changes: 5 additions & 4 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Test fixtures."""
# yapf: disable
from tests.fixtures.fixtures import (snapshot_config,
TfGraphTestCase,
TfTestCase)
from tests.fixtures.fixtures import (reset_gpu_mode, snapshot_config,
TfGraphTestCase, TfTestCase)

# yapf: enable

__all__ = ['snapshot_config', 'TfGraphTestCase', 'TfTestCase']
__all__ = [
'reset_gpu_mode', 'snapshot_config', 'TfGraphTestCase', 'TfTestCase'
]
6 changes: 6 additions & 0 deletions tests/fixtures/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from garage.experiment import deterministic
from garage.experiment.snapshotter import SnapshotConfig
from garage.torch import set_gpu_mode

from tests.fixtures.logger import NullOutput

Expand Down Expand Up @@ -64,3 +65,8 @@ def teardown_method(self):
del self.graph
del self.sess
gc.collect()


def reset_gpu_mode():
"""Reset mode to CPU after test."""
set_gpu_mode(False)
6 changes: 5 additions & 1 deletion tests/garage/torch/algos/test_mtsac.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from garage.torch.q_functions import ContinuousMLPQFunction
from garage.trainer import Trainer

from tests.fixtures import snapshot_config
from tests.fixtures import reset_gpu_mode, snapshot_config


@pytest.mark.mujoco
Expand Down Expand Up @@ -178,6 +178,7 @@ def test_mtsac_inverted_double_pendulum():
assert ret > 0


@pytest.mark.serial
def test_to():
"""Test the torch function that moves modules to GPU.
Expand Down Expand Up @@ -236,8 +237,10 @@ def test_to():
for param in mtsac.policy.parameters():
assert param.device == device
assert mtsac._log_alpha.device == device
reset_gpu_mode()


@pytest.mark.serial
@pytest.mark.mujoco
def test_fixed_alpha():
"""Test if using fixed_alpha ensures that alpha is non differentiable."""
Expand Down Expand Up @@ -298,3 +301,4 @@ def test_fixed_alpha():
assert torch.allclose(torch.Tensor([0.5] * num_tasks),
mtsac._log_alpha.to('cpu'))
assert not mtsac._use_automatic_entropy_tuning
reset_gpu_mode()
6 changes: 5 additions & 1 deletion tests/garage/torch/algos/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from garage.torch.q_functions import ContinuousMLPQFunction
from garage.trainer import Trainer

from tests.fixtures import snapshot_config
from tests.fixtures import reset_gpu_mode, snapshot_config


class _MockDistribution:
Expand Down Expand Up @@ -177,6 +177,7 @@ def testTemperatureLoss():
assert np.all(np.isclose(loss, expected_loss))


@pytest.mark.serial
@pytest.mark.mujoco
def test_sac_inverted_double_pendulum():
"""Test Sac performance on inverted pendulum."""
Expand Down Expand Up @@ -234,6 +235,7 @@ def test_sac_inverted_double_pendulum():
assert not torch.allclose(torch.Tensor([1.]), sac._log_alpha.to('cpu'))
# check that policy is learning beyond predecided threshold
assert ret > 80
reset_gpu_mode()


@pytest.mark.mujoco
Expand Down Expand Up @@ -286,6 +288,7 @@ def test_fixed_alpha():
assert not sac._use_automatic_entropy_tuning


@pytest.mark.serial
@pytest.mark.gpu
def test_sac_to():
"""Test moving Sac between CPU and GPU."""
Expand Down Expand Up @@ -339,3 +342,4 @@ def test_sac_to():
set_gpu_mode(False)
sac.to()
assert torch.allclose(log_alpha, sac._log_alpha)
reset_gpu_mode()
6 changes: 5 additions & 1 deletion tests/garage/torch/algos/test_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from garage.torch.q_functions import ContinuousMLPQFunction
from garage.trainer import Trainer

from tests.fixtures import snapshot_config, TfGraphTestCase
from tests.fixtures import reset_gpu_mode, snapshot_config, TfGraphTestCase


class TestTD3(TfGraphTestCase):
"""Test class for TD3."""

@pytest.mark.serial
@pytest.mark.mujoco
def test_td3_inverted_double_pendulum(self):
deterministic.set_seed(0)
Expand Down Expand Up @@ -67,7 +68,9 @@ def test_td3_inverted_double_pendulum(self):
td3.to()
trainer.setup(td3, env)
trainer.train(n_epochs=n_epochs, batch_size=sampler_batch_size)
reset_gpu_mode()

@pytest.mark.serial
@pytest.mark.mujoco
def test_pickling(self):
"""Test pickle and unpickle."""
Expand Down Expand Up @@ -116,3 +119,4 @@ def test_pickling(self):
pickled = pickle.dumps(td3)
unpickled = pickle.loads(pickled)
assert unpickled
reset_gpu_mode()
Loading

0 comments on commit 3d4368d

Please sign in to comment.