Skip to content

Commit

Permalink
Refactor RL2 to use EpisodeBatch (#2138)
Browse files Browse the repository at this point in the history
* Refactor RL2 to use EpisodeBatch

* Fix isort
  • Loading branch information
yeukfu authored Oct 28, 2020
1 parent a63349a commit 4312678
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 66 deletions.
3 changes: 1 addition & 2 deletions examples/torch/maml_trpo_metaworld_ml1_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

from garage import wrap_experiment
from garage.envs import MetaWorldSetTaskEnv
from garage.experiment import (MetaEvaluator,
MetaWorldTaskSampler,
from garage.experiment import (MetaEvaluator, MetaWorldTaskSampler,
SetTaskSampler)
from garage.experiment.deterministic import set_seed
from garage.torch.algos import MAMLTRPO
Expand Down
4 changes: 1 addition & 3 deletions src/garage/tf/algos/reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
from garage import _Default, log_performance, make_optimizer
from garage.np.algos import RLAlgorithm
from garage.sampler import RaySampler
from garage.tf import (compile_function,
flatten_inputs,
graph_inputs,
from garage.tf import (compile_function, flatten_inputs, graph_inputs,
new_tensor)
from garage.tf.optimizers import LBFGSOptimizer

Expand Down
90 changes: 47 additions & 43 deletions src/garage/tf/algos/rl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from garage import (EnvSpec, EnvStep, EpisodeBatch, log_multitask_performance,
StepType, Wrapper)
from garage.np import concat_tensor_dict_list, discount_cumsum
from garage.np.algos import MetaRLAlgorithm
from garage.sampler import DefaultWorker
from garage.tf.algos._rl2npo import RL2NPO
Expand Down Expand Up @@ -339,7 +338,7 @@ def train(self, trainer):
if trainer.step_itr % self._n_epochs_per_eval == 0:
if self._meta_evaluator is not None:
self._meta_evaluator.evaluate(self)
trainer.step_episode = trainer.obtain_samples(
trainer.step_episode = trainer.obtain_episodes(
trainer.step_itr,
env_update=self._task_sampler.sample(self._meta_batch_size))
last_return = self.train_once(trainer.step_itr,
Expand All @@ -348,18 +347,18 @@ def train(self, trainer):

return last_return

def train_once(self, itr, paths):
def train_once(self, itr, episodes):
"""Perform one step of policy optimization given one batch of samples.
Args:
itr (int): Iteration number.
paths (list[dict]): A list of collected paths.
episodes (EpisodeBatch): Batch of episodes.
Returns:
numpy.float64: Average return.
"""
episodes, average_return = self._process_samples(itr, paths)
episodes, average_return = self._process_samples(itr, episodes)
logger.log('Optimizing policy...')
self._inner_algo.optimize_policy(episodes)
return average_return
Expand Down Expand Up @@ -400,16 +399,17 @@ def adapt_policy(self, exploration_policy, exploration_episodes):
return RL2AdaptedPolicy(exploration_policy._policy)

# pylint: disable=protected-access
def _process_samples(self, itr, paths):
def _process_samples(self, itr, episodes):
# pylint: disable=too-many-statements
"""Return processed sample data based on the collected paths.
Args:
itr (int): Iteration number.
paths (OrderedDict[dict]): A list of collected paths for each
task. In RL^2, there are n environments/tasks and paths in
each of them will be concatenated at some point and fed to
the policy.
episodes (EpisodeBatch): Original collected episode batch for each
task. For each episode, episode.agent_infos['batch_idx']
indicates which task this episode belongs to. In RL^2, there
are n environments/tasks and paths in each of them will be
concatenated at some point and fed to the policy.
Returns:
EpisodeBatch: Processed batch of episodes for feeding the inner
Expand All @@ -423,24 +423,25 @@ def _process_samples(self, itr, paths):
concatenated_paths = []

paths_by_task = collections.defaultdict(list)
for path in paths:
path['returns'] = discount_cumsum(path['rewards'], self._discount)
path['lengths'] = [len(path['rewards'])]
if 'batch_idx' in path:
paths_by_task[path['batch_idx']].append(path)
elif 'batch_idx' in path['agent_infos']:
paths_by_task[path['agent_infos']['batch_idx'][0]].append(path)
for episode in episodes.split():
if hasattr(episode, 'batch_idx'):
paths_by_task[episode.batch_idx[0]].append(episode)
elif 'batch_idx' in episode.agent_infos:
paths_by_task[episode.agent_infos['batch_idx'][0]].append(
episode)
else:
raise ValueError(
'Batch idx is required for RL2 but not found, '
'Make sure to use garage.tf.algos.rl2.RL2Worker '
'for sampling')

# all path in paths_by_task[i] are sampled from task[i]
for _paths in paths_by_task.values():
concatenated_path = self._concatenate_paths(_paths)
for episode_list in paths_by_task.values():
concatenated_path = self._concatenate_paths(episode_list)
concatenated_paths.append(concatenated_path)

concatenated_episodes = EpisodeBatch.concatenate(*concatenated_paths)

name_map = None
if hasattr(self._task_sampler, '_envs') and hasattr(
self._task_sampler._envs[0]._env, 'all_task_names'):
Expand All @@ -450,26 +451,22 @@ def _process_samples(self, itr, paths):
name_map = dict(enumerate(names))

undiscounted_returns = log_multitask_performance(
itr,
EpisodeBatch.from_list(self._env_spec, paths),
self._inner_algo._discount,
name_map=name_map)
itr, episodes, self._inner_algo._discount, name_map=name_map)

average_return = np.mean(undiscounted_returns)
episodes = EpisodeBatch.from_list(self._env_spec, concatenated_paths)

return episodes, average_return
return concatenated_episodes, average_return

def _concatenate_paths(self, paths):
def _concatenate_paths(self, episode_list):
"""Concatenate paths.
The input paths are from different episodes but same task/environment.
In RL^2, paths within each meta batch are all concatenate into a single
path and fed to the policy.
Args:
paths (dict): Input paths. All paths are from different episodes,
but the same task/environment.
episode_list (list[EpisodeBatch]): Input paths. All paths are from
different episodes, but the same task/environment.
Returns:
dict: Concatenated paths from the same task/environment. Shape of
Expand All @@ -479,23 +476,30 @@ def _concatenate_paths(self, paths):
values of shape :math:`[max_episode_length, S^*]`
"""
observations = np.concatenate([path['observations'] for path in paths])
env_infos = {
k: np.concatenate([b.env_infos[k] for b in episode_list])
for k in episode_list[0].env_infos.keys()
}
agent_infos = {
k: np.concatenate([b.agent_infos[k] for b in episode_list])
for k in episode_list[0].agent_infos.keys()
}
actions = np.concatenate([
self._env_spec.action_space.flatten_n(path['actions'])
for path in paths
self._env_spec.action_space.flatten_n(ep.actions)
for ep in episode_list
])
valids = np.concatenate(
[np.ones_like(path['rewards']) for path in paths])
baselines = np.concatenate(
[np.zeros_like(path['rewards']) for path in paths])

concatenated_path = concat_tensor_dict_list(paths)
concatenated_path['observations'] = observations
concatenated_path['actions'] = actions
concatenated_path['valids'] = valids
concatenated_path['baselines'] = baselines

return concatenated_path

return EpisodeBatch(
env_spec=episode_list[0].env_spec,
observations=np.concatenate(
[ep.observations for ep in episode_list]),
last_observations=episode_list[-1].last_observations,
actions=actions,
rewards=np.concatenate([ep.rewards for ep in episode_list]),
env_infos=env_infos,
agent_infos=agent_infos,
step_types=np.concatenate([ep.step_types for ep in episode_list]),
lengths=np.asarray([sum([ep.lengths[0] for ep in episode_list])]))

@property
def policy(self):
Expand Down
16 changes: 4 additions & 12 deletions src/garage/tf/algos/te_npo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,13 @@

from garage import InOutSpec, log_performance
from garage.experiment import deterministic
from garage.np import (discount_cumsum,
explained_variance_1d,
rrse,
from garage.np import (discount_cumsum, explained_variance_1d, rrse,
sliding_window)
from garage.np.algos import RLAlgorithm
from garage.sampler import LocalSampler
from garage.tf import (center_advs,
compile_function,
compute_advantages,
concat_tensor_list,
discounted_returns,
flatten_inputs,
graph_inputs,
pad_tensor_dict,
positive_advs,
from garage.tf import (center_advs, compile_function, compute_advantages,
concat_tensor_list, discounted_returns, flatten_inputs,
graph_inputs, pad_tensor_dict, positive_advs,
stack_tensor_dict_list)
from garage.tf.embeddings import StochasticEncoder
from garage.tf.optimizers import LBFGSOptimizer
Expand Down
8 changes: 2 additions & 6 deletions src/garage/torch/algos/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@
import numpy as np
import torch

from garage import (_Default,
EpisodeBatch,
log_performance,
make_optimizer,
obtain_evaluation_episodes,
TimeStepBatch)
from garage import (_Default, EpisodeBatch, log_performance, make_optimizer,
obtain_evaluation_episodes, TimeStepBatch)
from garage.np.algos.rl_algorithm import RLAlgorithm
from garage.np.policies import Policy
from garage.sampler import RaySampler
Expand Down

0 comments on commit 4312678

Please sign in to comment.