From 6461a071f0155712add1b41316003e90c9c77899 Mon Sep 17 00:00:00 2001 From: "K.R. Zentner" <41180126+krzentner@users.noreply.github.com> Date: Sun, 17 Apr 2022 01:31:26 -0700 Subject: [PATCH] Cloudpickle MultiprocessingSampler EnvUpdates (#2322) --- src/garage/sampler/multiprocessing_sampler.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/garage/sampler/multiprocessing_sampler.py b/src/garage/sampler/multiprocessing_sampler.py index 3135b46942..a1fab79399 100644 --- a/src/garage/sampler/multiprocessing_sampler.py +++ b/src/garage/sampler/multiprocessing_sampler.py @@ -78,7 +78,8 @@ def __init__( self._agents = self._factory.prepare_worker_messages( agents, cloudpickle.dumps) - self._envs = self._factory.prepare_worker_messages(envs) + self._envs = self._factory.prepare_worker_messages( + envs, cloudpickle.dumps) self._to_sampler = mp.Queue(2 * self._factory.n_workers) self._to_worker = [mp.Queue(1) for _ in range(self._factory.n_workers)] # If we crash from an exception, with full queues, we would rather not @@ -192,7 +193,8 @@ def obtain_samples(self, itr, num_samples, agent_update, env_update=None): updated_workers = set() agent_ups = self._factory.prepare_worker_messages( agent_update, cloudpickle.dumps) - env_ups = self._factory.prepare_worker_messages(env_update) + env_ups = self._factory.prepare_worker_messages( + env_update, cloudpickle.dumps) with click.progressbar(length=num_samples, label='Sampling') as pbar: while completed_samples < num_samples: @@ -260,7 +262,8 @@ def obtain_exact_episodes(self, updated_workers = set() agent_ups = self._factory.prepare_worker_messages( agent_update, cloudpickle.dumps) - env_ups = self._factory.prepare_worker_messages(env_update) + env_ups = self._factory.prepare_worker_messages( + env_update, cloudpickle.dumps) episodes = defaultdict(list) with click.progressbar(length=self._factory.n_workers, @@ -337,7 +340,7 @@ def __getstate__(self): return dict( factory=self._factory, agents=[cloudpickle.loads(agent) for agent in self._agents], - envs=self._envs) + envs=[cloudpickle.loads(env) for env in self._envs]) def __setstate__(self, state): """Unpickle the state. @@ -400,7 +403,7 @@ def run_worker(factory, to_worker, to_sampler, worker_number, agent, env): inner_worker = factory(worker_number) inner_worker.update_agent(cloudpickle.loads(agent)) - inner_worker.update_env(env) + inner_worker.update_env(cloudpickle.loads(env)) version = 0 streaming_samples = False @@ -422,7 +425,7 @@ def run_worker(factory, to_worker, to_sampler, worker_number, agent, env): # Update env and policy. agent_update, env_update, version = contents inner_worker.update_agent(cloudpickle.loads(agent_update)) - inner_worker.update_env(env_update) + inner_worker.update_env(cloudpickle.loads(env_update)) streaming_samples = True elif tag == 'stop': streaming_samples = False