Skip to content

Commit

Permalink
Snapshotter using Torch save and load
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyiwu9494 committed Apr 22, 2021
1 parent 90b6090 commit f235e99
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions src/garage/experiment/snapshotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import errno
import os
import pathlib
import sys

import cloudpickle

# pylint: disable=no-name-in-module

SnapshotConfig = collections.namedtuple(
'SnapshotConfig', ['snapshot_dir', 'snapshot_mode', 'snapshot_gap'])

Expand Down Expand Up @@ -82,6 +85,7 @@ def snapshot_gap(self):
"""
return self._snapshot_gap

# pylint: disable=too-many-branches
def save_itr_params(self, itr, params):
"""Save the parameters if at the right iteration.
Expand All @@ -94,8 +98,12 @@ def save_itr_params(self, itr, params):
"gap_overwrite", "gap_and_last", or "none".
"""
# pylint: disable=import-outside-toplevel
torch = False
if torch in sys.modules:
import torch
file_name = None

# pylint: enable=import-outside-toplevel
if self._snapshot_mode == 'all':
file_name = os.path.join(self._snapshot_dir, 'itr_%d.pkl' % itr)
elif self._snapshot_mode == 'gap_overwrite':
Expand All @@ -113,17 +121,23 @@ def save_itr_params(self, itr, params):
file_name = os.path.join(self._snapshot_dir,
'itr_%d.pkl' % itr)
file_name_last = os.path.join(self._snapshot_dir, 'params.pkl')
with open(file_name_last, 'wb') as file:
cloudpickle.dump(params, file)
if torch:
torch.save(params, file_name_last, pickle_module=cloudpickle)
else:
with open(file_name_last, 'wb') as file:
cloudpickle.dump(params, file)
elif self._snapshot_mode == 'none':
pass
else:
raise ValueError('Invalid snapshot mode {}'.format(
self._snapshot_mode))

if file_name:
with open(file_name, 'wb') as file:
cloudpickle.dump(params, file)
if torch:
torch.save(params, file_name, pickle_module=cloudpickle)
else:
with open(file_name, 'wb') as file:
cloudpickle.dump(params, file)

def load(self, load_dir, itr='last'):
# pylint: disable=no-self-use
Expand All @@ -145,6 +159,12 @@ def load(self, load_dir, itr='last'):
NotAFileError: If the snapshot exists but is not a file.
"""
torch = False
# pylint: disable=import-outside-toplevel
if torch in sys.modules:
import torch
import garage.torch
# pylint: enable=import-outside-toplevel
if isinstance(itr, int) or itr.isdigit():
load_from_file = os.path.join(load_dir, 'itr_{}.pkl'.format(itr))
else:
Expand All @@ -166,6 +186,11 @@ def load(self, load_dir, itr='last'):
if not os.path.isfile(load_from_file):
raise NotAFileError('File not existing: ', load_from_file)

if torch:
device = garage.torch.global_device()
return torch.load(load_from_file,
map_location=device,
pickle_module=cloudpickle)
with open(load_from_file, 'rb') as file:
return cloudpickle.load(file)

Expand Down

0 comments on commit f235e99

Please sign in to comment.