From a9e083a084f30e514fd35ea9202483a4bf3fe4ac Mon Sep 17 00:00:00 2001 From: Ziyi Wu Date: Mon, 3 May 2021 20:41:35 -0700 Subject: [PATCH] Add cloudpickle wrapper for compatibility --- setup.py | 2 +- src/garage/experiment/snapshotter.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index c4857286c..eac74309b 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ 'akro>=0.0.8', 'click>=2.0', # Older versions don't work with torch.save - 'cloudpickle>=1.6.0', + 'cloudpickle', 'cma==2.7.0', 'dowel==0.0.3', 'numpy>=1.14.5', diff --git a/src/garage/experiment/snapshotter.py b/src/garage/experiment/snapshotter.py index a3f44ebe6..64dc89267 100644 --- a/src/garage/experiment/snapshotter.py +++ b/src/garage/experiment/snapshotter.py @@ -136,8 +136,12 @@ def save_itr_params(self, itr, params): if file_name: if torch: + + class _pickle_module: + Pickler = cloudpickle.CloudPickler + params['global_device'] = global_device() - torch.save(params, file_name, pickle_module=cloudpickle) + torch.save(params, file_name, pickle_module=_pickle_module) else: with open(file_name, 'wb') as file: cloudpickle.dump(params, file)