Skip to content

Commit

Permalink
Cloudpickle MultiprocessingSampler EnvUpdates (#2322)
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner authored Apr 17, 2022
1 parent c56513f commit 6461a07
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/garage/sampler/multiprocessing_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def __init__(

self._agents = self._factory.prepare_worker_messages(
agents, cloudpickle.dumps)
self._envs = self._factory.prepare_worker_messages(envs)
self._envs = self._factory.prepare_worker_messages(
envs, cloudpickle.dumps)
self._to_sampler = mp.Queue(2 * self._factory.n_workers)
self._to_worker = [mp.Queue(1) for _ in range(self._factory.n_workers)]
# If we crash from an exception, with full queues, we would rather not
Expand Down Expand Up @@ -192,7 +193,8 @@ def obtain_samples(self, itr, num_samples, agent_update, env_update=None):
updated_workers = set()
agent_ups = self._factory.prepare_worker_messages(
agent_update, cloudpickle.dumps)
env_ups = self._factory.prepare_worker_messages(env_update)
env_ups = self._factory.prepare_worker_messages(
env_update, cloudpickle.dumps)

with click.progressbar(length=num_samples, label='Sampling') as pbar:
while completed_samples < num_samples:
Expand Down Expand Up @@ -260,7 +262,8 @@ def obtain_exact_episodes(self,
updated_workers = set()
agent_ups = self._factory.prepare_worker_messages(
agent_update, cloudpickle.dumps)
env_ups = self._factory.prepare_worker_messages(env_update)
env_ups = self._factory.prepare_worker_messages(
env_update, cloudpickle.dumps)
episodes = defaultdict(list)

with click.progressbar(length=self._factory.n_workers,
Expand Down Expand Up @@ -337,7 +340,7 @@ def __getstate__(self):
return dict(
factory=self._factory,
agents=[cloudpickle.loads(agent) for agent in self._agents],
envs=self._envs)
envs=[cloudpickle.loads(env) for env in self._envs])

def __setstate__(self, state):
"""Unpickle the state.
Expand Down Expand Up @@ -400,7 +403,7 @@ def run_worker(factory, to_worker, to_sampler, worker_number, agent, env):

inner_worker = factory(worker_number)
inner_worker.update_agent(cloudpickle.loads(agent))
inner_worker.update_env(env)
inner_worker.update_env(cloudpickle.loads(env))

version = 0
streaming_samples = False
Expand All @@ -422,7 +425,7 @@ def run_worker(factory, to_worker, to_sampler, worker_number, agent, env):
# Update env and policy.
agent_update, env_update, version = contents
inner_worker.update_agent(cloudpickle.loads(agent_update))
inner_worker.update_env(env_update)
inner_worker.update_env(cloudpickle.loads(env_update))
streaming_samples = True
elif tag == 'stop':
streaming_samples = False
Expand Down

0 comments on commit 6461a07

Please sign in to comment.