From d2180b52603ae480046065451120ae5322149624 Mon Sep 17 00:00:00 2001 From: qlan3 Date: Tue, 4 Jun 2024 01:20:03 -0600 Subject: [PATCH] RLC2024 version --- .gitignore | 4 +- README.md | 91 ++- agents/A2C.py | 457 +++++++------- agents/A2C2.py | 50 -- agents/A2Ccollect.py | 154 +++++ agents/A2Cstar.py | 49 ++ agents/BaseAgent.py | 47 +- agents/CollectA2C.py | 188 ------ agents/CollectPPO.py | 473 --------------- agents/MetaA2C.py | 516 +++++++--------- agents/MetaA2Cstar.py | 189 ++++++ agents/MetaPPO.py | 919 ++++++++++++----------------- agents/MetaPPOstar.py | 385 ++++++++++++ agents/MetapA2C.py | 34 ++ agents/MetapPPO.py | 29 + agents/PPO.py | 703 +++++++++++----------- agents/PPOstar.py | 306 ++++++++++ agents/RNNIndentity.py | 207 ------- agents/SLCollect.py | 194 ++++++ agents/StarA2C.py | 237 -------- agents/__init__.py | 20 +- analysis_brax.py | 85 +-- analysis_grid.py | 211 +++---- analysis_identity.py | 74 --- components/gradients.py | 21 +- components/losses.py | 194 ------ components/network.py | 96 +-- components/optim.py | 365 ++++++------ components/ppo_networks.py | 99 ---- components/running_statistics.py | 20 +- components/star.py | 508 ++++++++++++++++ components/star_gradients.py | 94 +++ configs/a2c_catch.json | 21 + configs/a2c_grid.json | 21 + configs/ant_collect.json | 37 -- configs/ant_lopt.json | 47 -- configs/ant_meta.json | 55 -- configs/ant_ppo.json | 36 -- configs/bdl_a2c.json | 32 - configs/bdl_collect.json | 28 - configs/bdl_identity.json | 26 - configs/bdl_lopt.json | 38 -- configs/bdl_meta.json | 50 -- configs/bdl_star.json | 49 -- configs/bdl_star_lopt.json | 38 -- configs/collect_bdl.json | 22 + configs/collect_mnist.json | 17 + configs/fetch_collect.json | 37 -- configs/fetch_lopt.json | 47 -- configs/fetch_ppo.json | 36 -- configs/grasp_collect.json | 37 -- configs/grasp_lopt.json | 47 -- configs/grasp_ppo.json | 36 -- configs/grid_a2c.json | 43 -- configs/grid_collect.json | 43 -- configs/grid_meta.json | 61 -- configs/halfcheetah_collect.json | 37 -- configs/halfcheetah_lopt.json | 47 -- configs/halfcheetah_ppo.json | 36 -- configs/humanoid_collect.json | 37 -- configs/humanoid_lopt.json | 47 -- configs/humanoid_meta.json | 55 -- configs/humanoid_ppo.json | 36 -- configs/lopt_l2l_ant.json | 85 +++ configs/lopt_l2l_bdl.json | 66 +++ configs/lopt_l2l_catch.json | 46 ++ configs/lopt_l2l_humanoid.json | 85 +++ configs/lopt_lin_ant.json | 85 +++ configs/lopt_lin_bdl.json | 66 +++ configs/lopt_lin_catch.json | 46 ++ configs/lopt_lin_humanoid.json | 85 +++ configs/lopt_rl_ant.json | 85 +++ configs/lopt_rl_bdl.json | 66 +++ configs/lopt_rl_catch.json | 46 ++ configs/lopt_rl_grid_ant.json | 131 ++++ configs/lopt_rl_grid_humanoid.json | 131 ++++ configs/lopt_rl_grid_pendulum.json | 131 ++++ configs/lopt_rl_grid_walker2d.json | 131 ++++ configs/lopt_rl_humanoid.json | 85 +++ configs/lopt_rl_sdl.json | 66 +++ configs/lopt_rlp_ant.json | 45 ++ configs/lopt_rlp_bdl.json | 36 ++ configs/lopt_rlp_humanoid.json | 45 ++ configs/lopt_rlp_sdl.json | 36 ++ configs/lopt_star_ant.json | 86 +++ configs/lopt_star_bdl.json | 67 +++ configs/lopt_star_catch.json | 47 ++ configs/lopt_star_humanoid.json | 86 +++ configs/meta_l2l_ant.json | 43 ++ configs/meta_l2l_bdl.json | 35 ++ configs/meta_l2l_catch.json | 35 ++ configs/meta_l2l_humanoid.json | 43 ++ configs/meta_lin_ant.json | 43 ++ configs/meta_lin_bdl.json | 35 ++ configs/meta_lin_catch.json | 35 ++ configs/meta_lin_humanoid.json | 43 ++ configs/meta_rl_ant.json | 43 ++ configs/meta_rl_bdl.json | 35 ++ configs/meta_rl_catch.json | 35 ++ configs/meta_rl_grid.json | 37 ++ configs/meta_rl_humanoid.json | 43 ++ configs/meta_rl_sdl.json | 35 ++ configs/meta_rlp_ant.json | 43 ++ configs/meta_rlp_bdl.json | 35 ++ configs/meta_rlp_humanoid.json | 43 ++ configs/meta_rlp_sdl.json | 35 ++ configs/meta_star_ant.json | 44 ++ configs/meta_star_bdl.json | 36 ++ configs/meta_star_catch.json | 36 ++ configs/meta_star_humanoid.json | 44 ++ configs/ppo_ant.json | 30 + configs/ppo_humanoid.json | 30 + configs/ppo_pendulum.json | 30 + configs/ppo_walker2d.json | 30 + configs/pusher_collect.json | 37 -- configs/pusher_lopt.json | 47 -- configs/pusher_ppo.json | 36 -- configs/reacher_collect.json | 37 -- configs/reacher_lopt.json | 47 -- configs/reacher_ppo.json | 36 -- configs/sds_a2c.json | 32 - configs/sds_lopt.json | 68 --- configs/sds_meta.json | 50 -- configs/sds_star.json | 49 -- configs/sds_star_lopt.json | 38 -- configs/ur5e_collect.json | 37 -- configs/ur5e_lopt.json | 47 -- configs/ur5e_ppo.json | 36 -- download.py | 13 + envs/catch.py | 66 +-- envs/gridworld.py | 95 +-- envs/random_walk.py | 56 -- envs/spaces.py | 6 +- envs/utils.py | 91 +-- experiment.py | 15 +- main.py | 17 +- requirements.txt | 22 +- run.sh | 86 +++ utils/dataloader.py | 51 ++ utils/helper.py | 91 +-- utils/logger.py | 2 +- utils/plotter.py | 266 ++++++--- utils/sweeper.py | 438 +++++++------- 143 files changed, 7158 insertions(+), 5832 deletions(-) delete mode 100644 agents/A2C2.py create mode 100644 agents/A2Ccollect.py create mode 100644 agents/A2Cstar.py delete mode 100644 agents/CollectA2C.py delete mode 100644 agents/CollectPPO.py create mode 100644 agents/MetaA2Cstar.py create mode 100644 agents/MetaPPOstar.py create mode 100644 agents/MetapA2C.py create mode 100644 agents/MetapPPO.py create mode 100644 agents/PPOstar.py delete mode 100644 agents/RNNIndentity.py create mode 100644 agents/SLCollect.py delete mode 100644 agents/StarA2C.py delete mode 100644 analysis_identity.py delete mode 100644 components/losses.py delete mode 100644 components/ppo_networks.py create mode 100644 components/star.py create mode 100644 components/star_gradients.py create mode 100644 configs/a2c_catch.json create mode 100644 configs/a2c_grid.json delete mode 100644 configs/ant_collect.json delete mode 100644 configs/ant_lopt.json delete mode 100644 configs/ant_meta.json delete mode 100644 configs/ant_ppo.json delete mode 100644 configs/bdl_a2c.json delete mode 100644 configs/bdl_collect.json delete mode 100644 configs/bdl_identity.json delete mode 100644 configs/bdl_lopt.json delete mode 100644 configs/bdl_meta.json delete mode 100644 configs/bdl_star.json delete mode 100644 configs/bdl_star_lopt.json create mode 100644 configs/collect_bdl.json create mode 100644 configs/collect_mnist.json delete mode 100644 configs/fetch_collect.json delete mode 100644 configs/fetch_lopt.json delete mode 100644 configs/fetch_ppo.json delete mode 100644 configs/grasp_collect.json delete mode 100644 configs/grasp_lopt.json delete mode 100644 configs/grasp_ppo.json delete mode 100644 configs/grid_a2c.json delete mode 100644 configs/grid_collect.json delete mode 100644 configs/grid_meta.json delete mode 100644 configs/halfcheetah_collect.json delete mode 100644 configs/halfcheetah_lopt.json delete mode 100644 configs/halfcheetah_ppo.json delete mode 100644 configs/humanoid_collect.json delete mode 100644 configs/humanoid_lopt.json delete mode 100644 configs/humanoid_meta.json delete mode 100644 configs/humanoid_ppo.json create mode 100644 configs/lopt_l2l_ant.json create mode 100644 configs/lopt_l2l_bdl.json create mode 100644 configs/lopt_l2l_catch.json create mode 100644 configs/lopt_l2l_humanoid.json create mode 100644 configs/lopt_lin_ant.json create mode 100644 configs/lopt_lin_bdl.json create mode 100644 configs/lopt_lin_catch.json create mode 100644 configs/lopt_lin_humanoid.json create mode 100644 configs/lopt_rl_ant.json create mode 100644 configs/lopt_rl_bdl.json create mode 100644 configs/lopt_rl_catch.json create mode 100644 configs/lopt_rl_grid_ant.json create mode 100644 configs/lopt_rl_grid_humanoid.json create mode 100644 configs/lopt_rl_grid_pendulum.json create mode 100644 configs/lopt_rl_grid_walker2d.json create mode 100644 configs/lopt_rl_humanoid.json create mode 100644 configs/lopt_rl_sdl.json create mode 100644 configs/lopt_rlp_ant.json create mode 100644 configs/lopt_rlp_bdl.json create mode 100644 configs/lopt_rlp_humanoid.json create mode 100644 configs/lopt_rlp_sdl.json create mode 100644 configs/lopt_star_ant.json create mode 100644 configs/lopt_star_bdl.json create mode 100644 configs/lopt_star_catch.json create mode 100644 configs/lopt_star_humanoid.json create mode 100644 configs/meta_l2l_ant.json create mode 100644 configs/meta_l2l_bdl.json create mode 100644 configs/meta_l2l_catch.json create mode 100644 configs/meta_l2l_humanoid.json create mode 100644 configs/meta_lin_ant.json create mode 100644 configs/meta_lin_bdl.json create mode 100644 configs/meta_lin_catch.json create mode 100644 configs/meta_lin_humanoid.json create mode 100644 configs/meta_rl_ant.json create mode 100644 configs/meta_rl_bdl.json create mode 100644 configs/meta_rl_catch.json create mode 100644 configs/meta_rl_grid.json create mode 100644 configs/meta_rl_humanoid.json create mode 100644 configs/meta_rl_sdl.json create mode 100644 configs/meta_rlp_ant.json create mode 100644 configs/meta_rlp_bdl.json create mode 100644 configs/meta_rlp_humanoid.json create mode 100644 configs/meta_rlp_sdl.json create mode 100644 configs/meta_star_ant.json create mode 100644 configs/meta_star_bdl.json create mode 100644 configs/meta_star_catch.json create mode 100644 configs/meta_star_humanoid.json create mode 100644 configs/ppo_ant.json create mode 100644 configs/ppo_humanoid.json create mode 100644 configs/ppo_pendulum.json create mode 100644 configs/ppo_walker2d.json delete mode 100644 configs/pusher_collect.json delete mode 100644 configs/pusher_lopt.json delete mode 100644 configs/pusher_ppo.json delete mode 100644 configs/reacher_collect.json delete mode 100644 configs/reacher_lopt.json delete mode 100644 configs/reacher_ppo.json delete mode 100644 configs/sds_a2c.json delete mode 100644 configs/sds_lopt.json delete mode 100644 configs/sds_meta.json delete mode 100644 configs/sds_star.json delete mode 100644 configs/sds_star_lopt.json delete mode 100644 configs/ur5e_collect.json delete mode 100644 configs/ur5e_lopt.json delete mode 100644 configs/ur5e_ppo.json create mode 100644 download.py delete mode 100644 envs/random_walk.py create mode 100644 run.sh create mode 100644 utils/dataloader.py diff --git a/.gitignore b/.gitignore index 4ea4106..54d796b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,9 +7,7 @@ logfile *output* *DS_Store* *todo* -run*.sh -script.sh -.trunk +data ### Python ### # Byte-compiled / optimized / DLL files diff --git a/README.md b/README.md index f08aa81..173daf1 100755 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Optim4RL -Optim4RL is a framework of learning to optimize for reinforcement learning, introduced in our paper [Learning to Optimize for Reinforcement Learning](https://arxiv.org/abs/2302.01470). +This is the official implementation of *Optim4RL*, a learning to optimize framework for reinforcement learning, introduced in our RLC 2024 paper [Learning to Optimize for Reinforcement Learning](https://arxiv.org/abs/2302.01470). + **Table of Contents** @@ -14,41 +15,29 @@ Optim4RL is a framework of learning to optimize for reinforcement learning, intr - [Acknowledgement](#acknowledgement) - [Disclaimer](#disclaimer) -## Installation - -1. Install `learned_optimization`: - - ```bash - git clone https://github.com/google/learned_optimization.git - cd learned_optimization - pip install -e . - ``` -2. Install [JAX](https://github.com/google/jax): +## Installation -- TPU: +1. Install [JAX](https://github.com/google/jax) 0.4.19: See [Installing JAX](https://jax.readthedocs.io/en/latest/installation.html) for details. For example, ```bash - pip install "jax[tpu]>=0.3.23" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + pip install --upgrade "jax[cuda12_pip]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` - -- GPU: See [JAX GPU (CUDA) installation](https://github.com/google/jax#pip-installation-gpu-cuda) for details. An example: + +2. Install other packages: see `requirements.txt`. ```bash - pip install "jax[cuda11_cudnn82]>=0.3.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + pip install -r requirements.txt ``` -3. Install [Brax](https://github.com/google/brax): +3. Install `learned_optimization`: - ```bash - pip install git+https://github.com/google/brax.git - ``` - -4. Install other packages: see `requirements.txt`. + ```bash + git clone https://github.com/google/learned_optimization.git + cd learned_optimization + pip install -e . && cd .. + ``` - ```bash - pip install -r requirements.txt - ``` ## Usage @@ -56,13 +45,13 @@ Optim4RL is a framework of learning to optimize for reinforcement learning, intr All hyperparameters, including parameters for grid search, are stored in a configuration file in the directory `configs`. To run an experiment, a configuration index is first used to generate a configuration dict corresponding to this specific configuration index. Then we run an experiment defined by this configuration dict. All results, including log files, are saved in the directory `logs`. Please refer to the code for details. -For example, run the experiment with configuration file `sds_a2c.json` and configuration index `1`: +For example, run the experiment with configuration file `a2c_catch.json` and configuration index `1`: ```bash -python main.py --config_file ./configs/sds_a2c.json --config_idx 1 +python main.py --config_file ./configs/a2c_catch.json --config_idx 1 ``` -To do a grid search, we first calculate the number of total combinations in a configuration file (e.g. `sds_a2c.json`): +To do a grid search, we first calculate the number of total combinations in a configuration file (e.g. `a2c_catch.json`): ```bash python utils/sweeper.py @@ -70,49 +59,50 @@ python utils/sweeper.py The output will be: -`The number of total combinations in sds_a2c.json: 12` +`The number of total combinations in a2c_catch.json: 2` -Then we run through all configuration indexes from `1` to `12`. The simplest way is using a bash script: +Then we run through all configuration indexes from `1` to `2`. The simplest way is using a bash script: ```bash -for index in {1..12} +for index in {1..2} do - python main.py --config_file ./configs/sds_a2c.json --config_idx $index + python main.py --config_file ./configs/a2c_catch.json --config_idx $index done ``` [Parallel](https://www.gnu.org/software/parallel/) is usually a better choice to schedule a large number of jobs: ```bash -parallel --eta --ungroup python main.py --config_file ./configs/sds_a2c.json --config_idx {1} ::: $(seq 1 12) +parallel --eta --ungroup python main.py --config_file ./configs/a2c_catch.json --config_idx {1} ::: $(seq 1 2) ``` Any configuration index with the same remainder (divided by the number of total combinations) should have the same configuration dict (except the random seed if `generate_random_seed` is `True`). So for multiple runs, we just need to add the number of total combinations to the configuration index. For example, 5 runs for configuration index `1`: ```bash -for index in 1 13 25 37 49 +for index in 1 3 5 7 9 do - python main.py --config_file ./configs/sds_a2c.json --config_idx $index + python main.py --config_file ./configs/a2c_catch.json --config_idx $index done ``` Or a simpler way: ```bash -parallel --eta --ungroup python main.py --config_file ./configs/sds_a2c.json --config_idx {1} ::: $(seq 1 12 60) +parallel --eta --ungroup python main.py --config_file ./configs/a2c_catch.json --config_idx {1} ::: $(seq 1 2 10) ``` +Please check `run.sh` for the details of all experiments. + + ### Experiment -- Benchmark classical optimizers: run `*_a2c.json` or `*_ppo.json`. -- Collect agent gradients and parameter updates during training: run `*_collect.json`. -- Approximate the identity function with RNNs, given agent gradients as input: - 1. Collect agent gradients and parameter updates by running `bdl_collect.json`. - 2. Run `bdl_identity.json`. +- Benchmark classical optimizers: run `a2c_*.json` or `ppo_*.json`. +- Collect agent gradients and parameter updates during training: run `collect_*.json`. - Meta-learn optimizers and test them: - 1. Train optimizers by running `*_meta.json` or `*_star.json`. The meta-parameters at different training stages will be saved in corresponding log directories. Note that for some experiments, more than 1 GPU/TPU is needed due to a large GPU/TPU memory requirement. - 2. Use the paths of saved meta-parameters as the values for `param_load_path` in `*_lopt.json`. - 3. Run `*_lopt.json` to test learned optimizers with various meta-parameters. See `sds_lopt.json` for an example. + 1. Train optimizers by running `meta_*.json`. The meta-parameters at different training stages will be saved in corresponding log directories. Note that for some experiments, more than 1 GPU/TPU (e.g., 4) is needed due to a large GPU/TPU memory requirement. For example, check `meta_rl_catch.json`. + 2. Use the paths of saved meta-parameters as the values for `param_load_path` in `lopt_*.json`. + 3. Run `lopt_*.json` to test learned optimizers with various meta-parameters. For example, check `lopt_rl_catch.json`. + ### Analysis @@ -122,25 +112,29 @@ To analyze the experimental results, just run: python analysis_*.py ``` -Inside `analysis_*.py`, `unfinished_index` will print out the configuration indexes of unfinished jobs based on the existence of the result file. `memory_info` will print out the memory usage information and generate a histogram to show the distribution of memory usages in the directory `logs/sds_a2c/0`. Similarly, `time_info` will print out the time information and generate a histogram to show the time distribution in the directory `logs/sds_a2c/0`. Finally, `analyze` will generate `csv` files that store training and test results. Please check `analysis_*.py` for more details. More functions are available in `utils/plotter.py`. +Inside `analysis_*.py`, `unfinished_index` will print out the configuration indexes of unfinished jobs based on the existence of the result file. `memory_info` will print out the memory usage information and generate a histogram to show the distribution of memory usages in the directory `logs/a2c_catch/0`. Similarly, `time_info` will print out the time information and generate a histogram to show the time distribution in the directory `logs/a2c_catch/0`. Finally, `analyze` will generate `csv` files that store training and test results. Please check `analysis_*.py` for more details. More functions are available in `utils/plotter.py`. + ## Citation If you find this work useful to your research, please cite our paper. ```bibtex -@article{lan2023learning, +@inproceedings{lan2024learning, title={Learning to Optimize for Reinforcement Learning}, author={Lan, Qingfeng and Mahmood, A. Rupam and Yan, Shuicheng and Xu, Zhongwen}, - journal={arXiv preprint arXiv:2302.01470}, - year={2023} + booktitle={Reinforcement Learning Conference}, + year={2024}, + url={https://openreview.net/forum?id=JQuEXGj2r1} } ``` + ## License `Optim4RL` is distributed under the terms of the [Apache2](https://www.apache.org/licenses/LICENSE-2.0) license. + ## Acknowledgement We thank the following projects which provide great references: @@ -150,6 +144,7 @@ We thank the following projects which provide great references: - [learned_optimization](https://github.com/google/learned_optimization) - [Explorer](https://github.com/qlan3/Explorer) + ## Disclaimer This is not an official Sea Limited or Garena Online Private Limited product. \ No newline at end of file diff --git a/agents/A2C.py b/agents/A2C.py index bee4fdd..674f7b7 100644 --- a/agents/A2C.py +++ b/agents/A2C.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,256 +12,237 @@ # See the License for the specific language governing permissions and # limitations under the License. +import flax +import rlax +import optax +import numpy as np +import pandas as pd from copy import deepcopy +from functools import partial from typing import Any -import chex import jax import jax.numpy as jnp -import numpy as np -import optax -import pandas as pd -import rlax -from jax import lax, random, tree_util +from jax import jit, lax, random, tree_util +from utils.helper import jitted_split from agents.BaseAgent import BaseAgent +from components.optim import set_optim from components.network import ActorVCriticNet -from components.optim import set_optimizer -@chex.dataclass(frozen=True) +@flax.struct.dataclass class TimeStep: - obs: chex.Array - action: chex.Array - reward: chex.Array - done: chex.Array + obs: jax.Array + action: jax.Array + reward: jax.Array + done: jax.Array -@chex.dataclass -class TrainingState: - agent_param: Any - agent_optim_state: optax.OptState +@flax.struct.dataclass +class MyTrainState: + agent_param: flax.core.FrozenDict + agent_optim_state: Any class A2C(BaseAgent): - """ - Implementation of A2C for gridworlds, compatible with classical optimizers and learned optimizers (LinearOptim, Optim4RL, and L2LGD2). - """ - - def __init__(self, cfg): - super().__init__(cfg) - # Set agent optimizer - self.seed, optim_seed = random.split(self.seed) - # Set learning_rates for tasks - agent_optimizer_cfg = deepcopy(cfg["agent_optimizer"]["kwargs"]) - if self.agent_name in ["MetaA2C"]: - if isinstance(agent_optimizer_cfg["learning_rate"], list): - self.learning_rates = agent_optimizer_cfg["learning_rate"] - elif isinstance(agent_optimizer_cfg["learning_rate"], float): - self.learning_rates = [ - cfg["agent_optimizer"]["kwargs"]["learning_rate"] - ] * len(self.env_names) - else: - raise TypeError("Only List[float] or float is allowed") - agent_optimizer_cfg["learning_rate"] = 1.0 - self.agent_optimizer = set_optimizer( - cfg["agent_optimizer"]["name"], agent_optimizer_cfg, optim_seed - ) - # Make some utility tools - self.reshape = lambda x: x.reshape( - (self.core_count, self.batch_size) + x.shape[1:] - ) - - def create_agent_nets(self): - agent_nets = [] - for i, env_name in enumerate(self.env_names): - agent_net = ActorVCriticNet( - action_size=self.action_sizes[i], env_name=env_name - ) - agent_nets.append(agent_net) - return agent_nets - - def move_one_step(self, carry_in, step_seed): - agent_param, env_state = carry_in - step_seed, action_seed = random.split(step_seed) - obs = self.env.render_obs(env_state) - logits, _ = self.agent_net.apply(agent_param, obs[None,]) - # Select an action - action = random.categorical(key=action_seed, logits=logits[0]) - # Move one step in env - env_state, reward, done = self.env.step(step_seed, env_state, action) - carry_out = [agent_param, env_state] - return carry_out, TimeStep(obs=obs, action=action, reward=reward, done=done) - - def move_rollout_steps(self, agent_param, env_state, step_seed): - carry_in = [agent_param, env_state] - # Move for rollout_steps - step_seeds = random.split(step_seed, self.rollout_steps) - carry_out, rollout = lax.scan( - f=self.move_one_step, init=carry_in, xs=step_seeds - ) - env_state = carry_out[1] - return env_state, rollout - - def compute_agent_loss(self, agent_param, env_state, step_seed): - # Move for rollout_steps - env_state, rollout = self.move_rollout_steps(agent_param, env_state, step_seed) - last_obs = self.env.render_obs(env_state) - all_obs = jnp.concatenate([rollout.obs, jnp.expand_dims(last_obs, 0)], axis=0) - logits, v = self.agent_net.apply(agent_param, all_obs) - # Compute multi-step temporal difference error - td_error = rlax.td_lambda( - v_tm1=v[:-1], - r_t=rollout.reward, - discount_t=self.discount * (1.0 - rollout.done), - v_t=v[1:], - lambda_=self.cfg["agent"]["gae_lambda"], - stop_target_gradients=True, - ) - # Compute critic loss - critic_loss = self.cfg["agent"]["critic_loss_weight"] * jnp.mean(td_error**2) - # Compute actor loss - actor_loss = rlax.policy_gradient_loss( - logits_t=logits[:-1], - a_t=rollout.action, - adv_t=td_error, - w_t=jnp.ones_like(td_error), - use_stop_gradient=True, - ) - entropy_loss = self.cfg["agent"]["entropy_weight"] * rlax.entropy_loss( - logits_t=logits[:-1], w_t=jnp.ones_like(td_error) - ) - total_loss = actor_loss + critic_loss + entropy_loss - return total_loss, (env_state, rollout) - - def learn(self, carry_in): - training_state, env_state, seed = carry_in - seed, step_seed = random.split(seed) - # Generate one rollout and compute the gradient - agent_grad, (env_state, rollout) = jax.grad( - self.compute_agent_loss, has_aux=True - )(training_state.agent_param, env_state, step_seed) - # Reduce mean gradients across batch an cores - agent_grad = lax.pmean(agent_grad, axis_name="batch") - agent_grad = lax.pmean(agent_grad, axis_name="core") - # Compute the updates of model parameters - agent_param_update, agent_optim_state = self.agent_optimizer.update( - agent_grad, training_state.agent_optim_state - ) - # Update model parameters - agent_param = optax.apply_updates( - training_state.agent_param, agent_param_update - ) - training_state = training_state.replace( - agent_param=agent_param, agent_optim_state=agent_optim_state - ) - carry_out = [training_state, env_state, seed] - logs = dict(done=rollout.done, reward=rollout.reward) - return carry_out, logs - - def train_iterations(self, carry_in): - # Vectorize the learn function across batch - batched_learn = jax.vmap( - self.learn, - in_axes=([None, 0, 0],), - out_axes=([None, 0, 0], 0), - axis_name="batch", - ) - - # Repeat the training for many iterations - def train_one_iteration(carry, _): - return batched_learn(carry) - - carry_out, logs = lax.scan( - f=train_one_iteration, init=carry_in, length=self.iterations, xs=None - ) - return carry_out, logs - - def train(self): - seed = self.seed - for i, env_name in enumerate(self.env_names): - self.logger.info( - f"<{self.config_idx}> Environment {i+1}/{len(self.env_names)}: {env_name}" - ) - # Generate random seeds for env and agent - seed, env_seed, agent_seed = random.split(seed, 3) - # Set environment and agent network - self.env, self.agent_net = self.envs[i], self.agent_nets[i] - # Initialize agent parameter and optimizer state - dummy_obs = self.env.render_obs(self.env.reset(env_seed))[None, :] - agent_param = self.agent_net.init(agent_seed, dummy_obs) - training_state = TrainingState( - agent_param=agent_param, - agent_optim_state=self.agent_optimizer.init(agent_param), - ) - # Intialize env_states over cores and batch - seed, *env_seeds = random.split(seed, self.core_count * self.batch_size + 1) - env_states = jax.vmap(self.env.reset)(jnp.stack(env_seeds)) - env_states = tree_util.tree_map(self.reshape, env_states) - seed, *step_seeds = random.split( - seed, self.core_count * self.batch_size + 1 - ) - step_seeds = self.reshape(jnp.stack(step_seeds)) - # Replicate the training process over multiple cores - pmap_train_iterations = jax.pmap( - self.train_iterations, - in_axes=([None, 0, 0],), - out_axes=([None, 0, 0], 0), - axis_name="core", - ) - carry_in = [training_state, env_states, step_seeds] - carry_out, logs = pmap_train_iterations(carry_in) - # Process and save logs - self.process_logs(env_name, logs) - - def process_logs(self, env_name, logs): - # Move logs to CPU, with shape {[core_count, iterations, batch_size, *]} - logs = jax.device_get(logs) - # Reshape to {[iterations, core_count, batch_size, *]} - for k in logs.keys(): - logs[k] = logs[k].swapaxes(0, 1) - # Compute episode return - episode_return, step_list = self.get_episode_return( - logs["done"], logs["reward"] - ) - result = { - "Env": env_name, - "Agent": self.agent_name, - "Step": step_list * self.macro_step, - "Return": episode_return, - } - # Save logs - self.save_logs(env_name, result) - - def get_episode_return(self, done_list, reward_list): - # Input shape: [iterations, core_count, batch_size, rollout_steps*(inner_updates+1)] - # Reshape to: [batch_size, core_count, iterations*rollout_steps*(inner_updates+1)] - done_list = done_list.swapaxes(0, 2) - done_list = done_list.reshape(done_list.shape[:2] + (-1,)) - reward_list = reward_list.swapaxes(0, 2) - reward_list = reward_list.reshape(reward_list.shape[:2] + (-1,)) - # Compute return - for j in range(1, reward_list.shape[-1]): - reward_list[:, :, j] = reward_list[:, :, j] + reward_list[:, :, j - 1] * ( - 1 - done_list[:, :, j - 1] - ) - return_list = reward_list * done_list - # Shape: [batch_size, core_count, iterations, rollout_steps*(inner_updates+1)] - return_list = return_list.reshape(return_list.shape[:2] + (self.iterations, -1)) - done_list = done_list.reshape(done_list.shape[:2] + (self.iterations, -1)) - # Average over batch, core, and rollout, to shape [iterations] - return_list = return_list.sum(axis=(0, 1, 3)) - done_list = done_list.sum(axis=(0, 1, 3)) - # Get return logs - step_list, episode_return = [], [] - for i in range(self.iterations): - if done_list[i] != 0: - episode_return.append(return_list[i] / done_list[i]) - step_list.append(i) - return np.array(episode_return), np.array(step_list) - - def save_logs(self, env_name, result): - result = pd.DataFrame(result) - result["Env"] = result["Env"].astype("category") - result["Agent"] = result["Agent"].astype("category") - result.to_feather(self.log_path(env_name)) + ''' + Implementation of Actor Critic. + ''' + def __init__(self, cfg): + super().__init__(cfg) + # Set agent optimizer + self.seed, optim_seed = jitted_split(self.seed) + # Set learning_rates for tasks + agent_optim_cfg = deepcopy(cfg['agent_optim']['kwargs']) + if ('MetaA2C' in self.agent_name or 'MetapA2C' in self.agent_name) and 'star' not in self.agent_name: + if isinstance(agent_optim_cfg['learning_rate'], list): + self.learning_rates = agent_optim_cfg['learning_rate'].copy() + elif isinstance(agent_optim_cfg['learning_rate'], float): + self.learning_rates = [agent_optim_cfg['learning_rate']] * self.task_num + else: + raise TypeError('Only List[float] or float is allowed') + agent_optim_cfg['learning_rate'] = 1.0 + self.agent_optim = set_optim(cfg['agent_optim']['name'], agent_optim_cfg, optim_seed) + # Make some utility tools + self.reshape = lambda x: x.reshape((self.core_count, self.batch_size) + x.shape[1:]) + + def create_agent_nets(self): + agent_nets = [] + for i, env_name in enumerate(self.env_names): + agent_net = ActorVCriticNet( + action_size=self.action_sizes[i], + env_name=env_name + ) + agent_nets.append(agent_net) + return agent_nets + + @partial(jit, static_argnames=['self', 'i']) + def move_one_step(self, carry_in, step_seed, i): + agent_param, env_state = carry_in + step_seed, action_seed = jitted_split(step_seed) + obs = self.envs[i].render_obs(env_state) + logits, _ = self.agent_nets[i].apply(agent_param, obs[None,]) + # Select an action + action = random.categorical(key=action_seed, logits=logits[0]) + # Move one step in env + env_state, reward, done = self.envs[i].step(step_seed, env_state, action) + carry_out = (agent_param, env_state) + return carry_out, TimeStep(obs=obs, action=action, reward=reward, done=done) + + @partial(jit, static_argnames=['self', 'i']) + def move_rollout_steps(self, agent_param, env_state, step_seed, i): + carry_in = (agent_param, env_state) + # Move for rollout_steps + step_seeds = jitted_split(step_seed, self.rollout_steps) + carry_out, rollout = lax.scan(f=partial(self.move_one_step, i=i), init=carry_in, xs=step_seeds) + env_state = carry_out[1] + return env_state, rollout + + @partial(jit, static_argnames=['self', 'i']) + def compute_loss(self, agent_param, env_state, step_seed, i): + # Move for rollout_steps + env_state, rollout = self.move_rollout_steps(agent_param, env_state, step_seed, i) + last_obs = self.envs[i].render_obs(env_state) + all_obs = jnp.concatenate([rollout.obs, jnp.expand_dims(last_obs, 0)], axis=0) + logits, v = self.agent_nets[i].apply(agent_param, all_obs) + # Compute multi-step temporal difference error + td_error = rlax.td_lambda( + v_tm1 = v[:-1], + r_t = rollout.reward, + discount_t = self.discount * (1.0-rollout.done), + v_t = v[1:], + lambda_ = self.cfg['agent']['gae_lambda'], + stop_target_gradients = True + ) + # Compute critic loss + critic_loss = self.cfg['agent']['critic_loss_weight'] * jnp.mean(td_error**2) + # Compute actor loss + actor_loss = rlax.policy_gradient_loss( + logits_t = logits[:-1], + a_t = rollout.action, + adv_t = td_error, + w_t = jnp.ones_like(td_error), + use_stop_gradient = True + ) + entropy_loss = self.cfg['agent']['entropy_weight'] * rlax.entropy_loss( + logits_t = logits[:-1], + w_t=jnp.ones_like(td_error) + ) + total_loss = actor_loss + critic_loss + entropy_loss + return total_loss, (env_state, rollout) + + @partial(jit, static_argnames=['self', 'i']) + def learn(self, carry_in, i): + training_state, env_state, seed = carry_in + seed, step_seed = jitted_split(seed) + # Generate one rollout and compute the gradient + agent_grad, (env_state, rollout) = jax.grad(self.compute_loss, has_aux=True)(training_state.agent_param, env_state, step_seed, i) + # Reduce mean gradients across batch an cores + agent_grad = lax.pmean(agent_grad, axis_name='batch') + agent_grad = lax.pmean(agent_grad, axis_name='core') + # Compute the updates of model parameters + param_update, new_optim_state = self.agent_optim.update(agent_grad, training_state.agent_optim_state) + # Update model parameters + new_param = optax.apply_updates(training_state.agent_param, param_update) + training_state = training_state.replace( + agent_param = new_param, + agent_optim_state = new_optim_state + ) + carry_out = (training_state, env_state, seed) + logs = dict(done=rollout.done, reward=rollout.reward) + return carry_out, logs + + @partial(jit, static_argnames=['self', 'i']) + def train_iterations(self, carry_in, i): + # Vectorize the learn function across batch + batched_learn = jax.vmap( + self.learn, + in_axes=((None, 0, 0), None), + out_axes=((None, 0, 0), 0), + axis_name='batch' + ) + # Repeat the training for many iterations + train_one_iteration = lambda carry, _: batched_learn(carry, i) + carry_out, logs = lax.scan(f=train_one_iteration, init=carry_in, length=self.iterations, xs=None) + return carry_out, logs + + def train(self): + seed = self.seed + for i in range(self.task_num): + self.logger.info(f'<{self.config_idx}> Task {i+1}/{self.task_num}: {self.env_names[i]}') + # Generate random seeds for env and agent + seed, env_seed, agent_seed = jitted_split(seed, 3) + # Initialize agent parameter and optimizer state + dummy_obs = self.envs[i].render_obs(self.envs[i].reset(env_seed))[None,] + agent_param = self.agent_nets[i].init(agent_seed, dummy_obs) + training_state = MyTrainState( + agent_param = agent_param, + agent_optim_state = self.agent_optim.init(agent_param) + ) + # Intialize env_states over cores and batch + seed, *env_seeds = jitted_split(seed, self.core_count * self.batch_size + 1) + env_states = jax.vmap(self.envs[i].reset)(jnp.stack(env_seeds)) + env_states = tree_util.tree_map(self.reshape, env_states) + seed, *step_seeds = jitted_split(seed, self.core_count * self.batch_size + 1) + step_seeds = self.reshape(jnp.stack(step_seeds)) + # Replicate the training process over multiple cores + pmap_train_iterations = jax.pmap( + self.train_iterations, + in_axes = ((None, 0, 0), None), + out_axes = ((None, 0, 0), 0), + axis_name = 'core', + static_broadcasted_argnums = (1) + ) + carry_in = (training_state, env_states, step_seeds) + carry_out, logs = pmap_train_iterations(carry_in, i) + # Process and save logs + self.process_logs(self.env_names[i], logs) + + def process_logs(self, env_name, logs): + # Move logs to CPU, with shape {[core_count, iterations, batch_size, *]} + logs = jax.device_get(logs) + # Reshape to {[iterations, core_count, batch_size, *]} + for k in logs.keys(): + logs[k] = logs[k].swapaxes(0, 1) + # Compute episode return + episode_return, step_list = self.get_episode_return(logs['done'], logs['reward']) + result = { + 'Env': env_name, + 'Agent': self.agent_name, + 'Step': step_list*self.macro_step, + 'Return': episode_return + } + # Save logs + self.save_logs(env_name, result) + + def get_episode_return(self, done_list, reward_list): + # Input shape: [iterations, core_count, batch_size, rollout_steps*(inner_updates+1)] + # Reshape to: [batch_size, core_count, iterations*rollout_steps*(inner_updates+1)] + done_list = done_list.swapaxes(0, 2) + done_list = done_list.reshape(done_list.shape[:2]+ (-1,)) + reward_list = reward_list.swapaxes(0, 2) + reward_list = reward_list.reshape(reward_list.shape[:2]+ (-1,)) + # Compute return + for j in range(1, reward_list.shape[-1]): + reward_list[:,:,j] = reward_list[:,:,j] + reward_list[:,:,j-1] * (1-done_list[:,:,j-1]) + return_list = reward_list * done_list + # Shape: [batch_size, core_count, iterations, rollout_steps*(inner_updates+1)] + return_list = return_list.reshape(return_list.shape[:2]+ (self.iterations, -1)) + done_list = done_list.reshape(done_list.shape[:2]+ (self.iterations, -1)) + # Average over batch, core, and rollout, to shape [iterations] + return_list = return_list.sum(axis=(0,1,3)) + done_list = done_list.sum(axis=(0,1,3)) + # Get return logs + step_list, episode_return = [], [] + for i in range(self.iterations): + if done_list[i] != 0: + episode_return.append(return_list[i]/done_list[i]) + step_list.append(i) + return np.array(episode_return), np.array(step_list) + + def save_logs(self, env_name, result): + result = pd.DataFrame(result) + result['Env'] = result['Env'].astype('category') + result['Agent'] = result['Agent'].astype('category') + result.to_feather(self.log_path(env_name)) \ No newline at end of file diff --git a/agents/A2C2.py b/agents/A2C2.py deleted file mode 100644 index c200104..0000000 --- a/agents/A2C2.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2022 Garena Online Private Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax -from jax import lax, random - -from agents.A2C import A2C - - -class A2C2(A2C): - """ - Implementation of A2C for gridworlds, only compatible with STAR. - """ - - def __init__(self, cfg): - super().__init__(cfg) - - def learn(self, carry_in): - training_state, env_state, seed = carry_in - seed, step_seed = random.split(seed) - # Generate one rollout and compute the gradient - (agent_loss, (env_state, rollout)), agent_grad = jax.value_and_grad( - self.compute_agent_loss, has_aux=True - )(training_state.agent_param, env_state, step_seed) - # Reduce mean gradients across batch an cores - agent_grad = lax.pmean(agent_grad, axis_name="batch") - agent_grad = lax.pmean(agent_grad, axis_name="core") - # Update model parameters - agent_optim_state = self.agent_optimizer.update( - agent_grad, training_state.agent_optim_state, agent_loss - ) - # Set new training_state - training_state = training_state.replace( - agent_param=training_state.agent_optim_state.params, - agent_optim_state=agent_optim_state, - ) - carry_out = [training_state, env_state, seed] - logs = dict(done=rollout.done, reward=rollout.reward) - return carry_out, logs diff --git a/agents/A2Ccollect.py b/agents/A2Ccollect.py new file mode 100644 index 0000000..932b7dc --- /dev/null +++ b/agents/A2Ccollect.py @@ -0,0 +1,154 @@ +# Copyright 2024 Garena Online Private Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import optax +import numpy as np +from functools import partial +import matplotlib.pyplot as plt + +import jax +import jax.numpy as jnp +from jax import jit, lax, tree_util + +from utils.helper import jitted_split, pytree2array +from agents.A2C import A2C, MyTrainState + + + +class A2Ccollect(A2C): + ''' + Collect agent gradient in A2C. + ''' + def __init__(self, cfg): + super().__init__(cfg) + + @partial(jit, static_argnames=['self', 'i']) + def learn(self, carry_in, i): + training_state, env_state, seed = carry_in + seed, step_seed = jitted_split(seed) + # Generate one rollout and compute the gradient + agent_grad, (env_state, rollout) = jax.grad(self.compute_loss, has_aux=True)(training_state.agent_param, env_state, step_seed, i) + # Reduce mean gradients across batch an cores + agent_grad = lax.pmean(agent_grad, axis_name='batch') + agent_grad = lax.pmean(agent_grad, axis_name='core') + # Compute the updates of model parameters + param_update, new_optim_state = self.agent_optim.update(agent_grad, training_state.agent_optim_state) + # Update model parameters + new_param = optax.apply_updates(training_state.agent_param, param_update) + training_state = training_state.replace( + agent_param = new_param, + agent_optim_state = new_optim_state + ) + carry_out = (training_state, env_state, seed) + # Choose part of agent_grad due to memory limit + agent_grad = pytree2array(agent_grad) + idxs = jnp.array(range(0, len(agent_grad), self.cfg['agent']['data_reduce'])) + agent_grad = agent_grad[idxs] + logs = dict(done=rollout.done, reward=rollout.reward) + return carry_out, (logs, agent_grad) + + @partial(jit, static_argnames=['self', 'i']) + def train_iterations(self, carry_in, i): + # Vectorize the learn function across batch + batched_learn = jax.vmap( + self.learn, + in_axes=((None, 0, 0), None), + out_axes=((None, 0, 0), (0, None)), + axis_name='batch' + ) + # Repeat the training for many iterations + train_one_iteration = lambda carry, _: batched_learn(carry, i) + carry_out, logs = lax.scan(f=train_one_iteration, init=carry_in, length=self.iterations, xs=None) + return carry_out, logs + + def train(self): + seed = self.seed + for i in range(self.task_num): + self.logger.info(f'<{self.config_idx}> Task {i+1}/{self.task_num}: {self.env_names[i]}') + # Generate random seeds for env and agent + seed, env_seed, agent_seed = jitted_split(seed, 3) + # Initialize agent parameter and optimizer state + dummy_obs = self.envs[i].render_obs(self.envs[i].reset(env_seed))[None,] + agent_param = self.agent_nets[i].init(agent_seed, dummy_obs) + training_state = MyTrainState( + agent_param = agent_param, + agent_optim_state = self.agent_optim.init(agent_param) + ) + # Intialize env_states over cores and batch + seed, *env_seeds = jitted_split(seed, self.core_count * self.batch_size + 1) + env_states = jax.vmap(self.envs[i].reset)(jnp.stack(env_seeds)) + env_states = tree_util.tree_map(self.reshape, env_states) + seed, *step_seeds = jitted_split(seed, self.core_count * self.batch_size + 1) + step_seeds = self.reshape(jnp.stack(step_seeds)) + # Replicate the training process over multiple cores + pmap_train_iterations = jax.pmap( + self.train_iterations, + in_axes = ((None, 0, 0), None), + out_axes = ((None, 0, 0), 0), + axis_name = 'core', + static_broadcasted_argnums = (1) + ) + carry_in = (training_state, env_states, step_seeds) + carry_out, logs = pmap_train_iterations(carry_in, i) + # Process and save logs + return_logs, agent_grad = logs + return_logs['agent_grad'] = agent_grad + self.process_logs(self.env_names[i], return_logs) + + def process_logs(self, env_name, logs): + # Move logs to CPU, with shape {[core_count, iterations, batch_size, *]} + logs = jax.device_get(logs) + # Reshape to {[iterations, core_count, batch_size, *]} + for k in logs.keys(): + logs[k] = logs[k].swapaxes(0, 1) + # Compute episode return + episode_return, step_list = self.get_episode_return(logs['done'], logs['reward']) + result = { + 'Env': env_name, + 'Agent': self.agent_name, + 'Step': step_list*self.macro_step, + 'Return': episode_return + } + # Save logs + self.save_logs(env_name, result) + # Save agent_grad: (num_param, optimization_steps) + self.logger.info(f"# of agent_param collected for {env_name}: {logs['agent_grad'].shape[0]}") + np.savez(self.cfg['logs_dir']+'data.npz', x=logs['agent_grad']) + # Print some grad statistics + grad = logs['agent_grad'].reshape(-1) + log_abs_grad = np.log10(np.abs(grad)+1e-8) + self.logger.info(f'g: min = {grad.min():.4f}, max = {grad.max():.4f}, mean = {grad.mean():.4f}') + self.logger.info(f'log(|g|+1e-8): min = {log_abs_grad.min():.4f}, max = {log_abs_grad.max():.4f}, mean = {log_abs_grad.mean():.4f}') + # Plot grad + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5), tight_layout=True) + ax1.hist(grad, bins=40, density=False) + ax1.set_yscale('log') + ax1.set_xlabel('$g$', fontsize=18) + ax1.set_ylabel('log(counts)', fontsize=18) + ax1.grid(True) + # Plot log(|grad|) + ax2.hist(log_abs_grad, bins=list(np.arange(-9, 5, 0.5)), density=True) + ax2.set_xlim(-9, 5) + ax2.set_xticks(list(np.arange(-9, 5, 1))) + ax2.set_xlabel('$\log(|g|+10^{-8})$', fontsize=18) + ax2.set_ylabel('Probability density', fontsize=18) + ax2.grid(True) + # Adjust figure layout + plt.tick_params(axis='both', which='major', labelsize=14) + fig.tight_layout() + # Save figure + plt.savefig(self.cfg['logs_dir']+'grad.png') + plt.clf() + plt.cla() + plt.close() \ No newline at end of file diff --git a/agents/A2Cstar.py b/agents/A2Cstar.py new file mode 100644 index 0000000..df2bda4 --- /dev/null +++ b/agents/A2Cstar.py @@ -0,0 +1,49 @@ +# Copyright 2024 Garena Online Private Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from jax import jit, lax + +from functools import partial + +from utils.helper import jitted_split +from agents.A2C import A2C + + +class A2Cstar(A2C): + ''' + Implementation of Actor Critic for Star optimizer only. + ''' + def __init__(self, cfg): + super().__init__(cfg) + + @partial(jit, static_argnames=['self', 'i']) + def learn(self, carry_in, i): + training_state, env_state, seed = carry_in + seed, step_seed = jitted_split(seed) + # Generate one rollout and compute the gradient + (agent_loss, (env_state, rollout)), agent_grad = jax.value_and_grad(self.compute_loss, has_aux=True)(training_state.agent_param, env_state, step_seed, i) + # Reduce mean gradients across batch an cores + agent_grad = lax.pmean(agent_grad, axis_name='batch') + agent_grad = lax.pmean(agent_grad, axis_name='core') + # Update model parameters + new_optim_state = self.agent_optim.update(agent_grad, training_state.agent_optim_state, agent_loss) + # Set new training_state + training_state = training_state.replace( + agent_param = new_optim_state.params, + agent_optim_state = new_optim_state + ) + carry_out = (training_state, env_state, seed) + logs = dict(done=rollout.done, reward=rollout.reward) + return carry_out, logs \ No newline at end of file diff --git a/agents/BaseAgent.py b/agents/BaseAgent.py index cdb4798..0a426f0 100644 --- a/agents/BaseAgent.py +++ b/agents/BaseAgent.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,20 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax -import pickle import numpy as np -import jax.numpy as jnp from copy import deepcopy -from jax import lax, random, tree_util + +import jax +from jax import random from utils.logger import Logger from envs.utils import make_env from envs.spaces import Box, Discrete -from gymnax.environments.spaces import Box as gymnax_Box -from gymnax.environments.spaces import Discrete as gymnax_Discrete - class BaseAgent(object): def __init__(self, cfg): @@ -42,12 +38,12 @@ def __init__(self, cfg): del cfg['env']['name'] # Make envs self.envs = [] - for i in range(len(self.env_names)): - env_name = self.env_names[i] + self.task_num = len(self.env_names) + for i in range(self.task_num): env_cfg = deepcopy(cfg['env']) if 'reward_scaling' in env_cfg.keys() and isinstance(env_cfg['reward_scaling'], list): env_cfg['reward_scaling'] = env_cfg['reward_scaling'][i] - self.envs.append(make_env(env_name, env_cfg)) + self.envs.append(make_env(self.env_names[i], env_cfg)) # Get action_types, action_sizes, and state_sizes self.get_env_info() # Create agent networks @@ -78,33 +74,16 @@ def get_env_info(self): self.action_types, self.action_sizes, self.state_sizes = [], [], [] for env in self.envs: # Get state info - if isinstance(env.observation_space, Discrete) or isinstance(env.observation_space, gymnax_Discrete): - self.state_sizes.append(env.observation_space.n) + if isinstance(env.obs_space, Discrete): + self.state_sizes.append(env.obs_space.n) else: # Box, MultiBinary - self.state_sizes.append(int(np.prod(env.observation_space.shape))) + self.state_sizes.append(int(np.prod(env.obs_space.shape))) # Get action info - if isinstance(env.action_space, Discrete) or isinstance(env.action_space, gymnax_Discrete): + if isinstance(env.action_space, Discrete): self.action_types.append('DISCRETE') self.action_sizes.append(env.action_space.n) - elif isinstance(env.action_space, Box) or isinstance(env.action_space, gymnax_Box): + elif isinstance(env.action_space, Box): self.action_types.append('CONTINUOUS') self.action_sizes.append(env.action_space.shape[0]) else: - raise ValueError('Unknown action type.') - - def pytree2array(self, values): - leaves = tree_util.tree_leaves(lax.stop_gradient(values)) - a = jnp.concatenate(leaves, axis=None) - return a - - def save_model_param(self, model_param, filepath): - f = open(filepath, 'wb') - pickle.dump(model_param, f) - f.close() - - def load_model_param(self, filepath): - f = open(filepath, 'rb') - model_param = pickle.load(f) - model_param = tree_util.tree_map(jnp.array, model_param) - f.close() - return model_param \ No newline at end of file + raise ValueError('Unknown action type.') \ No newline at end of file diff --git a/agents/CollectA2C.py b/agents/CollectA2C.py deleted file mode 100644 index c20ecc7..0000000 --- a/agents/CollectA2C.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright 2022 Garena Online Private Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax -import jax.numpy as jnp -import matplotlib.pyplot as plt -import numpy as np -import optax -from jax import lax, random, tree_util - -from agents.A2C import A2C, TrainingState - - -class CollectA2C(A2C): - """ - Collect agent gradients and parameter updates during training A2C in gridworlds. - """ - - def __init__(self, cfg): - super().__init__(cfg) - - def learn(self, carry_in): - training_state, env_state, seed = carry_in - seed, step_seed = random.split(seed) - # Generate one rollout and compute the gradient - agent_grad, (env_state, rollout) = jax.grad( - self.compute_agent_loss, has_aux=True - )(training_state.agent_param, env_state, step_seed) - # Reduce mean gradients across batch an cores - agent_grad = lax.pmean(agent_grad, axis_name="batch") - agent_grad = lax.pmean(agent_grad, axis_name="core") - # Compute the updates of model parameters - agent_param_update, agent_optim_state = self.agent_optimizer.update( - agent_grad, training_state.agent_optim_state - ) - # Update model parameters - agent_param = optax.apply_updates( - training_state.agent_param, agent_param_update - ) - training_state = training_state.replace( - agent_param=agent_param, agent_optim_state=agent_optim_state - ) - carry_out = [training_state, env_state, seed] - # Pick agent_grad and agent_param_update for a few parameter - agent_grad = self.pytree2array(agent_grad) - agent_param_update = ( - self.pytree2array(agent_param_update) - / self.cfg["agent_optimizer"]["kwargs"]["learning_rate"] - ) - idxs = jnp.array(range(0, len(agent_grad), self.cfg["agent"]["data_reduce"])) - agent_grad, agent_param_update = agent_grad[idxs], agent_param_update[idxs] - return_logs = dict(done=rollout.done, reward=rollout.reward) - grad_logs = dict(agent_grad=agent_grad, agent_param_update=agent_param_update) - return carry_out, (return_logs, grad_logs) - - def train_iterations(self, carry_in): - # Vectorize the learn function across batch - batched_learn = jax.vmap( - self.learn, - in_axes=([None, 0, 0],), - out_axes=([None, 0, 0], (0, None)), - axis_name="batch", - ) - - # Repeat the training for many iterations - def train_one_iteration(carry, _): - return batched_learn(carry) - - carry_out, logs = lax.scan( - f=train_one_iteration, init=carry_in, length=self.iterations, xs=None - ) - return carry_out, logs - - def train(self): - seed = self.seed - for i, env_name in enumerate(self.env_names): - self.logger.info( - f"<{self.config_idx}> Environment {i+1}/{len(self.env_names)}: {env_name}" - ) - # Generate random seeds for env and agent - seed, env_seed, agent_seed = random.split(seed, 3) - # Set environment and agent network - self.env, self.agent_net = self.envs[i], self.agent_nets[i] - # Initialize agent parameter and optimizer state - dummy_obs = self.env.render_obs(self.env.reset(env_seed))[None, :] - agent_param = self.agent_net.init(agent_seed, dummy_obs) - self.logger.info( - f"# of agent_param for {env_name}: {self.pytree2array(agent_param).size}" - ) - training_state = TrainingState( - agent_param=agent_param, - agent_optim_state=self.agent_optimizer.init(agent_param), - ) - # Intialize env_states over cores and batch - seed, *env_seeds = random.split(seed, self.core_count * self.batch_size + 1) - env_states = jax.vmap(self.env.reset)(jnp.stack(env_seeds)) - env_states = tree_util.tree_map(self.reshape, env_states) - seed, *step_seeds = random.split( - seed, self.core_count * self.batch_size + 1 - ) - step_seeds = self.reshape(jnp.stack(step_seeds)) - # Replicate the training process over multiple cores - pmap_train_iterations = jax.pmap( - self.train_iterations, - in_axes=([None, 0, 0],), - out_axes=([None, 0, 0], (0, None)), - axis_name="core", - ) - carry_in = [training_state, env_states, step_seeds] - carry_out, logs = pmap_train_iterations(carry_in) - # Process and save logs - return_logs, grad_logs = logs - return_logs["agent_grad"] = grad_logs["agent_grad"] - return_logs["agent_param_update"] = grad_logs["agent_param_update"] - self.process_logs(env_name, return_logs) - - def process_logs(self, env_name, logs): - # Move logs to CPU, with shape {[core_count, iterations, batch_size, *]} - logs = jax.device_get(logs) - # Reshape to {[iterations, core_count, batch_size, *]} - for k in logs.keys(): - logs[k] = logs[k].swapaxes(0, 1) - # Compute episode return - episode_return, step_list = self.get_episode_return( - logs["done"], logs["reward"] - ) - result = { - "Env": env_name, - "Agent": self.agent_name, - "Step": step_list * self.macro_step, - "Return": episode_return, - } - # Save logs - self.save_logs(env_name, result) - # Save agent_grad and agent_param_update into a npz file: (num_param, optimization_steps) - self.logger.info( - f"# of agent_param collected for {env_name}: {logs['agent_grad'].shape[0]}" - ) - x = logs["agent_grad"] - y = logs["agent_param_update"] - np.savez(self.cfg["logs_dir"] + "data.npz", x=x, y=y) - # Plot - grad = x.reshape(-1) - abs_update = np.abs(y).reshape(-1) - # Plot log(|g|) - abs_grad = np.abs(grad) - self.logger.info( - f"|g|: min = {abs_grad.min():.4f}, max = {abs_grad.max():.4f}, mean = {abs_grad.mean():.4f}" - ) - log_abs_grad = np.log10(abs_grad + 1e-16) - self.logger.info( - f"log(|g|+1e-16): min = {log_abs_grad.min():.4f}, max = {log_abs_grad.max():.4f}, mean = {log_abs_grad.mean():.4f}" - ) - num, bins, patches = plt.hist(log_abs_grad, bins=20) - plt.xlabel(r"$\log(|g|+10^{-16})$") - plt.ylabel("Counts in the bin") - plt.grid(True) - plt.savefig(self.cfg["logs_dir"] + "grad.png") - plt.clf() - plt.cla() - plt.close() - # Plot log(|update|) - self.logger.info( - f"|update|: min = {abs_update.min():.4f}, max = {abs_update.max():.4f}, mean = {abs_update.mean():.4f}" - ) - log_abs_update = np.log10(abs_update + 1e-16) - self.logger.info( - f"log(|update|): min = {log_abs_update.min():.4f}, max = {log_abs_update.max():.4f}, mean = {log_abs_update.mean():.4f}" - ) - num, bins, patches = plt.hist(log_abs_update, bins=20) - plt.xlabel(r"$\log(|\Delta \theta|+10^{-16})$") - plt.ylabel("Counts in the bin") - plt.grid(True) - plt.savefig(self.cfg["logs_dir"] + "update.png") - plt.clf() - plt.cla() - plt.close() diff --git a/agents/CollectPPO.py b/agents/CollectPPO.py deleted file mode 100644 index 0505619..0000000 --- a/agents/CollectPPO.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright 2022 Garena Online Private Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2022 The Brax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import time -from typing import Tuple - -import flax -import jax -import jax.numpy as jnp -import matplotlib.pyplot as plt -import numpy as np -import optax -import pandas as pd -from brax import envs -from brax import jumpy as jp -from brax.envs import wrappers -from brax.training import acting, types -from brax.training.acme import specs -from jax import lax -from jax.tree_util import tree_leaves, tree_map - -import components.losses as ppo_losses -from components import gradients, ppo_networks, running_statistics -from components.optim import set_optimizer -from utils.logger import Logger - -InferenceParams = Tuple[running_statistics.NestedMeanStd, types.Params] - - -@flax.struct.dataclass -class TrainingState: - """Contains training state for the learner.""" - - optimizer_state: optax.OptState - params: ppo_losses.PPONetworkParams - normalizer_param: running_statistics.RunningStatisticsState - env_step: jnp.ndarray - - -def pytree2array(values): - leaves = tree_leaves(lax.stop_gradient(values)) - a = jnp.concatenate(leaves, axis=None) - return a - - -class CollectPPO(object): - """ - Collect agent gradients and parameter updates during training PPO in Brax. - """ - - def __init__(self, cfg): - self.cfg = cfg - self.config_idx = cfg["config_idx"] - self.logger = Logger(cfg["logs_dir"]) - self.log_path = cfg["logs_dir"] + "result_Test.feather" - self.result = [] - # Set environment - self.env_name = cfg["env"]["name"] - self.agent_name = cfg["agent"]["name"] - self.train_steps = int(cfg["env"]["train_steps"]) - self.env = envs.get_environment(env_name=self.env_name) - self.state = self.env.reset(rng=jp.random_prngkey(seed=self.cfg["seed"])) - # Timing - self.start_time = time.time() - self._PMAP_AXIS_NAME = "i" - - def save_progress(self, step_count, metrics): - episode_return = float(jax.device_get(metrics["eval/episode_reward"])) - result_dict = { - "Env": self.env_name, - "Agent": self.agent_name, - "Step": step_count, - "Return": episode_return, - } - self.result.append(result_dict) - # Save result to files - result = pd.DataFrame(self.result) - result["Env"] = result["Env"].astype("category") - result["Agent"] = result["Agent"].astype("category") - result.to_feather(self.log_path) - # Show log - speed = step_count / (time.time() - self.start_time) - eta = (self.train_steps - step_count) / speed / 60 if speed > 0 else -1 - return episode_return, speed, eta - - def train(self): - # Env - env = self.env - num_timesteps = self.train_steps - episode_length = self.cfg["env"]["episode_length"] - action_repeat = self.cfg["env"]["action_repeat"] - reward_scaling = self.cfg["env"]["reward_scaling"] - num_envs = self.cfg["env"]["num_envs"] - num_evals = self.cfg["env"]["num_evals"] - num_eval_envs = 128 - normalize_observations = self.cfg["env"]["normalize_obs"] - # Agent - network_factory = ppo_networks.make_ppo_networks - gae_lambda = self.cfg["agent"]["gae_lambda"] - unroll_length = self.cfg["agent"]["rollout_steps"] - num_minibatches = self.cfg["agent"]["num_minibatches"] - clip_ratio = self.cfg["agent"]["clip_ratio"] - update_epochs = self.cfg["agent"]["update_epochs"] - entropy_cost = self.cfg["agent"]["entropy_weight"] - normalize_advantage = True - # Optimization - batch_size = self.cfg["batch_size"] - discounting = self.cfg["discount"] - max_devices_per_host = self.cfg["max_devices_per_host"] - # Others - seed = self.cfg["seed"] - eval_env = None - deterministic_eval = False - progress_fn = self.save_progress - logs = dict(agent_grad=[], agent_param_update=[]) - - """PPO training.""" - process_id = jax.process_index() - process_count = jax.process_count() - total_device_count = jax.device_count() - local_device_count = jax.local_device_count() - if max_devices_per_host is not None and max_devices_per_host > 0: - local_devices_to_use = min(local_device_count, max_devices_per_host) - else: - local_devices_to_use = local_device_count - self.logger.info( - f"Total device: {total_device_count}, Process: {process_count} (ID {process_id})" - ) - self.logger.info( - f"Local device: {local_device_count}, Devices to be used: {local_devices_to_use}" - ) - device_count = local_devices_to_use * process_count - assert num_envs % device_count == 0 - assert batch_size * num_minibatches % num_envs == 0 - - # The number of environment steps executed for every training step. - env_step_per_training_step = ( - batch_size * unroll_length * num_minibatches * action_repeat - ) - num_evals = max(num_evals, 1) - # The number of training_step calls per training_epoch call. - num_training_steps_per_epoch = num_timesteps // ( - num_evals * env_step_per_training_step - ) - self.logger.info( - f"num_minibatches={num_minibatches}, update_epochs={update_epochs}, num_training_steps_per_epoch={num_training_steps_per_epoch}, num_evals={num_evals}" - ) - - # Prepare keys - # key_networks should be global so that - # the initialized networks are the same for different processes. - key = jax.random.PRNGKey(seed) - global_key, local_key = jax.random.split(key) - local_key = jax.random.fold_in(local_key, process_id) - local_key, key_env, eval_key = jax.random.split(local_key, 3) - key_policy, key_value, key_optim = jax.random.split(global_key, 3) - del key, global_key - key_envs = jax.random.split(key_env, num_envs // process_count) - key_envs = jnp.reshape( - key_envs, (local_devices_to_use, -1) + key_envs.shape[1:] - ) - - # Set training and evaluation env - env = wrappers.wrap_for_training( - env, episode_length=episode_length, action_repeat=action_repeat - ) - reset_fn = jax.jit(jax.vmap(env.reset)) - env_states = reset_fn(key_envs) - if eval_env is None: - eval_env = env - else: - eval_env = wrappers.wrap_for_training( - eval_env, episode_length=episode_length, action_repeat=action_repeat - ) - - # Set optimizer - optimizer = set_optimizer( - self.cfg["optimizer"]["name"], self.cfg["optimizer"]["kwargs"], key_optim - ) - - # Set PPO network - if normalize_observations: - normalize = running_statistics.normalize - else: - normalize = lambda x, y: x - ppo_network = network_factory( - env.observation_size, - env.action_size, - preprocess_observations_fn=normalize, - ) - make_policy = ppo_networks.make_inference_fn(ppo_network) - init_params = ppo_losses.PPONetworkParams( - policy=ppo_network.policy_network.init(key_policy), - value=ppo_network.value_network.init(key_value), - ) - training_state = TrainingState( - optimizer_state=optimizer.init(init_params), - params=init_params, - normalizer_param=running_statistics.init_state( - specs.Array((env.observation_size,), jnp.float32) - ), - env_step=0, - ) - - # Set loss function - loss_fn = functools.partial( - ppo_losses.compute_ppo_loss, - ppo_network=ppo_network, - entropy_cost=entropy_cost, - discounting=discounting, - reward_scaling=reward_scaling, - gae_lambda=gae_lambda, - clip_ratio=clip_ratio, - normalize_advantage=normalize_advantage, - ) - gradient_update_fn = gradients.gradient_update_fn( - loss_fn, optimizer, pmap_axis_name=self._PMAP_AXIS_NAME, has_aux=True - ) - - def convert_data(x: jnp.ndarray, key: types.PRNGKey): - x = jax.random.permutation(key, x) - x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:]) - return x - - def minibatch_step( - carry, - data: types.Transition, - normalizer_param: running_statistics.RunningStatisticsState, - ): - optimizer_state, params, key = carry - key, key_loss = jax.random.split(key) - ( - (loss, _), - params, - optimizer_state, - grads, - params_update, - ) = gradient_update_fn( - params, - normalizer_param, - data, - key_loss, - optimizer_state=optimizer_state, - ) - grads = pytree2array(grads) - params_update = pytree2array(params_update) - idxs = jnp.array(range(0, len(grads), self.cfg["agent"]["data_reduce"])) - grads, params_update = grads[idxs], params_update[idxs] - return (optimizer_state, params, key), (grads, params_update) - - def sgd_step( - carry, - unused_t, - data: types.Transition, - normalizer_param: running_statistics.RunningStatisticsState, - ): - optimizer_state, params, key = carry - key, key_perm, key_grad = jax.random.split(key, 3) - shuffled_data = tree_map( - functools.partial(convert_data, key=key_perm), data - ) - (optimizer_state, params, key_grad), (grads, params_update) = lax.scan( - f=functools.partial(minibatch_step, normalizer_param=normalizer_param), - init=(optimizer_state, params, key_grad), - xs=shuffled_data, - length=num_minibatches, - ) - return (optimizer_state, params, key), (grads, params_update) - - def training_step( - carry: Tuple[TrainingState, envs.State, types.PRNGKey], unused_t - ) -> Tuple[Tuple[TrainingState, envs.State, types.PRNGKey], types.Metrics]: - training_state, state, key = carry - key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) - policy = make_policy( - (training_state.normalizer_param, training_state.params.policy) - ) - - # Set rollout function - def rollout(carry, unused_t): - current_state, current_key = carry - current_key, next_key = jax.random.split(current_key) - next_state, data = acting.generate_unroll( - env, - current_state, - policy, - current_key, - unroll_length, - extra_fields=("truncation",), - ) - return (next_state, next_key), data - - # Rollout for `batch_size * num_minibatches * unroll_length` steps - (state, _), data = lax.scan( - f=rollout, - init=(state, key_generate_unroll), - xs=None, - length=batch_size * num_minibatches // num_envs, - ) - # shape = (batch_size * num_minibatches // num_envs, unroll_length, num_envs) - data = tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) - # shape = (batch_size * num_minibatches // num_envs, num_envs, unroll_length) - data = tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data) - # shape = (batch_size * num_minibatches, unroll_length) - assert data.discount.shape[1:] == (unroll_length,) - # Update normalization params and normalize observations. - normalizer_param = running_statistics.update( - training_state.normalizer_param, - data.observation, - pmap_axis_name=self._PMAP_AXIS_NAME, - ) - # SGD steps - (optimizer_state, params, key_sgd), (grads, params_update) = lax.scan( - f=functools.partial( - sgd_step, data=data, normalizer_param=normalizer_param - ), - init=(training_state.optimizer_state, training_state.params, key_sgd), - xs=None, - length=update_epochs, - ) - # Set the new training state - new_training_state = TrainingState( - optimizer_state=optimizer_state, - params=params, - normalizer_param=normalizer_param, - env_step=training_state.env_step + env_step_per_training_step, - ) - return (new_training_state, state, new_key), (grads, params_update) - - def training_epoch( - training_state: TrainingState, state: envs.State, key: types.PRNGKey - ) -> Tuple[TrainingState, envs.State, types.Metrics]: - (training_state, state, _), (grads, params_update) = lax.scan( - f=training_step, - init=(training_state, state, key), - xs=None, - length=num_training_steps_per_epoch, - ) - return training_state, state, (grads, params_update) - - pmap_training_epoch = jax.pmap( - training_epoch, - in_axes=(None, 0, 0), - out_axes=(None, 0, None), - devices=jax.local_devices()[:local_devices_to_use], - axis_name=self._PMAP_AXIS_NAME, - ) - - # Set evaluator - evaluator = acting.Evaluator( - eval_env, - functools.partial(make_policy, deterministic=deterministic_eval), - num_eval_envs=num_eval_envs, - episode_length=episode_length, - action_repeat=action_repeat, - key=eval_key, - ) - - # Run an initial evaluation - i, current_step = 0, 0 - if process_id == 0 and num_evals > 1: - metrics = evaluator.run_evaluation( - (training_state.normalizer_param, training_state.params.policy), - training_metrics={}, - ) - episode_return, speed, eta = progress_fn(0, metrics) - self.logger.info( - f"<{self.config_idx}> Iteration {i}/{num_evals}, Step {current_step}, Return={episode_return:.2f}, Speed={speed:.2f} (steps/s), ETA={eta:.2f} (mins)" - ) - - # Start training - for i in range(1, num_evals + 1): - epoch_key, local_key = jax.random.split(local_key) - epoch_keys = jax.random.split(epoch_key, local_devices_to_use) - # Train for one epoch - training_state, env_states, (grads, params_update) = pmap_training_epoch( - training_state, env_states, epoch_keys - ) - current_step = int(training_state.env_step) - # Run evaluation - if process_id == 0: - metrics = evaluator.run_evaluation( - (training_state.normalizer_param, training_state.params.policy), - training_metrics={}, - ) - episode_return, speed, eta = progress_fn(current_step, metrics) - self.logger.info( - f"<{self.config_idx}> Iteration {i}/{num_evals}, Step {current_step}, Return={episode_return:.2f}, Speed={speed:.2f} (steps/s), ETA={eta:.2f} (mins)" - ) - # Save grads, params_update - logs["agent_grad"].append(grads) - logs["agent_param_update"].append(params_update) - self.process_logs(logs) - - def process_logs(self, logs): - # Stack to shape {[iterations, num_training_steps_per_epoch, update_epochs, num_minibatches, *]} - logs = {k: jnp.stack(v, axis=0) for k, v in logs.items()} - # Reshape to {[iterations * num_training_steps_per_epoch * update_epochs * num_minibatches, *]} - logs = tree_map(lambda x: jnp.reshape(x, (-1, x.shape[-1])), logs) - # Move logs to CPU, with shape {[core_count, iterations, batch_size, *]} - logs = jax.device_get(logs) - # Shape to: (num_param, optimization_steps) - # Save agent_grad and agent_param_update into a npz file - x = logs["agent_grad"].swapaxes(0, 1) - y = ( - logs["agent_param_update"].swapaxes(0, 1) - / self.cfg["optimizer"]["kwargs"]["learning_rate"] - ) - self.logger.info(f"x: {x.shape}, y: {y.shape}") - np.savez(self.cfg["logs_dir"] + "data.npz", x=x, y=y) - - # Plot - grad = x.reshape(-1) - abs_update = np.abs(y).reshape(-1) - # Plot log(|g|) - abs_grad = np.abs(grad) - self.logger.info( - f"|g|: min = {abs_grad.min():.4f}, max = {abs_grad.max():.4f}, mean = {abs_grad.mean():.4f}" - ) - log_abs_grad = np.log10(abs_grad + 1e-16) - self.logger.info( - f"log(|g|+1e-16): min = {log_abs_grad.min():.4f}, max = {log_abs_grad.max():.4f}, mean = {log_abs_grad.mean():.4f}" - ) - num, bins, patches = plt.hist(log_abs_grad, bins=20) - plt.xlabel(r"$\log(|g|+10^{-16})$") - plt.ylabel("Counts in the bin") - plt.grid(True) - plt.savefig(self.cfg["logs_dir"] + "grad.png") - plt.clf() - plt.cla() - plt.close() - # Plot log(|update|) - self.logger.info( - f"|update|: min = {abs_update.min():.4f}, max = {abs_update.max():.4f}, mean = {abs_update.mean():.4f}" - ) - log_abs_update = np.log10(abs_update + 1e-16) - self.logger.info( - f"log(|update|): min = {log_abs_update.min():.4f}, max = {log_abs_update.max():.4f}, mean = {log_abs_update.mean():.4f}" - ) - num, bins, patches = plt.hist(log_abs_update, bins=20) - plt.xlabel(r"$\log(|\Delta \theta|+10^{-16})$") - plt.ylabel("Counts in the bin") - plt.grid(True) - plt.savefig(self.cfg["logs_dir"] + "update.png") - plt.clf() - plt.cla() - plt.close() diff --git a/agents/MetaA2C.py b/agents/MetaA2C.py index 6d28c21..26c6018 100644 --- a/agents/MetaA2C.py +++ b/agents/MetaA2C.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,308 +12,232 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools import time +import optax +import numpy as np +from functools import partial import jax import jax.numpy as jnp -import numpy as np -import optax -import rlax -from jax import lax, random, tree_util +from jax import jit, lax, tree_util -from agents.A2C import A2C, TrainingState -from components.optim import set_meta_optimizer -from utils.helper import tree_transpose +from components.optim import set_optim +from utils.helper import jitted_split, tree_transpose, save_model_param +from agents.A2C import A2C, MyTrainState class MetaA2C(A2C): + ''' + Implementation of Meta A2C + ''' + def __init__(self, cfg): + super().__init__(cfg) + # Set meta optimizer + self.seed, optim_seed = jitted_split(self.seed) + self.cfg['meta_optim']['kwargs'].setdefault('max_norm', -1) + self.max_norm = self.cfg['meta_optim']['kwargs']['max_norm'] + del self.cfg['meta_optim']['kwargs']['max_norm'] + self.meta_optim = set_optim(self.cfg['meta_optim']['name'], cfg['meta_optim']['kwargs'], optim_seed) + # Set reset_indexes + if isinstance(cfg['agent']['reset_interval'], int): + self.reset_intervals = [cfg['agent']['reset_interval']] * self.task_num + elif isinstance(cfg['agent']['reset_interval'], list): + self.reset_intervals = cfg['agent']['reset_interval'].copy() + else: + raise TypeError('Only List[int] or int is allowed') + self.reset_indexes = [None]*self.task_num + for i in range(self.task_num): + reset_indexes = [int(x) for x in jnp.linspace(0, self.reset_intervals[i]-1, num=self.num_envs)] + self.reset_indexes[i] = self.reshape(jnp.array(reset_indexes)) + + def abs_sq(self, x: jax.Array) -> jax.Array: + """Returns the squared norm of a (maybe complex) array. + Copy from https://github.com/deepmind/optax/blob/master/optax/_src/numerics.py """ - Meta-train a learned optimizer during traing A2C in gridworlds, compatible with LinearOptim, Optim4RL, and L2LGD2. - """ - - def __init__(self, cfg): - super().__init__(cfg) - # Set meta optimizer - self.seed, optim_seed = random.split(self.seed) - if "max_norm" in self.cfg["meta_optimizer"]["kwargs"].keys(): - self.max_norm = self.cfg["meta_optimizer"]["kwargs"]["max_norm"] - del self.cfg["meta_optimizer"]["kwargs"]["max_norm"] - else: - self.max_norm = -1 - self.meta_optimizer = set_meta_optimizer( - self.cfg["meta_optimizer"]["name"], - self.cfg["meta_optimizer"]["kwargs"], - optim_seed, - ) - # Set reset_indexes - if isinstance(cfg["agent"]["reset_interval"], int): - self.reset_intervals = [cfg["agent"]["reset_interval"]] * len( - self.env_names - ) - elif isinstance(cfg["agent"]["reset_interval"], list): - self.reset_intervals = cfg["agent"]["reset_interval"] - else: - raise TypeError("Only List[int] or int is allowed") - self.reset_indexes = dict() - for i, env_name in enumerate(self.env_names): - reset_indexes = [ - int(x) - for x in jnp.linspace(0, self.reset_intervals[i] - 1, num=self.num_envs) - ] - self.reset_indexes[env_name] = self.reshape(jnp.array(reset_indexes)) - - def compute_meta_loss(self, agent_param, env_state, step_seed): - # Move for rollout_steps - env_state, rollout = self.move_rollout_steps(agent_param, env_state, step_seed) - last_obs = self.env.render_obs(env_state) - all_obs = jnp.concatenate([rollout.obs, jnp.expand_dims(last_obs, 0)], axis=0) - logits, v = self.agent_net.apply(agent_param, all_obs) - # Compute multi-step temporal difference error - td_error = rlax.td_lambda( - v_tm1=v[:-1], - r_t=rollout.reward, - discount_t=self.discount * (1.0 - rollout.done), - v_t=v[1:], - lambda_=self.cfg["agent"]["gae_lambda"], - stop_target_gradients=True, - ) - # Compute actor loss - actor_loss = rlax.policy_gradient_loss( - logits_t=logits[:-1], - a_t=rollout.action, - adv_t=td_error, - w_t=jnp.ones_like(td_error), - use_stop_gradient=True, - ) - return actor_loss, env_state - - def agent_update(self, carry_in, _): - """Perform a step of inner update to the agent.""" - meta_param, training_state, env_state, seed, lr = carry_in - seed, step_seed = random.split(seed) - # Generate one rollout and compute agent gradient - agent_grad, (env_state, rollout) = jax.grad( - self.compute_agent_loss, has_aux=True - )(training_state.agent_param, env_state, step_seed) - # Update agent parameters - agent_param_update, agent_optim_state = self.agent_optimizer.update_with_param( - meta_param, agent_grad, training_state.agent_optim_state, lr - ) - agent_param = optax.apply_updates( - training_state.agent_param, agent_param_update - ) - # Set new training_state - training_state = training_state.replace( - agent_param=agent_param, agent_optim_state=agent_optim_state - ) - carry_out = [meta_param, training_state, env_state, seed, lr] - return carry_out, None - - def agent_update_and_meta_loss(self, meta_param, carry_in): - """Update agent param and compute meta loss with the last rollout.""" - # Perform inner updates - carry_in = [meta_param] + carry_in - carry_out, _ = lax.scan( - f=self.agent_update, init=carry_in, length=self.inner_updates, xs=None - ) - meta_param, training_state, env_state, step_seed, lr = carry_out - # Use the last rollout as the validation data to compute meta loss - meta_loss, env_state = self.compute_meta_loss( - training_state.agent_param, env_state, step_seed - ) - carry_out = [training_state, env_state] - return meta_loss, carry_out + if not isinstance(x, (np.ndarray, jnp.ndarray)): + raise ValueError(f"`abs_sq` accepts only NDarrays, got: {x}.") + return (x.conj() * x).real - def learn(self, carry_in): - """Two level updates for meta_param (outer update) and agent_param (inner update).""" - training_state, env_state, seed, lr = carry_in - # Perform inner updates and compute meta gradient. - seed, step_seed = random.split(seed) - carry_in = [training_state, env_state, step_seed, lr] - meta_param = training_state.agent_optim_state.optim_param - meta_grad, carry_out = jax.grad(self.agent_update_and_meta_loss, has_aux=True)( - meta_param, carry_in - ) - training_state, env_state = carry_out - # Reduce mean gradient across batch an cores - meta_grad = lax.pmean(meta_grad, axis_name="batch") - meta_grad = lax.pmean(meta_grad, axis_name="core") - carry_out = [meta_grad, training_state, env_state, seed, lr] - return carry_out - - def get_training_state(self, seed, obs): - agent_param = self.agent_net.init(seed, obs) - training_state = TrainingState( - agent_param=agent_param, - agent_optim_state=self.agent_optimizer.init(agent_param), - ) - return training_state - - def reset_agent_training( - self, - training_state, - env_state, - reset_index, - seed, - optim_param, - iter_num, - agent_reset_interval, - obs, - ): - # Select the new one if iter_num % agent_reset_interval == reset_index - def f_select(n_s, o_s): - return lax.select(iter_num % agent_reset_interval == reset_index, n_s, o_s) - - # Generate a new training_state and env_state - new_training_state = self.get_training_state(seed, obs) - new_env_state = self.env.reset(seed) - # Select the new training_state - training_state = tree_util.tree_map( - f_select, new_training_state, training_state - ) - env_state = tree_util.tree_map(f_select, new_env_state, env_state) - # Update optim_param - agent_optim_state = training_state.agent_optim_state - agent_optim_state = agent_optim_state.replace(optim_param=optim_param) - training_state = training_state.replace(agent_optim_state=agent_optim_state) - return training_state, env_state - - def abs_sq(self, x): - """Returns the squared norm of a (maybe complex) array. - Copied from https://github.com/deepmind/optax/blob/master/optax/_src/numerics.py - """ - if not isinstance(x, (np.ndarray, jnp.ndarray)): - raise ValueError(f"`abs_sq` accepts only NDarrays, got: {x}.") - return (x.conj() * x).real - - def global_norm(self, updates): - """ - Compute the global norm across a nested structure of tensors. - Copied from https://github.com/deepmind/optax/blob/master/optax/_src/linear_algebra.py - """ - return jnp.sqrt( - sum(jnp.sum(self.abs_sq(x)) for x in tree_util.tree_leaves(updates)) - ) - - def train(self): - seed = self.seed - # Initialize pmap_train_one_iteration and carries: hidden_state, agent_param, agent_optim_state, env_states, step_seeds - carries = dict() - pmap_train_one_iterations = dict() - pvmap_reset_agent_training = dict() - for i, env_name in enumerate(self.env_names): - # Generate random seeds for env and agent - seed, env_seed, agent_seed = random.split(seed, num=3) - # Set environment and agent network - self.env, self.agent_net = self.envs[i], self.agent_nets[i] - # Initialize agent parameter and optimizer - dummy_obs = self.env.render_obs(self.env.reset(env_seed))[None, :] - pvmap_reset_agent_training[env_name] = jax.pmap( - jax.vmap( - functools.partial(self.reset_agent_training, obs=dummy_obs), - in_axes=(0, 0, 0, 0, None, None, None), - ), - in_axes=(0, 0, 0, 0, None, None, None), - ) - # We initialize core_count*batch_size different agent parameters and optimizer states. - pvmap_get_training_state = jax.pmap( - jax.vmap(self.get_training_state, in_axes=(0, None)), in_axes=(0, None) - ) - agent_seed, *agent_seeds = random.split( - agent_seed, self.core_count * self.batch_size + 1 - ) - agent_seeds = self.reshape(jnp.stack(agent_seeds)) - training_states = pvmap_get_training_state(agent_seeds, dummy_obs) - # Intialize env_states over cores and batch - seed, *env_seeds = random.split(seed, self.core_count * self.batch_size + 1) - env_states = jax.vmap(self.env.reset)(jnp.stack(env_seeds)) - env_states = tree_util.tree_map(self.reshape, env_states) - seed, *step_seeds = random.split( - seed, self.core_count * self.batch_size + 1 - ) - step_seeds = self.reshape(jnp.stack(step_seeds)) - # Save in carries dict - carries[env_name] = [ - training_states, - env_states, - step_seeds, - self.learning_rates[i], - ] - # Replicate the training process over multiple cores - batched_learn = jax.vmap( - self.learn, - in_axes=([0, 0, 0, None],), - out_axes=[None, 0, 0, 0, None], - axis_name="batch", - ) - pmap_train_one_iterations[env_name] = jax.pmap( - batched_learn, - in_axes=([0, 0, 0, None],), - out_axes=[None, 0, 0, 0, None], - axis_name="core", - ) - - self.meta_param = self.agent_optimizer.optim_param - self.meta_optim_state = self.meta_optimizer.init(self.meta_param) - # Train for self.iterations for each env - for t in range(1, self.iterations + 1): - meta_grads = [] - start_time = time.time() - for i, env_name in enumerate(self.env_names): - # Set environment and agent network - self.env, self.agent_net = self.envs[i], self.agent_nets[i] - # Reset agent training: agent_param, hidden_state, env_state - # and update meta parameter (i.e. optim_param) - training_states, env_states = carries[env_name][0], carries[env_name][1] - seed, *reset_seeds = random.split( - seed, self.core_count * self.batch_size + 1 - ) - reset_seeds = self.reshape(jnp.stack(reset_seeds)) - training_states, env_states = pvmap_reset_agent_training[env_name]( - training_states, - env_states, - self.reset_indexes[env_name], - reset_seeds, - self.meta_param, - t - 1, - self.reset_intervals[i], - ) - carries[env_name][0], carries[env_name][1] = training_states, env_states - # Train for one iteration - carry_in = carries[env_name] - carry_out = pmap_train_one_iterations[env_name](carry_in) - # Update carries - carries[env_name] = carry_out[1:] - # Gather meta grad and process - meta_grad = carry_out[0] - if self.max_norm > 0: - g_norm = self.global_norm(meta_grad) - meta_grad = tree_util.tree_map( - lambda x: (x / g_norm.astype(x.dtype)) * self.max_norm, - meta_grad, - ) - meta_grads.append(meta_grad) - # Update meta paramter - meta_grad = tree_transpose(meta_grads) - meta_grad = tree_util.tree_map(lambda x: jnp.mean(x, axis=0), meta_grad) - # Update meta parameter - meta_param_update, self.meta_optim_state = self.meta_optimizer.update( - meta_grad, self.meta_optim_state - ) - self.meta_param = optax.apply_updates(self.meta_param, meta_param_update) - # Show log - if t % self.cfg["display_interval"] == 0: - step_count = t * self.macro_step - speed = self.macro_step / (time.time() - start_time) - eta = (self.train_steps - step_count) / speed / 60 if speed > 0 else -1 - self.logger.info( - f"<{self.config_idx}> Step {step_count}/{self.train_steps} Iteration {t}/{self.iterations}: Speed={speed:.2f} (steps/s), ETA={eta:.2f} (mins)" - ) - # Save meta param - if (self.cfg["save_param"] > 0 and t % self.cfg["save_param"] == 0) or ( - t == self.iterations - ): - self.save_model_param( - self.meta_param, self.cfg["logs_dir"] + f"param{t}.pickle" - ) + def global_norm(self, updates): + """ + Compute the global norm across a nested structure of tensors. + Copy from https://github.com/deepmind/optax/blob/master/optax/_src/linear_algebra.py + """ + return jnp.sqrt(sum(jnp.sum(self.abs_sq(x)) for x in tree_util.tree_leaves(updates))) + + @partial(jit, static_argnames=['self', 'i']) + def agent_update(self, carry_in, i): + '''Perform a step of inner update to the agent.''' + meta_param, training_state, env_state, seed, lr = carry_in + seed, step_seed = jitted_split(seed) + # Generate one rollout and compute agent gradient + agent_grad, (env_state, rollout) = jax.grad(self.compute_loss, has_aux=True)(training_state.agent_param, env_state, step_seed, i) + # Update agent parameters + param_update, new_optim_state = self.agent_optim.update_with_param( + meta_param, agent_grad, training_state.agent_optim_state, lr + ) + new_param = optax.apply_updates(training_state.agent_param, param_update) + # Set new training_state + training_state = training_state.replace( + agent_param = new_param, + agent_optim_state = new_optim_state + ) + carry_out = (meta_param, training_state, env_state, seed, lr) + return carry_out, i + + @partial(jit, static_argnames=['self', 'i']) + def agent_update_and_meta_loss(self, meta_param, carry_in, i): + '''Update agent param and compute meta loss with the last rollout.''' + # Perform inner updates + carry = (meta_param,) + carry_in + carry, _ = lax.scan( + f = lambda carry, _: self.agent_update(carry, i), + init = carry, + length = self.inner_updates, + xs = None + ) + meta_param, training_state, env_state, step_seed, lr = carry + # Use the last rollout as the validation data to compute meta loss + meta_loss, (env_state, rollout) = self.compute_loss(training_state.agent_param, env_state, step_seed, i) + return meta_loss, (training_state, env_state) + + @partial(jit, static_argnames=['self', 'i']) + def learn(self, carry_in, i): + '''Two level updates for meta_param (outer update) and agent_param (inner update).''' + meta_param, training_state, env_state, seed, lr = carry_in + # Perform inner updates and compute meta gradient. + seed, step_seed = jitted_split(seed) + carry_in = (training_state, env_state, step_seed, lr) + meta_grad, (training_state, env_state) = jax.grad(self.agent_update_and_meta_loss, has_aux=True)(meta_param, carry_in, i) + # Reduce mean gradient across batch an cores + meta_grad = lax.pmean(meta_grad, axis_name='batch') + meta_grad = lax.pmean(meta_grad, axis_name='core') + carry_out = (meta_grad, training_state, env_state, seed, lr) + return carry_out + + @partial(jit, static_argnames=['self', 'i']) + def get_training_state(self, seed, dummy_obs, i): + agent_param = self.agent_nets[i].init(seed, dummy_obs) + training_state = MyTrainState( + agent_param = agent_param, + agent_optim_state = self.agent_optim.init(agent_param) + ) + return training_state + + @partial(jit, static_argnames=['self', 'i']) + def reset_agent_training(self, training_state, env_state, reset_index, seed, optim_param, iter_num, dummy_obs, i): + # Select the new one if iter_num % agent_reset_interval == reset_index + f_select = lambda n_s, o_s: lax.select(iter_num % self.reset_intervals[i] == reset_index, n_s, o_s) + # Generate a new training_state and env_state + new_training_state = self.get_training_state(seed, dummy_obs, i) + new_env_state = self.envs[i].reset(seed) + # Select the new training_state + training_state = tree_util.tree_map(f_select, new_training_state, training_state) + env_state = tree_util.tree_map(f_select, new_env_state, env_state) + # Update optim_param + agent_optim_state = training_state.agent_optim_state + agent_optim_state = agent_optim_state.replace(optim_param=optim_param) + training_state = training_state.replace(agent_optim_state=agent_optim_state) + return training_state, env_state + + def train(self): + seed = self.seed + # Initialize pmap_train_one_iteration and carries (hidden_state, agent_param, agent_optim_state, env_states, step_seeds) + carries = [None] * self.task_num + dummy_obs = [None] * self.task_num + pmap_train_one_iterations = [None] * self.task_num + pvmap_reset_agent_training = [None] * self.task_num + for i in range(self.task_num): + # Generate random seeds for env and agent + seed, env_seed, agent_seed = jitted_split(seed, num=3) + # Initialize agent parameter and optimizer + dummy_obs[i] = self.envs[i].render_obs(self.envs[i].reset(env_seed))[None,] + pvmap_reset_agent_training[i] = jax.pmap( + jax.vmap( + self.reset_agent_training, + in_axes = (0, 0, 0, 0, None, None, None, None), + out_axes= (0, 0), + axis_name = 'batch' + ), + in_axes = (0, 0, 0, 0, None, None, None, None), + out_axes= (0, 0), + axis_name = 'core', + static_broadcasted_argnums = (7) + ) + # We initialize core_count*batch_size different agent parameters and optimizer states. + pvmap_get_training_state = jax.pmap( + jax.vmap( + self.get_training_state, + in_axes = (0, None, None), + out_axes = (0), + axis_name = 'batch' + ), + in_axes = (0, None, None), + out_axes = (0), + axis_name = 'core', + static_broadcasted_argnums = (2) + ) + agent_seed, *agent_seeds = jitted_split(agent_seed, self.core_count * self.batch_size + 1) + agent_seeds = self.reshape(jnp.stack(agent_seeds)) + training_states = pvmap_get_training_state(agent_seeds, dummy_obs[i], i) + # Intialize env_states over cores and batch + seed, *env_seeds = jitted_split(seed, self.core_count * self.batch_size + 1) + env_states = jax.vmap(self.envs[i].reset)(jnp.stack(env_seeds)) + env_states = tree_util.tree_map(self.reshape, env_states) + seed, *step_seeds = jitted_split(seed, self.core_count * self.batch_size + 1) + step_seeds = self.reshape(jnp.stack(step_seeds)) + # Save in carries dict + carries[i] = (training_states, env_states, step_seeds, self.learning_rates[i]) + # Replicate the training process over multiple cores + pmap_train_one_iterations[i] = jax.pmap( + jax.vmap( + self.learn, + in_axes = ((None, 0, 0, 0, None), None), + out_axes = (None, 0, 0, 0, None), + axis_name = 'batch' + ), + in_axes = ((None, 0, 0, 0, None), None), + out_axes = (None, 0, 0, 0, None), + axis_name = 'core', + static_broadcasted_argnums = (1) + ) + + self.meta_param = self.agent_optim.optim_param + self.meta_optim_state = self.meta_optim.init(self.meta_param) + # Train for self.iterations for each env + for t in range(1, self.iterations+1): + start_time = time.time() + meta_grads = [] + for i in range(self.task_num): + # Reset agent training: agent_param, hidden_state, env_state + # and update meta parameter (i.e. optim_param) + training_states, env_states = carries[i][0], carries[i][1] + seed, *reset_seeds = jitted_split(seed, self.core_count * self.batch_size + 1) + reset_seeds = self.reshape(jnp.stack(reset_seeds)) + training_states, env_states = pvmap_reset_agent_training[i](training_states, env_states, self.reset_indexes[i], reset_seeds, self.meta_param, t-1, dummy_obs[i], i) + carries[i] = list(carries[i]) + carries[i][0], carries[i][1] = training_states, env_states + carries[i] = tuple(carries[i]) + # Train for one iteration + carry_out = pmap_train_one_iterations[i]((self.meta_param,)+carries[i], i) + # Gather meta grad and update carries + meta_grad, carries[i] = carry_out[0], carry_out[1:] + if self.max_norm > 0: + g_norm = self.global_norm(meta_grad) + meta_grad = tree_util.tree_map(lambda x: (x / g_norm.astype(x.dtype)) * self.max_norm, meta_grad) + meta_grads.append(meta_grad) + # Update meta paramter + meta_grad = tree_transpose(meta_grads) + meta_grad = tree_util.tree_map(lambda x: jnp.mean(x, axis=0), meta_grad) + # Update meta parameter + meta_param_update, self.meta_optim_state = self.meta_optim.update(meta_grad, self.meta_optim_state) + self.meta_param = optax.apply_updates(self.meta_param, meta_param_update) + # Show log + if t % self.cfg['display_interval'] == 0: + step_count = t * self.macro_step + speed = self.macro_step / (time.time() - start_time) + eta = (self.train_steps - step_count) / speed / 60 if speed>0 else -1 + self.logger.info(f'<{self.config_idx}> Step {step_count}/{self.train_steps} Iteration {t}/{self.iterations}: Speed={speed:.2f} (steps/s), ETA={eta:.2f} (mins)') + # Save meta param + if t == self.iterations: + save_model_param(self.meta_param, self.cfg['logs_dir']+'param.pickle') \ No newline at end of file diff --git a/agents/MetaA2Cstar.py b/agents/MetaA2Cstar.py new file mode 100644 index 0000000..7816d4d --- /dev/null +++ b/agents/MetaA2Cstar.py @@ -0,0 +1,189 @@ +# Copyright 2024 Garena Online Private Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +from jax import jit, lax, tree_util + +import time +import flax +import optax +from typing import Any +from functools import partial + +from utils.helper import jitted_split, tree_transpose, save_model_param +from agents.MetaA2C import MetaA2C + + +@flax.struct.dataclass +class StarTrainingState: + agent_param: Any + agent_optim_param: Any + agent_optim_state: optax.OptState + + +class MetaA2Cstar(MetaA2C): + ''' + Implementation of Meta A2C with STAR optmizer only + ''' + def __init__(self, cfg): + super().__init__(cfg) + assert self.task_num == 1, 'Only single task training is supported in MetaA2Cstar for now.' + + @partial(jit, static_argnames=['self', 'i']) + def agent_update(self, carry_in, i): + '''Perform a step of inner update to the agent.''' + meta_param, training_state, env_state, seed, lr = carry_in + seed, step_seed = jitted_split(seed) + # Generate one rollout and compute agent gradient + (agent_loss, (env_state, rollout)), agent_grad = jax.value_and_grad(self.compute_loss, has_aux=True)(training_state.agent_param, env_state, step_seed, i) + # Update agent parameters + agent_optim_state = self.agent_optim.update_with_param( + meta_param, agent_grad, training_state.agent_optim_state, agent_loss + ) + # Set new training_state + training_state = training_state.replace( + agent_param = agent_optim_state.params, + agent_optim_state = agent_optim_state + ) + carry_out = (meta_param, training_state, env_state, seed, lr) + return carry_out, i + + @partial(jit, static_argnames=['self', 'i']) + def get_training_state(self, seed, dummy_obs, i): + agent_param = self.agent_nets[i].init(seed, dummy_obs) + training_state = StarTrainingState( + agent_param = agent_param, + agent_optim_param = self.agent_optim.get_optim_param(), + agent_optim_state = self.agent_optim.init(agent_param) + ) + return training_state + + @partial(jit, static_argnames=['self', 'i']) + def reset_agent_training(self, training_state, env_state, reset_index, seed, optim_param, iter_num, dummy_obs, i): + # Select the new one if iter_num % agent_reset_interval == reset_index + f_select = lambda n_s, o_s: lax.select(iter_num % self.reset_intervals[i] == reset_index, n_s, o_s) + # Generate a new training_state and env_state + new_training_state = self.get_training_state(seed, dummy_obs, i) + new_env_state = self.envs[i].reset(seed) + # Select the new training_state + training_state = tree_util.tree_map(f_select, new_training_state, training_state) + env_state = tree_util.tree_map(f_select, new_env_state, env_state) + # Update optim_param + training_state = training_state.replace(agent_optim_param=optim_param) + return training_state, env_state + + def train(self): + seed = self.seed + # Initialize pmap_train_one_iteration and carries (hidden_state, agent_param, agent_optim_state, env_states, step_seeds) + carries = [None] * self.task_num + dummy_obs = [None] * self.task_num + pmap_train_one_iterations = [None] * self.task_num + pvmap_reset_agent_training = [None] * self.task_num + for i in range(self.task_num): + # Generate random seeds for env and agent + seed, env_seed, agent_seed = jitted_split(seed, num=3) + # Initialize agent parameter and optimizer + dummy_obs[i] = self.envs[i].render_obs(self.envs[i].reset(env_seed))[None,] + pvmap_reset_agent_training[i] = jax.pmap( + jax.vmap( + self.reset_agent_training, + in_axes = (0, 0, 0, 0, None, None, None, None), + out_axes= (0, 0), + axis_name = 'batch' + ), + in_axes = (0, 0, 0, 0, None, None, None, None), + out_axes= (0, 0), + axis_name = 'core', + static_broadcasted_argnums = (7) + ) + # We initialize core_count*batch_size different agent parameters and optimizer states. + pvmap_get_training_state = jax.pmap( + jax.vmap( + self.get_training_state, + in_axes = (0, None, None), + out_axes = (0), + axis_name = 'batch' + ), + in_axes = (0, None, None), + out_axes = (0), + axis_name = 'core', + static_broadcasted_argnums = (2) + ) + agent_seed, *agent_seeds = jitted_split(agent_seed, self.core_count * self.batch_size + 1) + agent_seeds = self.reshape(jnp.stack(agent_seeds)) + training_states = pvmap_get_training_state(agent_seeds, dummy_obs[i], i) + # Intialize env_states over cores and batch + seed, *env_seeds = jitted_split(seed, self.core_count * self.batch_size + 1) + env_states = jax.vmap(self.envs[i].reset)(jnp.stack(env_seeds)) + env_states = tree_util.tree_map(self.reshape, env_states) + seed, *step_seeds = jitted_split(seed, self.core_count * self.batch_size + 1) + step_seeds = self.reshape(jnp.stack(step_seeds)) + # Save in carries dict + carries[i] = (training_states, env_states, step_seeds, -1) + # Replicate the training process over multiple cores + pmap_train_one_iterations[i] = jax.pmap( + jax.vmap( + self.learn, + in_axes = ((None, 0, 0, 0, None), None), + out_axes = (None, 0, 0, 0, None), + axis_name = 'batch' + ), + in_axes = ((None, 0, 0, 0, None), None), + out_axes = (None, 0, 0, 0, None), + axis_name = 'core', + static_broadcasted_argnums = (1) + ) + + self.meta_param = self.agent_optim.get_optim_param() + self.meta_optim_state = self.meta_optim.init(self.meta_param) + # Train for self.iterations for each env + for t in range(1, self.iterations+1): + start_time = time.time() + meta_grads = [] + for i in range(self.task_num): + # Reset agent training: agent_param, hidden_state, env_state + # and update meta parameter (i.e. optim_param) + training_states, env_states = carries[i][0], carries[i][1] + seed, *reset_seeds = jitted_split(seed, self.core_count * self.batch_size + 1) + reset_seeds = self.reshape(jnp.stack(reset_seeds)) + training_states, env_states = pvmap_reset_agent_training[i](training_states, env_states, self.reset_indexes[i], reset_seeds, self.meta_param, t-1, dummy_obs[i], i) + carries[i] = list(carries[i]) + carries[i][0], carries[i][1] = training_states, env_states + carries[i] = tuple(carries[i]) + # Train for one iteration + carry_out = pmap_train_one_iterations[i]((self.meta_param,)+carries[i], i) + # Gather meta grad and update carries + meta_grad, carries[i] = carry_out[0], carry_out[1:] + if self.max_norm > 0: + g_norm = self.global_norm(meta_grad) + meta_grad = tree_util.tree_map(lambda x: (x / g_norm.astype(x.dtype)) * self.max_norm, meta_grad) + meta_grads.append(meta_grad) + # Update meta paramter + meta_grad = tree_transpose(meta_grads) + meta_grad = tree_util.tree_map(lambda x: jnp.mean(x, axis=0), meta_grad) + # Update meta parameter + meta_param_update, self.meta_optim_state = self.meta_optim.update(meta_grad, self.meta_optim_state) + self.meta_param = optax.apply_updates(self.meta_param, meta_param_update) + # Reset agent_optim with new meta_param + self.agent_optim.reset_optimizer(self.meta_param) + # Show log + if t % self.cfg['display_interval'] == 0: + step_count = t * self.macro_step + speed = self.macro_step / (time.time() - start_time) + eta = (self.train_steps - step_count) / speed / 60 if speed>0 else -1 + self.logger.info(f'<{self.config_idx}> Step {step_count}/{self.train_steps} Iteration {t}/{self.iterations}: Speed={speed:.2f} (steps/s), ETA={eta:.2f} (mins)') + # Save meta param + if t == self.iterations: + save_model_param(self.meta_param, self.cfg['logs_dir']+'param.pickle') \ No newline at end of file diff --git a/agents/MetaPPO.py b/agents/MetaPPO.py index ce4dc12..0debe9c 100644 --- a/agents/MetaPPO.py +++ b/agents/MetaPPO.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2023 The Brax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,565 +12,386 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2022 The Brax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +"""Proximal policy optimization training. +See: https://arxiv.org/pdf/1707.06347.pdf +""" -import functools -import pickle import time -from typing import Tuple +import flax +import optax +import functools +from typing import Tuple, Any -import chex import jax +from jax import lax import jax.numpy as jnp -import optax +from jax.tree_util import tree_map + from brax import envs -from brax import jumpy as jp -from brax.envs import wrappers from brax.training import acting, types +from brax.training.types import PRNGKey +from brax.training.agents.ppo import losses as ppo_losses +from brax.training.agents.ppo import networks as ppo_networks from brax.training.acme import specs -from jax import lax -from jax.tree_util import tree_map - -import components.losses as ppo_losses -from components import gradients, ppo_networks, running_statistics -from components.optim import OptimState, set_meta_optimizer, set_optimizer -from utils.logger import Logger -InferenceParams = Tuple[running_statistics.NestedMeanStd, types.Params] +from components import running_statistics +from components.optim import set_optim +from components import gradients +from utils.helper import jitted_split, save_model_param, pytree2array +from agents.PPO import PPO -@chex.dataclass +@flax.struct.dataclass class TrainingState: - """Contains training_state for the learner.""" - - agent_optim_state: OptimState - agent_param: ppo_losses.PPONetworkParams - - -class MetaPPO(object): - """ - Meta-train a learned optimizer during traing PPO in Brax, compatible with LinearOptim, Optim4RL, and L2LGD2. - """ - - def __init__(self, cfg): - self.cfg = cfg - self.config_idx = cfg["config_idx"] - self.logger = Logger(cfg["logs_dir"]) - self.log_path = cfg["logs_dir"] + "result_Test.feather" - self.result = [] - # Set environment - self.env_name = cfg["env"]["name"] - self.agent_name = cfg["agent"]["name"] - self.train_steps = int(cfg["env"]["train_steps"]) - self.env = envs.get_environment(env_name=self.env_name) - self.env_state = self.env.reset(rng=jp.random_prngkey(seed=self.cfg["seed"])) - self._PMAP_AXIS_NAME = "i" - # Agent reset interval - self.agent_reset_interval = self.cfg["agent"]["reset_interval"] - - def save_model_param(self, model_param, filepath): - f = open(filepath, "wb") - pickle.dump(model_param, f) - f.close() - - def train(self): - # Env - env = self.env - num_timesteps = self.train_steps - episode_length = self.cfg["env"]["episode_length"] - action_repeat = self.cfg["env"]["action_repeat"] - reward_scaling = self.cfg["env"]["reward_scaling"] - num_envs = self.cfg["env"]["num_envs"] - normalize_observations = self.cfg["env"]["normalize_obs"] - # Agent - network_factory = ppo_networks.make_ppo_networks - gae_lambda = self.cfg["agent"]["gae_lambda"] - unroll_length = self.cfg["agent"]["rollout_steps"] - num_minibatches = self.cfg["agent"]["num_minibatches"] - clip_ratio = self.cfg["agent"]["clip_ratio"] - update_epochs = self.cfg["agent"]["update_epochs"] - entropy_cost = self.cfg["agent"]["entropy_weight"] - normalize_advantage = True - # Meta learning - inner_updates = self.cfg["agent"]["inner_updates"] - # Others - batch_size = self.cfg["batch_size"] - discounting = self.cfg["discount"] - max_devices_per_host = self.cfg["max_devices_per_host"] - seed = self.cfg["seed"] - - """PPO training.""" - process_id = jax.process_index() - process_count = jax.process_count() - total_device_count = jax.device_count() - local_device_count = jax.local_device_count() - if max_devices_per_host is not None and max_devices_per_host > 0: - local_devices_to_use = min(local_device_count, max_devices_per_host) - else: - local_devices_to_use = local_device_count - self.logger.info( - f"Total device: {total_device_count}, Process: {process_count} (ID {process_id})" - ) - self.logger.info( - f"Local device: {local_device_count}, Devices to be used: {local_devices_to_use}" - ) - device_count = local_devices_to_use * process_count - assert num_envs % device_count == 0 - assert batch_size * num_minibatches % num_envs == 0 - self.core_reshape = lambda x: x.reshape((local_devices_to_use,) + x.shape[1:]) - - # The number of environment steps executed for every training step. - env_step_per_training_step = ( - batch_size * unroll_length * num_minibatches * action_repeat - ) - meta_env_step_per_training_step = ( - max(batch_size // num_envs, 1) - * num_envs - * unroll_length - * 1 - * action_repeat - ) - total_env_step_per_training_step = ( - env_step_per_training_step * inner_updates + meta_env_step_per_training_step - ) - # The number of training_step calls per training_epoch call. - self.iterations = num_timesteps // total_env_step_per_training_step - - self.logger.info(f"env_step_per_training_step = {env_step_per_training_step}") - self.logger.info( - f"meta_env_step_per_training_step = {meta_env_step_per_training_step}" - ) - self.logger.info( - f"total_env_step_per_training_step = {total_env_step_per_training_step}" + """Contains training state for the learner.""" + agent_optim_state: Any + agent_param: ppo_losses.PPONetworkParams + normalizer_param: running_statistics.RunningStatisticsState + + +class MetaPPO(PPO): + ''' + PPO for Brax with meta learned optimizer. + ''' + def __init__(self, cfg): + super().__init__(cfg) + # Agent reset interval + self.agent_reset_interval = self.cfg['agent']['reset_interval'] + reset_indexes = [int(x) for x in jnp.linspace(0, self.agent_reset_interval-1, num=self.local_devices_to_use)] + self.reset_indexes = self.core_reshape(jnp.array(reset_indexes)) + + def train(self): + # Env + environment = self.env + num_timesteps = self.train_steps + episode_length = self.cfg['env']['episode_length'] + action_repeat = self.cfg['env']['action_repeat'] + reward_scaling = self.cfg['env']['reward_scaling'] + num_envs = self.cfg['env']['num_envs'] + normalize_observations = self.cfg['env']['normalize_obs'] + # Agent + network_factory = ppo_networks.make_ppo_networks + gae_lambda = self.cfg['agent']['gae_lambda'] + unroll_length = self.cfg['agent']['rollout_steps'] + num_minibatches = self.cfg['agent']['num_minibatches'] + clipping_epsilon = self.cfg['agent']['clipping_epsilon'] + update_epochs = self.cfg['agent']['update_epochs'] + entropy_cost = self.cfg['agent']['entropy_weight'] + normalize_advantage = True + # Meta learning + inner_updates = self.cfg['agent']['inner_updates'] + # Others + batch_size = self.cfg['batch_size'] + discounting = self.cfg['discount'] + seed = self.cfg['seed'] + + """PPO training.""" + device_count = self.local_devices_to_use * self.process_count + assert num_envs % device_count == 0 + assert batch_size * num_minibatches % num_envs == 0 + # The number of environment steps executed for every training step. + env_step_per_training_step = batch_size * unroll_length * num_minibatches * action_repeat + meta_env_step_per_training_step = max(batch_size // num_envs, 1) * num_envs * unroll_length * 1 * action_repeat + total_env_step_per_training_step = env_step_per_training_step * inner_updates + + # The number of training_step calls per training_epoch call. + self.iterations = num_timesteps // total_env_step_per_training_step + self.logger.info(f'meta_env_step_per_training_step = {meta_env_step_per_training_step}') + self.logger.info(f'total_env_step_per_training_step = {total_env_step_per_training_step}') + self.logger.info(f'total iterations = {self.iterations}') + + # Prepare keys + # key_networks should be global so that + # the initialized networks are the same for different processes. + key = jax.random.PRNGKey(seed) + global_key, local_key = jitted_split(key) + local_key = jax.random.fold_in(local_key, self.process_id) + local_key, key_env, key_reset = jitted_split(local_key, 3) + key_agent_param, key_agent_optim, key_meta_optim = jitted_split(global_key, 3) + del key, global_key + key_envs = jitted_split(key_env, num_envs // self.process_count) + # Reshape to (local_devices_to_use, num_envs // process_count, 2) + key_envs = jnp.reshape(key_envs, (self.local_devices_to_use, -1) + key_envs.shape[1:]) + + # Set training and evaluation env + env = envs.training.wrap( + environment, + episode_length = episode_length, + action_repeat = action_repeat, + randomization_fn = None + ) + reset_fn = jax.pmap( + env.reset, + axis_name = self._PMAP_AXIS_NAME + ) + env_states = reset_fn(key_envs) + obs_shape = env_states.obs.shape + + # Set agent and meta optimizer + agent_optim = set_optim(self.cfg['agent_optim']['name'], self.cfg['agent_optim']['kwargs'], key_agent_optim) + meta_optim = set_optim(self.cfg['meta_optim']['name'], self.cfg['meta_optim']['kwargs'], key_meta_optim) + + # Set PPO network + if normalize_observations: + normalize = running_statistics.normalize + else: + normalize = lambda x, y: x + ppo_network = network_factory( + obs_shape[-1], + env.action_size, + preprocess_observations_fn = normalize, + policy_hidden_layer_sizes = (32,) * 4, + value_hidden_layer_sizes = (64,) * 5, + ) + make_policy = ppo_networks.make_inference_fn(ppo_network) + + # Set training states + def get_training_state(key): + key_policy, key_value = jitted_split(key) + agent_param = ppo_losses.PPONetworkParams( + policy = ppo_network.policy_network.init(key_policy), + value = ppo_network.value_network.init(key_value) + ) + training_state = TrainingState( + agent_optim_state = agent_optim.init(agent_param), + agent_param = agent_param, + normalizer_param = running_statistics.init_state(specs.Array(obs_shape[-1:], jnp.dtype('float32'))) + ) + return training_state + key_agents = jitted_split(key_agent_param, self.local_devices_to_use) + training_states = jax.pmap( + get_training_state, + axis_name = self._PMAP_AXIS_NAME + )(key_agents) + + # Set meta param and meta optim state + meta_param = agent_optim.optim_param + meta_optim_state = meta_optim.init(meta_param) + + # Set loss function + agent_loss_fn = functools.partial( + ppo_losses.compute_ppo_loss, + ppo_network=ppo_network, + entropy_cost=entropy_cost, + discounting=discounting, + reward_scaling=reward_scaling, + gae_lambda=gae_lambda, + clipping_epsilon=clipping_epsilon, + normalize_advantage=normalize_advantage + ) + meta_loss_fn = agent_loss_fn + + # Set pmap_axis_name to None so we don't average agent grad over cores + agent_grad_update_fn = gradients.gradient_update_fn_with_optim_param(agent_loss_fn, agent_optim, pmap_axis_name=None, has_aux=True) + + def convert_data(x: jnp.ndarray, key: PRNGKey): + x = jax.random.permutation(key, x) + x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:]) + return x + + def minibatch_step( + carry, data: types.Transition, + normalizer_param: running_statistics.RunningStatisticsState + ): + meta_param, optim_state, agent_param, key = carry + key, key_loss = jitted_split(key) + (loss, _), agent_param, optim_state = agent_grad_update_fn( + agent_param, + normalizer_param, + data, + key_loss, + optim_param = meta_param, + optimizer_state = optim_state + ) + return (meta_param, optim_state, agent_param, key), None + + def sgd_step( + carry, + unused_t, + data: types.Transition, + normalizer_param: running_statistics.RunningStatisticsState + ): + meta_param, optim_state, agent_param, key = carry + key, key_perm, key_grad = jitted_split(key, 3) + shuffled_data = tree_map(functools.partial(convert_data, key=key_perm), data) + (meta_param, optim_state, agent_param, key_grad), _ = lax.scan( + f = functools.partial(minibatch_step, normalizer_param=normalizer_param), + init = (meta_param, optim_state, agent_param, key_grad), + xs = shuffled_data, + length = num_minibatches + ) + return (meta_param, optim_state, agent_param, key), None + + def training_step( + carry: Tuple[flax.core.FrozenDict, TrainingState, envs.State, PRNGKey], + unused_t + ) -> Tuple[Tuple[flax.core.FrozenDict, TrainingState, envs.State, PRNGKey], Any]: + meta_param, training_state, env_state, key = carry + key_sgd, key_generate_unroll, new_key = jitted_split(key, 3) + policy = make_policy((training_state.normalizer_param, training_state.agent_param.policy)) + # Set rollout function + def rollout(carry, unused_t): + current_state, current_key = carry + current_key, next_key = jitted_split(current_key) + next_state, data = acting.generate_unroll( + env, + current_state, + policy, + current_key, + unroll_length, + extra_fields = ('truncation',) ) - self.logger.info(f"total iterations = {self.iterations}") - - # Prepare keys - # key_networks should be global so that - # the initialized networks are the same for different processes. - key = jax.random.PRNGKey(seed) - global_key, local_key = jax.random.split(key) - local_key = jax.random.fold_in(local_key, process_id) - local_key, key_env, key_reset = jax.random.split(local_key, 3) - key_agent_param, key_agent_optim, key_meta_optim = jax.random.split( - global_key, 3 - ) - del key, global_key - key_envs = jax.random.split(key_env, num_envs // process_count) - key_envs = jnp.reshape( - key_envs, (local_devices_to_use, -1) + key_envs.shape[1:] - ) - - # Set training and evaluation env - env = wrappers.wrap_for_training( - env, episode_length=episode_length, action_repeat=action_repeat - ) - reset_fn = jax.jit(jax.vmap(env.reset)) - env_states = reset_fn(key_envs) - - # Set agent and meta optimizer - agent_optimizer = set_optimizer( - self.cfg["agent_optimizer"]["name"], - self.cfg["agent_optimizer"]["kwargs"], - key_agent_optim, - ) - meta_optimizer = set_meta_optimizer( - self.cfg["meta_optimizer"]["name"], - self.cfg["meta_optimizer"]["kwargs"], - key_meta_optim, - ) - - # Set PPO network - if normalize_observations: - normalize = running_statistics.normalize - else: - normalize = lambda x, y: x - normalizer_param = running_statistics.init_state( - specs.Array((env.observation_size,), jnp.float32) - ) - ppo_network = network_factory( - env.observation_size, - env.action_size, - preprocess_observations_fn=normalize, - policy_hidden_layer_sizes=(32,) * 4, - value_hidden_layer_sizes=(64,) * 5, - ) - make_policy = ppo_networks.make_inference_fn(ppo_network) - - # Set training states - def get_training_state(key): - key_policy, key_value = jax.random.split(key) - agent_param = ppo_losses.PPONetworkParams( - policy=ppo_network.policy_network.init(key_policy), - value=ppo_network.value_network.init(key_value), - ) - training_state = TrainingState( - agent_param=agent_param, - agent_optim_state=agent_optimizer.init(agent_param), - ) - return training_state - - key_agents = jax.random.split(key_agent_param, local_devices_to_use) - training_states = jax.pmap(get_training_state)(key_agents) - - # Set meta param and meta optim state - meta_param = agent_optimizer.optim_param - meta_optim_state = meta_optimizer.init(meta_param) - - # Set loss function - agent_loss_fn = functools.partial( - ppo_losses.compute_ppo_loss, - ppo_network=ppo_network, - entropy_cost=entropy_cost, - discounting=discounting, - reward_scaling=reward_scaling, - gae_lambda=gae_lambda, - clip_ratio=clip_ratio, - normalize_advantage=normalize_advantage, - ) - - # Set pmap_axis_name to None so we don't average agent grad over cores - agent_grad_update_fn = gradients.gradient_update_fn_with_optim_param( - agent_loss_fn, agent_optimizer, pmap_axis_name=None, has_aux=True - ) - - meta_loss_fn = functools.partial( - ppo_losses.compute_ppo_loss, - ppo_network=ppo_network, - entropy_cost=entropy_cost, - discounting=discounting, - reward_scaling=reward_scaling, - gae_lambda=gae_lambda, - clip_ratio=clip_ratio, - normalize_advantage=normalize_advantage, + return (next_state, next_key), data + # Rollout for `batch_size * num_minibatches * unroll_length` steps + (env_state, _), data = lax.scan( + f = rollout, + init = (env_state, key_generate_unroll), + length = batch_size * num_minibatches // num_envs, + xs = None + ) + # shape = (batch_size * num_minibatches // num_envs, unroll_length, num_envs, ...) + data = tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) + # shape = (batch_size * num_minibatches // num_envs, num_envs, unroll_length, ...) + data = tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data) + # shape = (batch_size * num_minibatches, unroll_length, ...) + assert data.discount.shape[1:] == (unroll_length,) + # Update agent_param normalization + normalizer_param = running_statistics.update( + training_state.normalizer_param, + data.observation, + pmap_axis_name = self._PMAP_AXIS_NAME + ) + # SGD steps + (meta_param, agent_optim_state, agent_param, key_sgd), _ = lax.scan( + f = functools.partial(sgd_step, data=data, normalizer_param=normalizer_param), + init = (meta_param, training_state.agent_optim_state, training_state.agent_param, key_sgd), + length = update_epochs, + xs = None + ) + # Set the new training_state + new_training_state = TrainingState( + agent_optim_state = agent_optim_state, + agent_param = agent_param, + normalizer_param = normalizer_param + ) + return (meta_param, new_training_state, env_state, new_key), None + + def agent_update_and_meta_loss( + meta_param: flax.core.FrozenDict, + training_state: TrainingState, + env_state: envs.State, + key: PRNGKey + ) -> Tuple[jnp.ndarray, Tuple[flax.core.FrozenDict, TrainingState, envs.State, PRNGKey]]: + """Agent learning: update agent params""" + (meta_param, training_state, env_state, key), _ = lax.scan( + f = training_step, + init = (meta_param, training_state, env_state, key), + length = inner_updates, + xs = None + ) + """Meta learning: update meta params""" + # Gather data for meta learning + key_meta, key_generate_unroll, new_key = jitted_split(key, 3) + policy = make_policy((training_state.normalizer_param, training_state.agent_param.policy)) + # Set rollout function + def rollout(carry, unused_t): + current_state, current_key = carry + current_key, next_key = jitted_split(current_key) + next_state, data = acting.generate_unroll( + env, + current_state, + policy, + current_key, + unroll_length, + extra_fields = ('truncation',) ) - - def convert_data(x: jnp.ndarray, key: types.PRNGKey): - x = jax.random.permutation(key, x) - x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:]) - return x - - def minibatch_step( - carry, - data: types.Transition, - normalizer_param: running_statistics.RunningStatisticsState, - ): - meta_param, optim_state, agent_param, key = carry - key, key_loss = jax.random.split(key) - (loss, _), agent_param, optim_state = agent_grad_update_fn( - agent_param, - normalizer_param, - data, - key_loss, - optim_param=meta_param, - optimizer_state=optim_state, - ) - return (meta_param, optim_state, agent_param, key), None - - def sgd_step( - carry, - unused_t, - data: types.Transition, - normalizer_param: running_statistics.RunningStatisticsState, - ): - meta_param, optim_state, agent_param, key = carry - key, key_perm, key_grad = jax.random.split(key, 3) - shuffled_data = tree_map( - functools.partial(convert_data, key=key_perm), data - ) - (meta_param, optim_state, agent_param, key_grad), _ = lax.scan( - f=functools.partial(minibatch_step, normalizer_param=normalizer_param), - init=(meta_param, optim_state, agent_param, key_grad), - xs=shuffled_data, - length=num_minibatches, - ) - return (meta_param, optim_state, agent_param, key), None - - def training_step( - carry: Tuple[ - chex.ArrayTree, - running_statistics.RunningStatisticsState, - TrainingState, - envs.State, - types.PRNGKey, - ], - unused_t, - ) -> Tuple[ - Tuple[ - chex.ArrayTree, - running_statistics.RunningStatisticsState, - TrainingState, - envs.State, - types.PRNGKey, - ], - types.Metrics, - ]: - meta_param, normalizer_param, training_state, env_state, key = carry - key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) - policy = make_policy((normalizer_param, training_state.agent_param.policy)) - - # Set rollout function - def rollout(carry, unused_t): - current_state, current_key = carry - current_key, next_key = jax.random.split(current_key) - next_state, data = acting.generate_unroll( - env, - current_state, - policy, - current_key, - unroll_length, - extra_fields=("truncation",), - ) - return (next_state, next_key), data - - # Rollout for `batch_size * num_minibatches * unroll_length` steps - (env_state, _), data = lax.scan( - f=rollout, - init=(env_state, key_generate_unroll), - xs=None, - length=batch_size * num_minibatches // num_envs, - ) - # shape = (batch_size * num_minibatches // num_envs, unroll_length, num_envs, ...) - data = tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) - # shape = (batch_size * num_minibatches // num_envs, num_envs, unroll_length, ...) - data = tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data) - # shape = (batch_size * num_minibatches, unroll_length, ...) - assert data.discount.shape[1:] == (unroll_length,) - # Update normalization params and normalize observations. - normalizer_param = running_statistics.update( - normalizer_param, data.observation, pmap_axis_name=self._PMAP_AXIS_NAME - ) - # SGD steps - (meta_param, agent_optim_state, agent_param, key_sgd), _ = lax.scan( - f=functools.partial( - sgd_step, data=data, normalizer_param=normalizer_param - ), - init=( - meta_param, - training_state.agent_optim_state, - training_state.agent_param, - key_sgd, - ), - xs=None, - length=update_epochs, - ) - # Set the new training_state - training_state = training_state.replace( - agent_optim_state=agent_optim_state, agent_param=agent_param - ) - return ( - meta_param, - normalizer_param, - training_state, - env_state, - new_key, - ), None - - def agent_update_and_meta_loss( - meta_param, normalizer_param, training_state, env_state, key - ): - """Agent learning: update agent params""" - ( - meta_param, - normalizer_param, - training_state, - env_state, - key, - ), _ = lax.scan( - f=training_step, - init=(meta_param, normalizer_param, training_state, env_state, key), - length=inner_updates, - xs=None, - ) - """Meta learning: update meta params""" - # Gather data for meta learning - key_meta, key_generate_unroll, new_key = jax.random.split(key, 3) - policy = make_policy((normalizer_param, training_state.agent_param.policy)) - - # Set rollout function - def rollout(carry, unused_t): - current_state, current_key = carry - current_key, next_key = jax.random.split(current_key) - next_state, data = acting.generate_unroll( - env, - current_state, - policy, - current_key, - unroll_length, - extra_fields=("truncation",), - ) - return (next_state, next_key), data - - # Rollout for `batch_size * unroll_length` steps - (env_state, _), data = lax.scan( - f=rollout, - init=(env_state, key_generate_unroll), - xs=None, - length=max(batch_size // num_envs, 1), - ) - data = tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) - data = tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data) - assert data.discount.shape[1:] == (unroll_length,) - # Update normalization params and normalize observations. - normalizer_param = running_statistics.update( - normalizer_param, data.observation, pmap_axis_name=self._PMAP_AXIS_NAME - ) - # Compute meta loss - meta_loss, _ = meta_loss_fn( - params=training_state.agent_param, - normalizer_params=normalizer_param, - data=data, - rng=key_meta, - ) - return meta_loss, ( - meta_param, - normalizer_param, - training_state, - env_state, - key, - ) - - def meta_training_step( - meta_param, - meta_optim_state, - normalizer_param, - training_state, - env_state, - key, - ): - # Compute meta_grad - meta_grad, ( - meta_param, - normalizer_param, - training_state, - env_state, - key, - ) = jax.grad(agent_update_and_meta_loss, has_aux=True)( - meta_param, normalizer_param, training_state, env_state, key - ) - meta_grad = lax.pmean(meta_grad, axis_name=self._PMAP_AXIS_NAME) - # Update meta_param - meta_param_update, meta_optim_state = meta_optimizer.update( - meta_grad, meta_optim_state - ) - meta_param = optax.apply_updates(meta_param, meta_param_update) - # Update training_state: optim_param - agent_optim_state = training_state.agent_optim_state - agent_optim_state = agent_optim_state.replace(optim_param=meta_param) - training_state = training_state.replace(agent_optim_state=agent_optim_state) - return ( - meta_param, - meta_optim_state, - normalizer_param, - training_state, - env_state, - key, - ) - - pmap_meta_training_iteration = jax.pmap( - meta_training_step, - in_axes=(None, None, None, 0, 0, 0), - out_axes=(None, None, None, 0, 0, 0), - devices=jax.local_devices()[:local_devices_to_use], - axis_name=self._PMAP_AXIS_NAME, - ) - - # Setup agent training reset - reset_indexes = [ - int(x) - for x in jnp.linspace( - 0, self.agent_reset_interval - 1, num=local_devices_to_use - ) - ] - self.reset_indexes = self.core_reshape(jnp.array(reset_indexes)) - - def reset_agent_training(training_state, env_state, reset_index, key, iter_num): - # Select the new one if iter_num % agent_reset_interval == reset_index - def f_select(n_s, o_s): - return lax.select( - iter_num % self.agent_reset_interval == reset_index, n_s, o_s - ) - - # Generate a new training_state and env_state - key_env, key_agent = jax.random.split(key, 2) - new_training_state = get_training_state(key_agent) - key_envs = jax.random.split( - key_env, num_envs // process_count // local_devices_to_use - ) - new_env_state = jax.jit(jax.vmap(env.reset))(key_envs[None, :]) - new_env_state = tree_map(lambda x: x[0], new_env_state) - # Select the new training_state - training_state = tree_map(f_select, new_training_state, training_state) - env_state = tree_map(f_select, new_env_state, env_state) - return training_state, env_state - - pmap_reset_training_state = jax.pmap( - reset_agent_training, in_axes=(0, 0, 0, 0, None) - ) - - # Start training - step_key, local_key = jax.random.split(local_key) - step_keys = jax.random.split(step_key, local_devices_to_use) - for t in range(1, self.iterations + 1): - self.start_time = time.time() - # Reset agent training: agent_param, hidden_state, env_state - key_reset, *reset_keys = jax.random.split( - key_reset, local_devices_to_use + 1 - ) - reset_keys = jnp.stack(reset_keys) - training_states, env_states = pmap_reset_training_state( - training_states, env_states, self.reset_indexes, reset_keys, t - 1 - ) - # Train for one iteration - ( - meta_param, - meta_optim_state, - normalizer_param, - training_state, - env_states, - step_keys, - ) = pmap_meta_training_iteration( - meta_param, - meta_optim_state, - normalizer_param, - training_states, - env_states, - step_keys, - ) - # Show log - if ( - t % self.cfg["display_interval"] == 0 or t == self.iterations - ) and process_id == 0: - speed = total_env_step_per_training_step / ( - time.time() - self.start_time - ) - eta = ( - (self.iterations - t) - * total_env_step_per_training_step - / speed - / 60 - if speed > 0 - else -1 - ) - self.logger.info( - f"<{self.config_idx}> Iteration {t}/{self.iterations}, Speed={speed:.2f} (steps/s), ETA={eta:.2f} (mins)" - ) - # Save meta param - if (self.cfg["save_param"] > 0 and t % self.cfg["save_param"] == 0) or ( - t == self.iterations - ): - self.save_model_param( - meta_param, self.cfg["logs_dir"] + f"param{t}.pickle" - ) + return (next_state, next_key), data + # Rollout for `batch_size * unroll_length` steps + (env_state, _), data = lax.scan( + f = rollout, + init = (env_state, key_generate_unroll), + length = max(batch_size // num_envs, 1), + xs = None + ) + data = tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) + data = tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data) + assert data.discount.shape[1:] == (unroll_length,) + # Compute meta loss + meta_loss, _ = meta_loss_fn( + params = training_state.agent_param, + normalizer_params = training_state.normalizer_param, + data = data, + rng = key_meta + ) + return meta_loss, (meta_param, training_state, env_state, new_key) + + def meta_training_step(meta_param, meta_optim_state, training_state, env_state, key): + # Compute meta_grad + meta_grad, (meta_param, training_state, env_state, key) = jax.grad(agent_update_and_meta_loss, has_aux=True)( + meta_param, training_state, env_state, key + ) + meta_grad = lax.pmean(meta_grad, axis_name=self._PMAP_AXIS_NAME) + # Update meta_param + meta_param_update, meta_optim_state = meta_optim.update(meta_grad, meta_optim_state) + meta_param = optax.apply_updates(meta_param, meta_param_update) + # Update training_state: optim_param + agent_optim_state = training_state.agent_optim_state + agent_optim_state = agent_optim_state.replace(optim_param=meta_param) + training_state = training_state.replace(agent_optim_state=agent_optim_state) + return meta_param, meta_optim_state, training_state, env_state, key + + pmap_meta_training_iteration = jax.pmap( + meta_training_step, + in_axes = (None, None, 0, 0, 0), + out_axes = (None, None, 0, 0, 0), + devices = jax.local_devices()[:self.local_devices_to_use], + axis_name = self._PMAP_AXIS_NAME + ) + + # Setup agent training reset + def reset_agent_training(training_state, env_state, reset_index, key, iter_num): + # Select the new one if iter_num % agent_reset_interval == reset_index + f_select = lambda n_s, o_s: lax.select(iter_num % self.agent_reset_interval == reset_index, n_s, o_s) + # Generate a new training_state and env_state + key_env, key_agent = jitted_split(key, 2) + new_training_state = get_training_state(key_agent) + key_envs = jitted_split(key_env, num_envs // self.process_count // self.local_devices_to_use) + new_env_state = jax.jit(env.reset)(key_envs) + # Select the new training_state + training_state = tree_map(f_select, new_training_state, training_state) + env_state = tree_map(f_select, new_env_state, env_state) + return training_state, env_state + + pmap_reset_training_state = jax.pmap( + reset_agent_training, + in_axes = (0, 0, 0, 0, None), + out_axes = (0, 0), + axis_name = self._PMAP_AXIS_NAME + ) + + # Start training + step_key, local_key = jitted_split(local_key) + step_keys = jitted_split(step_key, self.local_devices_to_use) + step_keys = self.core_reshape(jnp.stack(step_keys)) + for t in range(1, self.iterations+1): + self.start_time = time.time() + # Reset agent training: agent_param, hidden_state, env_state + key_reset, *reset_keys = jitted_split(key_reset, self.local_devices_to_use+1) + reset_keys = self.core_reshape(jnp.stack(reset_keys)) + training_states, env_states = pmap_reset_training_state(training_states, env_states, self.reset_indexes, reset_keys, t-1) + # Train for one iteration + meta_param, meta_optim_state, training_states, env_states, step_keys = pmap_meta_training_iteration( + meta_param, meta_optim_state, training_states, env_states, step_keys + ) + # Check NaN error + if jnp.any(jnp.isnan(pytree2array(meta_param))): + self.logger.info('NaN error detected!') + break + # Show log + if (t % self.cfg['display_interval'] == 0 or t == self.iterations) and self.process_id == 0: + speed = total_env_step_per_training_step / (time.time() - self.start_time) + eta = (self.iterations - t) * total_env_step_per_training_step / speed / 60 if speed>0 else -1 + self.logger.info(f'<{self.config_idx}> Iteration {t}/{self.iterations}, Speed={speed:.2f} (steps/s), ETA={eta:.2f} (mins)') + # Save meta param + if t == self.iterations: + save_model_param(meta_param, self.cfg['logs_dir']+'param.pickle') \ No newline at end of file diff --git a/agents/MetaPPOstar.py b/agents/MetaPPOstar.py new file mode 100644 index 0000000..9890459 --- /dev/null +++ b/agents/MetaPPOstar.py @@ -0,0 +1,385 @@ +# Copyright 2024 Garena Online Private Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import flax +import optax +import functools +from typing import Tuple, Any + +import jax +from jax import lax +import jax.numpy as jnp +from jax.tree_util import tree_map + +from brax import envs +from brax.training import acting, types +from brax.training.types import PRNGKey +from brax.training.agents.ppo import losses as ppo_losses +from brax.training.agents.ppo import networks as ppo_networks +from brax.training.acme import specs + +from components import running_statistics +from components.optim import set_optim +from components import star_gradients +from utils.helper import jitted_split, save_model_param, pytree2array +from agents.PPO import PPO +from agents.MetaPPO import TrainingState + + +class MetaPPOstar(PPO): + ''' + PPO for Brax with meta learned optimizer. + ''' + def __init__(self, cfg): + super().__init__(cfg) + # Agent reset interval + self.agent_reset_interval = self.cfg['agent']['reset_interval'] + reset_indexes = [int(x) for x in jnp.linspace(0, self.agent_reset_interval-1, num=self.local_devices_to_use)] + self.reset_indexes = self.core_reshape(jnp.array(reset_indexes)) + + def train(self): + # Env + environment = self.env + num_timesteps = self.train_steps + episode_length = self.cfg['env']['episode_length'] + action_repeat = self.cfg['env']['action_repeat'] + reward_scaling = self.cfg['env']['reward_scaling'] + num_envs = self.cfg['env']['num_envs'] + normalize_observations = self.cfg['env']['normalize_obs'] + # Agent + network_factory = ppo_networks.make_ppo_networks + gae_lambda = self.cfg['agent']['gae_lambda'] + unroll_length = self.cfg['agent']['rollout_steps'] + num_minibatches = self.cfg['agent']['num_minibatches'] + clipping_epsilon = self.cfg['agent']['clipping_epsilon'] + update_epochs = self.cfg['agent']['update_epochs'] + entropy_cost = self.cfg['agent']['entropy_weight'] + normalize_advantage = True + # Meta learning + inner_updates = self.cfg['agent']['inner_updates'] + # Others + batch_size = self.cfg['batch_size'] + discounting = self.cfg['discount'] + seed = self.cfg['seed'] + + """PPO training.""" + device_count = self.local_devices_to_use * self.process_count + assert num_envs % device_count == 0 + assert batch_size * num_minibatches % num_envs == 0 + # The number of environment steps executed for every training step. + env_step_per_training_step = batch_size * unroll_length * num_minibatches * action_repeat + meta_env_step_per_training_step = max(batch_size // num_envs, 1) * num_envs * unroll_length * 1 * action_repeat + total_env_step_per_training_step = env_step_per_training_step * inner_updates + + # The number of training_step calls per training_epoch call. + self.iterations = num_timesteps // total_env_step_per_training_step + self.logger.info(f'meta_env_step_per_training_step = {meta_env_step_per_training_step}') + self.logger.info(f'total_env_step_per_training_step = {total_env_step_per_training_step}') + self.logger.info(f'total iterations = {self.iterations}') + + # Prepare keys + # key_networks should be global so that + # the initialized networks are the same for different processes. + key = jax.random.PRNGKey(seed) + global_key, local_key = jitted_split(key) + local_key = jax.random.fold_in(local_key, self.process_id) + local_key, key_env, key_reset = jitted_split(local_key, 3) + key_agent_param, key_agent_optim, key_meta_optim = jitted_split(global_key, 3) + del key, global_key + key_envs = jitted_split(key_env, num_envs // self.process_count) + # Reshape to (local_devices_to_use, num_envs // process_count, 2) + key_envs = jnp.reshape(key_envs, (self.local_devices_to_use, -1) + key_envs.shape[1:]) + + # Set training and evaluation env + env = envs.training.wrap( + environment, + episode_length = episode_length, + action_repeat = action_repeat, + randomization_fn = None + ) + reset_fn = jax.pmap( + env.reset, + axis_name = self._PMAP_AXIS_NAME + ) + env_states = reset_fn(key_envs) + obs_shape = env_states.obs.shape + + # Set agent and meta optimizer + agent_optim = set_optim(self.cfg['agent_optim']['name'], self.cfg['agent_optim']['kwargs'], key_agent_optim) + meta_optim = set_optim(self.cfg['meta_optim']['name'], self.cfg['meta_optim']['kwargs'], key_meta_optim) + + # Set PPO network + if normalize_observations: + normalize = running_statistics.normalize + else: + normalize = lambda x, y: x + ppo_network = network_factory( + obs_shape[-1], + env.action_size, + preprocess_observations_fn = normalize, + policy_hidden_layer_sizes = (32,) * 4, + value_hidden_layer_sizes = (64,) * 5, + ) + make_policy = ppo_networks.make_inference_fn(ppo_network) + + # Set training states + def get_training_state(key): + key_policy, key_value = jitted_split(key) + agent_param = ppo_losses.PPONetworkParams( + policy = ppo_network.policy_network.init(key_policy), + value = ppo_network.value_network.init(key_value) + ) + training_state = TrainingState( + agent_optim_state = agent_optim.init(agent_param), + agent_param = agent_param, + normalizer_param = running_statistics.init_state(specs.Array(obs_shape[-1:], jnp.dtype('float32'))) + ) + return training_state + key_agents = jitted_split(key_agent_param, self.local_devices_to_use) + training_states = jax.pmap( + get_training_state, + axis_name = self._PMAP_AXIS_NAME + )(key_agents) + + # Set meta param and meta optim state + meta_param = agent_optim.get_optim_param() + meta_optim_state = meta_optim.init(meta_param) + + # Set loss function + agent_loss_fn = functools.partial( + ppo_losses.compute_ppo_loss, + ppo_network=ppo_network, + entropy_cost=entropy_cost, + discounting=discounting, + reward_scaling=reward_scaling, + gae_lambda=gae_lambda, + clipping_epsilon=clipping_epsilon, + normalize_advantage=normalize_advantage + ) + meta_loss_fn = agent_loss_fn + + # Set pmap_axis_name to None so we don't average agent grad over cores + agent_grad_update_fn = star_gradients.gradient_update_fn_with_optim_param(agent_loss_fn, agent_optim, pmap_axis_name=None, has_aux=True) + + def convert_data(x: jnp.ndarray, key: PRNGKey): + x = jax.random.permutation(key, x) + x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:]) + return x + + def minibatch_step( + carry, data: types.Transition, + normalizer_param: running_statistics.RunningStatisticsState + ): + meta_param, optim_state, agent_param, key = carry + key, key_loss = jitted_split(key) + (loss, _), agent_param, optim_state = agent_grad_update_fn( + agent_param, + normalizer_param, + data, + key_loss, + optim_param = meta_param, + optimizer_state = optim_state + ) + return (meta_param, optim_state, agent_param, key), None + + def sgd_step( + carry, + unused_t, + data: types.Transition, + normalizer_param: running_statistics.RunningStatisticsState + ): + meta_param, optim_state, agent_param, key = carry + key, key_perm, key_grad = jitted_split(key, 3) + shuffled_data = tree_map(functools.partial(convert_data, key=key_perm), data) + (meta_param, optim_state, agent_param, key_grad), _ = lax.scan( + f = functools.partial(minibatch_step, normalizer_param=normalizer_param), + init = (meta_param, optim_state, agent_param, key_grad), + xs = shuffled_data, + length = num_minibatches + ) + return (meta_param, optim_state, agent_param, key), None + + def training_step( + carry: Tuple[flax.core.FrozenDict, TrainingState, envs.State, PRNGKey], + unused_t + ) -> Tuple[Tuple[flax.core.FrozenDict, TrainingState, envs.State, PRNGKey], Any]: + meta_param, training_state, env_state, key = carry + key_sgd, key_generate_unroll, new_key = jitted_split(key, 3) + policy = make_policy((training_state.normalizer_param, training_state.agent_param.policy)) + # Set rollout function + def rollout(carry, unused_t): + current_state, current_key = carry + current_key, next_key = jitted_split(current_key) + next_state, data = acting.generate_unroll( + env, + current_state, + policy, + current_key, + unroll_length, + extra_fields = ('truncation',) + ) + return (next_state, next_key), data + # Rollout for `batch_size * num_minibatches * unroll_length` steps + (env_state, _), data = lax.scan( + f = rollout, + init = (env_state, key_generate_unroll), + length = batch_size * num_minibatches // num_envs, + xs = None + ) + # shape = (batch_size * num_minibatches // num_envs, unroll_length, num_envs, ...) + data = tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) + # shape = (batch_size * num_minibatches // num_envs, num_envs, unroll_length, ...) + data = tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data) + # shape = (batch_size * num_minibatches, unroll_length, ...) + assert data.discount.shape[1:] == (unroll_length,) + # Update agent_param normalization + normalizer_param = running_statistics.update( + training_state.normalizer_param, + data.observation, + pmap_axis_name = self._PMAP_AXIS_NAME + ) + # SGD steps + (meta_param, agent_optim_state, agent_param, key_sgd), _ = lax.scan( + f = functools.partial(sgd_step, data=data, normalizer_param=normalizer_param), + init = (meta_param, training_state.agent_optim_state, training_state.agent_param, key_sgd), + length = update_epochs, + xs = None + ) + # Set the new training_state + new_training_state = TrainingState( + agent_optim_state = agent_optim_state, + agent_param = agent_param, + normalizer_param = normalizer_param + ) + return (meta_param, new_training_state, env_state, new_key), None + + def agent_update_and_meta_loss( + meta_param: flax.core.FrozenDict, + training_state: TrainingState, + env_state: envs.State, + key: PRNGKey + ) -> Tuple[jnp.ndarray, Tuple[flax.core.FrozenDict, TrainingState, envs.State, PRNGKey]]: + """Agent learning: update agent params""" + (meta_param, training_state, env_state, key), _ = lax.scan( + f = training_step, + init = (meta_param, training_state, env_state, key), + length = inner_updates, + xs = None + ) + """Meta learning: update meta params""" + # Gather data for meta learning + key_meta, key_generate_unroll, new_key = jitted_split(key, 3) + policy = make_policy((training_state.normalizer_param, training_state.agent_param.policy)) + # Set rollout function + def rollout(carry, unused_t): + current_state, current_key = carry + current_key, next_key = jitted_split(current_key) + next_state, data = acting.generate_unroll( + env, + current_state, + policy, + current_key, + unroll_length, + extra_fields = ('truncation',) + ) + return (next_state, next_key), data + # Rollout for `batch_size * unroll_length` steps + (env_state, _), data = lax.scan( + f = rollout, + init = (env_state, key_generate_unroll), + length = max(batch_size // num_envs, 1), + xs = None + ) + data = tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) + data = tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data) + assert data.discount.shape[1:] == (unroll_length,) + # Compute meta loss + meta_loss, _ = meta_loss_fn( + params = training_state.agent_param, + normalizer_params = training_state.normalizer_param, + data = data, + rng = key_meta + ) + return meta_loss, (meta_param, training_state, env_state, new_key) + + def meta_training_step(meta_param, meta_optim_state, training_state, env_state, key): + # Compute meta_grad + meta_grad, (meta_param, training_state, env_state, key) = jax.grad(agent_update_and_meta_loss, has_aux=True)( + meta_param, training_state, env_state, key + ) + meta_grad = lax.pmean(meta_grad, axis_name=self._PMAP_AXIS_NAME) + # Update meta_param + meta_param_update, meta_optim_state = meta_optim.update(meta_grad, meta_optim_state) + meta_param = optax.apply_updates(meta_param, meta_param_update) + # Update optim_param with meta_param outside this function + return meta_param, meta_optim_state, training_state, env_state, key + + pmap_meta_training_iteration = jax.pmap( + meta_training_step, + in_axes = (None, None, 0, 0, 0), + out_axes = (None, None, 0, 0, 0), + devices = jax.local_devices()[:self.local_devices_to_use], + axis_name = self._PMAP_AXIS_NAME + ) + + # Setup agent training reset + def reset_agent_training(training_state, env_state, reset_index, key, iter_num): + # Select the new one if iter_num % agent_reset_interval == reset_index + f_select = lambda n_s, o_s: lax.select(iter_num % self.agent_reset_interval == reset_index, n_s, o_s) + # Generate a new training_state and env_state + key_env, key_agent = jitted_split(key, 2) + new_training_state = get_training_state(key_agent) + key_envs = jitted_split(key_env, num_envs // self.process_count // self.local_devices_to_use) + new_env_state = jax.jit(env.reset)(key_envs) + # Select the new training_state + training_state = tree_map(f_select, new_training_state, training_state) + env_state = tree_map(f_select, new_env_state, env_state) + return training_state, env_state + + pmap_reset_training_state = jax.pmap( + reset_agent_training, + in_axes = (0, 0, 0, 0, None), + out_axes = (0, 0), + axis_name = self._PMAP_AXIS_NAME + ) + + # Start training + step_key, local_key = jitted_split(local_key) + step_keys = jitted_split(step_key, self.local_devices_to_use) + step_keys = self.core_reshape(jnp.stack(step_keys)) + for t in range(1, self.iterations+1): + self.start_time = time.time() + # Reset agent training: agent_param, hidden_state, env_state + key_reset, *reset_keys = jitted_split(key_reset, self.local_devices_to_use+1) + reset_keys = self.core_reshape(jnp.stack(reset_keys)) + training_states, env_states = pmap_reset_training_state(training_states, env_states, self.reset_indexes, reset_keys, t-1) + # Train for one iteration + meta_param, meta_optim_state, training_states, env_states, step_keys = pmap_meta_training_iteration( + meta_param, meta_optim_state, training_states, env_states, step_keys + ) + # Check NaN error + if jnp.any(jnp.isnan(pytree2array(meta_param))): + self.logger.info('NaN error detected!') + break + # Reset agent_optim with new meta_param + agent_optim.reset_optimizer(meta_param) + # Show log + if (t % self.cfg['display_interval'] == 0 or t == self.iterations) and self.process_id == 0: + speed = total_env_step_per_training_step / (time.time() - self.start_time) + eta = (self.iterations - t) * total_env_step_per_training_step / speed / 60 if speed>0 else -1 + self.logger.info(f'<{self.config_idx}> Iteration {t}/{self.iterations}, Speed={speed:.2f} (steps/s), ETA={eta:.2f} (mins)') + # Save meta param + if t == self.iterations: + save_model_param(meta_param, self.cfg['logs_dir']+'param.pickle') \ No newline at end of file diff --git a/agents/MetapA2C.py b/agents/MetapA2C.py new file mode 100644 index 0000000..c9ae57c --- /dev/null +++ b/agents/MetapA2C.py @@ -0,0 +1,34 @@ +# Copyright 2024 Garena Online Private Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp + +from agents.MetaA2C import MetaA2C + + +class MetapA2C(MetaA2C): + ''' + Implementation of Meta A2C *without* Pipeline Training: + We set reset_interval = 1 and reset_index = -1 such that + iter_num % reset_interval != reset_index, thus no reset. + ''' + def __init__(self, cfg): + super().__init__(cfg) + # Reset all reset_indexes to -1 + del self.reset_intervals, self.reset_indexes + self.reset_intervals = [1] * len(self.env_names) + self.reset_indexes = [None] * self.task_num + for i in range(self.task_num): + reset_indexes = [-1] * self.num_envs + self.reset_indexes[i] = self.reshape(jnp.array(reset_indexes)) \ No newline at end of file diff --git a/agents/MetapPPO.py b/agents/MetapPPO.py new file mode 100644 index 0000000..d756f39 --- /dev/null +++ b/agents/MetapPPO.py @@ -0,0 +1,29 @@ +# Copyright 2024 Garena Online Private Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp + +from agents.MetaPPO import MetaPPO + + +class MetapPPO(MetaPPO): + ''' + PPO for Brax with meta learned optimizer without Pipeline Training. + ''' + def __init__(self, cfg): + super().__init__(cfg) + # Reset all reset_indexes to -1 + self.agent_reset_interval = -1 + reset_indexes = [1]*self.local_devices_to_use + self.reset_indexes = self.core_reshape(jnp.array(reset_indexes)) \ No newline at end of file diff --git a/agents/PPO.py b/agents/PPO.py index 1e17d94..f6c71c0 100644 --- a/agents/PPO.py +++ b/agents/PPO.py @@ -1,18 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2022 The Brax Authors. +# Copyright 2023 The Brax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,380 +16,353 @@ See: https://arxiv.org/pdf/1707.06347.pdf """ -import functools import time +import flax +import optax +import functools +import numpy as np +import pandas as pd from typing import Tuple -import flax import jax +from jax import lax import jax.numpy as jnp -import numpy as np -import optax -import pandas as pd +from jax.tree_util import tree_map + from brax import envs -from brax import jumpy as jp -from brax.envs import wrappers from brax.training import acting, gradients, types +from brax.training.types import Params, PRNGKey, Metrics +from brax.training.agents.ppo import losses as ppo_losses +from brax.training.agents.ppo import networks as ppo_networks from brax.training.acme import specs -from jax import lax -from jax.tree_util import tree_map -import components.losses as ppo_losses -from components import ppo_networks, running_statistics -from components.optim import set_optimizer +from components import running_statistics +from components.optim import set_optim +from utils.helper import jitted_split from utils.logger import Logger -InferenceParams = Tuple[running_statistics.NestedMeanStd, types.Params] +InferenceParams = Tuple[running_statistics.NestedMeanStd, Params] @flax.struct.dataclass class TrainingState: - """Contains training state for the learner.""" - - agent_optim_state: optax.OptState - agent_param: ppo_losses.PPONetworkParams - normalizer_param: running_statistics.RunningStatisticsState - env_step: jnp.ndarray + """Contains training state for the learner.""" + agent_optim_state: optax.OptState + agent_param: ppo_losses.PPONetworkParams + normalizer_param: running_statistics.RunningStatisticsState + env_step: jnp.ndarray class PPO(object): - """ - Implementation of PPO for Brax, compatible with classical optimizers and learned optimizers (LinearOptim, Optim4RL, and L2LGD2). - """ - - def __init__(self, cfg): - self.cfg = cfg - self.config_idx = cfg["config_idx"] - self.logger = Logger(cfg["logs_dir"]) - self.log_path = cfg["logs_dir"] + "result_Test.feather" - self.result = [] - # Set environment - self.env_name = cfg["env"]["name"] - self.agent_name = cfg["agent"]["name"] - self.train_steps = int(cfg["env"]["train_steps"]) - self.env = envs.get_environment(env_name=self.env_name) - self.state = self.env.reset(rng=jp.random_prngkey(seed=self.cfg["seed"])) - # Timing - self.start_time = time.time() - self._PMAP_AXIS_NAME = "i" - - def save_progress(self, step_count, metrics): - episode_return = float(jax.device_get(metrics["eval/episode_reward"])) - result_dict = { - "Env": self.env_name, - "Agent": self.agent_name, - "Step": step_count, - "Return": episode_return, - } - self.result.append(result_dict) - # Save result to files - result = pd.DataFrame(self.result) - result["Env"] = result["Env"].astype("category") - result["Agent"] = result["Agent"].astype("category") - result.to_feather(self.log_path) - # Show log - speed = self.macro_step / (time.time() - self.start_time) - eta = (self.train_steps - step_count) / speed / 60 if speed > 0 else -1 - return episode_return, speed, eta - - def train(self): - # Env - env = self.env - num_timesteps = self.train_steps - episode_length = self.cfg["env"]["episode_length"] - action_repeat = self.cfg["env"]["action_repeat"] - reward_scaling = self.cfg["env"]["reward_scaling"] - num_envs = self.cfg["env"]["num_envs"] - num_evals = self.cfg["env"]["num_evals"] - num_eval_envs = 128 - normalize_observations = self.cfg["env"]["normalize_obs"] - # Agent - network_factory = ppo_networks.make_ppo_networks - gae_lambda = self.cfg["agent"]["gae_lambda"] - unroll_length = self.cfg["agent"]["rollout_steps"] - num_minibatches = self.cfg["agent"]["num_minibatches"] - clip_ratio = self.cfg["agent"]["clip_ratio"] - update_epochs = self.cfg["agent"]["update_epochs"] - entropy_cost = self.cfg["agent"]["entropy_weight"] - normalize_advantage = True - # Others - batch_size = self.cfg["batch_size"] - discounting = self.cfg["discount"] - max_devices_per_host = self.cfg["max_devices_per_host"] - seed = self.cfg["seed"] - eval_env = None - deterministic_eval = False - progress_fn = self.save_progress - - """PPO training.""" - process_id = jax.process_index() - process_count = jax.process_count() - total_device_count = jax.device_count() - local_device_count = jax.local_device_count() - if max_devices_per_host is not None and max_devices_per_host > 0: - local_devices_to_use = min(local_device_count, max_devices_per_host) - else: - local_devices_to_use = local_device_count - self.logger.info( - f"Total device: {total_device_count}, Process: {process_count} (ID {process_id})" - ) - self.logger.info( - f"Local device: {local_device_count}, Devices to be used: {local_devices_to_use}" - ) - device_count = local_devices_to_use * process_count - assert num_envs % device_count == 0 - assert batch_size * num_minibatches % num_envs == 0 - - # The number of environment steps executed for every training step. - env_step_per_training_step = ( - batch_size * unroll_length * num_minibatches * action_repeat - ) - num_evals = max(num_evals, 1) - # The number of training_step calls per training_epoch call. - num_training_steps_per_epoch = num_timesteps // ( - num_evals * env_step_per_training_step - ) - self.macro_step = num_training_steps_per_epoch * env_step_per_training_step - - # Prepare keys - # key_networks should be global so that - # the initialized networks are the same for different processes. - key = jax.random.PRNGKey(seed) - global_key, local_key = jax.random.split(key) - local_key = jax.random.fold_in(local_key, process_id) - local_key, key_env, eval_key = jax.random.split(local_key, 3) - key_policy, key_value, key_optim = jax.random.split(global_key, 3) - del key, global_key - key_envs = jax.random.split(key_env, num_envs // process_count) - key_envs = jnp.reshape( - key_envs, (local_devices_to_use, -1) + key_envs.shape[1:] + ''' + PPO for Brax. + ''' + def __init__(self, cfg): + self.cfg = cfg + self.config_idx = cfg['config_idx'] + self.logger = Logger(cfg['logs_dir']) + self.log_path = cfg['logs_dir'] + f'result_Test.feather' + self.result = [] + # Set environment + self.env_name = cfg['env']['name'] + self.agent_name = cfg['agent']['name'] + self.train_steps = int(cfg['env']['train_steps']) + backends = ['generalized', 'positional', 'spring'] + self.env = envs.get_environment(env_name=self.env_name, backend=backends[2]) + self.state = jax.jit(self.env.reset)(rng=jax.random.PRNGKey(seed=self.cfg['seed'])) + # Timing + self.start_time = time.time() + self._PMAP_AXIS_NAME = 'i' + + self.process_id = jax.process_index() + self.process_count = jax.process_count() + total_device_count = jax.device_count() + local_device_count = jax.local_device_count() + max_devices_per_host = self.cfg['max_devices_per_host'] + if max_devices_per_host is not None and max_devices_per_host > 0: + self.local_devices_to_use = min(local_device_count, max_devices_per_host) + else: + self.local_devices_to_use = local_device_count + self.core_reshape = lambda x: x.reshape((self.local_devices_to_use,) + x.shape[1:]) + self.logger.info(f'Total device: {total_device_count}, Process: {self.process_count} (ID {self.process_id})') + self.logger.info(f'Local device: {local_device_count}, Devices to be used: {self.local_devices_to_use}') + + def save_progress(self, step_count, metrics): + episode_return = float(jax.device_get(metrics['eval/episode_reward'])) + result_dict = { + 'Env': self.env_name, + 'Agent': self.agent_name, + 'Step': step_count, + 'Return': episode_return + } + self.result.append(result_dict) + # Save result to files + result = pd.DataFrame(self.result) + result['Env'] = result['Env'].astype('category') + result['Agent'] = result['Agent'].astype('category') + result.to_feather(self.log_path) + # Show log + speed = self.macro_step / (time.time() - self.start_time) + eta = (self.train_steps - step_count) / speed / 60 if speed>0 else -1 + return episode_return, speed, eta + + def train(self): + # Env + environment = self.env + num_timesteps = self.train_steps + episode_length = self.cfg['env']['episode_length'] + action_repeat = self.cfg['env']['action_repeat'] + reward_scaling = self.cfg['env']['reward_scaling'] + num_envs = self.cfg['env']['num_envs'] + num_evals = self.cfg['env']['num_evals'] + normalize_observations = self.cfg['env']['normalize_obs'] + # Agent + network_factory = ppo_networks.make_ppo_networks + gae_lambda = self.cfg['agent']['gae_lambda'] + unroll_length = self.cfg['agent']['rollout_steps'] + num_minibatches = self.cfg['agent']['num_minibatches'] + clipping_epsilon = self.cfg['agent']['clipping_epsilon'] + update_epochs = self.cfg['agent']['update_epochs'] + entropy_cost = self.cfg['agent']['entropy_weight'] + normalize_advantage = True + # Others + batch_size = self.cfg['batch_size'] + discounting = self.cfg['discount'] + seed = self.cfg['seed'] + deterministic_eval = False + progress_fn = self.save_progress + + """PPO training.""" + device_count = self.local_devices_to_use * self.process_count + assert num_envs % device_count == 0 + assert batch_size * num_minibatches % num_envs == 0 + # The number of environment steps executed for every training step. + env_step_per_training_step = batch_size * unroll_length * num_minibatches * action_repeat + num_evals = max(num_evals, 1) + # The number of training_step calls per training_epoch call. + num_training_steps_per_epoch = num_timesteps // (num_evals * env_step_per_training_step) + self.macro_step = num_training_steps_per_epoch * env_step_per_training_step + + # Prepare keys + # key_networks should be global so that + # the initialized networks are the same for different processes. + key = jax.random.PRNGKey(seed) + global_key, local_key = jitted_split(key) + local_key = jax.random.fold_in(local_key, self.process_id) + local_key, key_env, eval_key = jitted_split(local_key, 3) + key_policy, key_value, key_optim = jitted_split(global_key, 3) + del key, global_key + key_envs = jitted_split(key_env, num_envs // self.process_count) + # Reshape to (local_devices_to_use, num_envs//process_count, 2) + key_envs = jnp.reshape(key_envs, (self.local_devices_to_use, -1) + key_envs.shape[1:]) + + # Set training and evaluation env + env = envs.training.wrap( + environment, + episode_length = episode_length, + action_repeat = action_repeat, + randomization_fn = None + ) + eval_env = envs.training.wrap( + environment, + episode_length = episode_length, + action_repeat = action_repeat, + randomization_fn = None + ) + reset_fn = jax.pmap( + env.reset, + axis_name = self._PMAP_AXIS_NAME + ) + env_states = reset_fn(key_envs) + obs_shape = env_states.obs.shape # (local_devices_to_use, num_envs//process_count, ...) + + # Set optimizer + optimizer = set_optim(self.cfg['optim']['name'], self.cfg['optim']['kwargs'], key_optim) + + # Set PPO network + if normalize_observations: + normalize = running_statistics.normalize + else: + normalize = lambda x, y: x + ppo_network = network_factory( + obs_shape[-1], + env.action_size, + preprocess_observations_fn = normalize, + ) + make_policy = ppo_networks.make_inference_fn(ppo_network) + agent_param = ppo_losses.PPONetworkParams( + policy = ppo_network.policy_network.init(key_policy), + value = ppo_network.value_network.init(key_value) + ) + training_state = TrainingState( + agent_optim_state = optimizer.init(agent_param), + agent_param = agent_param, + normalizer_param = running_statistics.init_state(specs.Array(obs_shape[-1:], jnp.dtype('float32'))), + env_step = 0 + ) + + # Set loss function + loss_fn = functools.partial( + ppo_losses.compute_ppo_loss, + ppo_network=ppo_network, + entropy_cost=entropy_cost, + discounting=discounting, + reward_scaling=reward_scaling, + gae_lambda=gae_lambda, + clipping_epsilon=clipping_epsilon, + normalize_advantage=normalize_advantage + ) + gradient_update_fn = gradients.gradient_update_fn(loss_fn, optimizer, pmap_axis_name=self._PMAP_AXIS_NAME, has_aux=True) + + def convert_data(x: jnp.ndarray, key: PRNGKey): + x = jax.random.permutation(key, x) + x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:]) + return x + + def minibatch_step( + carry, data: types.Transition, + normalizer_param: running_statistics.RunningStatisticsState + ): + agent_optim_state, agent_param, key = carry + key, key_loss = jitted_split(key) + (loss, _), agent_param, agent_optim_state = gradient_update_fn( + agent_param, + normalizer_param, + data, + key_loss, + optimizer_state = agent_optim_state + ) + return (agent_optim_state, agent_param, key), None + + def sgd_step( + carry, + unused_t, + data: types.Transition, + normalizer_param: running_statistics.RunningStatisticsState + ): + agent_optim_state, agent_param, key = carry + key, key_perm, key_grad = jitted_split(key, 3) + shuffled_data = tree_map(functools.partial(convert_data, key=key_perm), data) + (agent_optim_state, agent_param, key_grad), _ = lax.scan( + f = functools.partial(minibatch_step, normalizer_param=normalizer_param), + init = (agent_optim_state, agent_param, key_grad), + xs = shuffled_data, + length = num_minibatches + ) + return (agent_optim_state, agent_param, key), None + + def training_step( + carry: Tuple[TrainingState, envs.State, PRNGKey], + unused_t + ) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]: + training_state, state, key = carry + key_sgd, key_generate_unroll, new_key = jitted_split(key, 3) + policy = make_policy((training_state.normalizer_param, training_state.agent_param.policy)) + # Set rollout function + def rollout(carry, unused_t): + current_state, current_key = carry + current_key, next_key = jitted_split(current_key) + next_state, data = acting.generate_unroll( + env, + current_state, + policy, + current_key, + unroll_length, + extra_fields = ('truncation',) ) - - # Set training and evaluation env - env = wrappers.wrap_for_training( - env, episode_length=episode_length, action_repeat=action_repeat - ) - reset_fn = jax.jit(jax.vmap(env.reset)) - env_states = reset_fn(key_envs) - if eval_env is None: - eval_env = env - else: - eval_env = wrappers.wrap_for_training( - eval_env, episode_length=episode_length, action_repeat=action_repeat - ) - - # Set optimizer - optimizer = set_optimizer( - self.cfg["optimizer"]["name"], self.cfg["optimizer"]["kwargs"], key_optim - ) - - # Set PPO network - if normalize_observations: - normalize = running_statistics.normalize - else: - normalize = lambda x, y: x - ppo_network = network_factory( - env.observation_size, - env.action_size, - preprocess_observations_fn=normalize, - ) - make_policy = ppo_networks.make_inference_fn(ppo_network) - agent_param = ppo_losses.PPONetworkParams( - policy=ppo_network.policy_network.init(key_policy), - value=ppo_network.value_network.init(key_value), + return (next_state, next_key), data + # Rollout for `batch_size * num_minibatches * unroll_length` steps + (state, _), data = lax.scan( + f = rollout, + init = (state, key_generate_unroll), + length = batch_size * num_minibatches // num_envs, + xs = None + ) + # shape = (batch_size * num_minibatches // num_envs, unroll_length, num_envs, ...) + data = tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) + # shape = (batch_size * num_minibatches // num_envs, num_envs, unroll_length, ...) + data = tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data) + # shape = (batch_size * num_minibatches, unroll_length, ...) + assert data.discount.shape[1:] == (unroll_length,) + # Update agent_param normalization + normalizer_param = running_statistics.update( + training_state.normalizer_param, + data.observation, + pmap_axis_name = self._PMAP_AXIS_NAME + ) + # SGD steps + (agent_optim_state, agent_param, key_sgd), _ = lax.scan( + f = functools.partial(sgd_step, data=data, normalizer_param=normalizer_param), + init = (training_state.agent_optim_state, training_state.agent_param, key_sgd), + length = update_epochs, + xs = None + ) + # Set the new training_state + new_training_state = TrainingState( + agent_optim_state = agent_optim_state, + agent_param = agent_param, + normalizer_param = normalizer_param, + env_step = training_state.env_step + env_step_per_training_step + ) + return (new_training_state, state, new_key), None + + def training_epoch( + training_state: TrainingState, + state: envs.State, + key: PRNGKey + ) -> Tuple[TrainingState, envs.State, Metrics]: + (training_state, state, key), _ = lax.scan( + f = training_step, + init = (training_state, state, key), + length = num_training_steps_per_epoch, + xs = None + ) + return training_state, state + + pmap_training_epoch = jax.pmap( + training_epoch, + in_axes = (None, 0, 0), + out_axes = (None, 0), + devices = jax.local_devices()[:self.local_devices_to_use], + axis_name = self._PMAP_AXIS_NAME + ) + + # Set evaluator + evaluator = acting.Evaluator( + eval_env, + functools.partial(make_policy, deterministic=deterministic_eval), + num_eval_envs = 128, + episode_length = episode_length, + action_repeat = action_repeat, + key = eval_key + ) + + # Run an initial evaluation + i, current_step = 0, 0 + if self.process_id == 0 and num_evals > 1: + metrics = evaluator.run_evaluation( + (training_state.normalizer_param, training_state.agent_param.policy), + training_metrics={} + ) + episode_return, _, _ = progress_fn(0, metrics) + self.logger.info(f'<{self.config_idx}> Iteration {i}/{num_evals}, Step {current_step}, Return={episode_return:.2f}') + + # Start training + for i in range(1, num_evals+1): + self.start_time = time.time() + epoch_key, local_key = jitted_split(local_key) + epoch_keys = jitted_split(epoch_key, self.local_devices_to_use) + # Train for one epoch + training_state, env_states = pmap_training_epoch(training_state, env_states, epoch_keys) + current_step = int(training_state.env_step) + # Run evaluation + if self.process_id == 0: + metrics = evaluator.run_evaluation( + (training_state.normalizer_param, training_state.agent_param.policy), + training_metrics={} ) - training_state = TrainingState( - agent_optim_state=optimizer.init(agent_param), - agent_param=agent_param, - normalizer_param=running_statistics.init_state( - specs.Array((env.observation_size,), jnp.float32) - ), - env_step=0, - ) - - # Set loss function - loss_fn = functools.partial( - ppo_losses.compute_ppo_loss, - ppo_network=ppo_network, - entropy_cost=entropy_cost, - discounting=discounting, - reward_scaling=reward_scaling, - gae_lambda=gae_lambda, - clip_ratio=clip_ratio, - normalize_advantage=normalize_advantage, - ) - gradient_update_fn = gradients.gradient_update_fn( - loss_fn, optimizer, pmap_axis_name=self._PMAP_AXIS_NAME, has_aux=True - ) - - def convert_data(x: jnp.ndarray, key: types.PRNGKey): - x = jax.random.permutation(key, x) - x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:]) - return x - - def minibatch_step( - carry, - data: types.Transition, - normalizer_param: running_statistics.RunningStatisticsState, - ): - agent_optim_state, agent_param, key = carry - key, key_loss = jax.random.split(key) - (loss, _), agent_param, agent_optim_state = gradient_update_fn( - agent_param, - normalizer_param, - data, - key_loss, - optimizer_state=agent_optim_state, - ) - return (agent_optim_state, agent_param, key), None - - def sgd_step( - carry, - unused_t, - data: types.Transition, - normalizer_param: running_statistics.RunningStatisticsState, - ): - agent_optim_state, agent_param, key = carry - key, key_perm, key_grad = jax.random.split(key, 3) - shuffled_data = tree_map( - functools.partial(convert_data, key=key_perm), data - ) - (agent_optim_state, agent_param, key_grad), _ = lax.scan( - f=functools.partial(minibatch_step, normalizer_param=normalizer_param), - init=(agent_optim_state, agent_param, key_grad), - xs=shuffled_data, - length=num_minibatches, - ) - return (agent_optim_state, agent_param, key), None - - def training_step( - carry: Tuple[TrainingState, envs.State, types.PRNGKey], unused_t - ) -> Tuple[Tuple[TrainingState, envs.State, types.PRNGKey], types.Metrics]: - training_state, state, key = carry - key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) - policy = make_policy( - (training_state.normalizer_param, training_state.agent_param.policy) - ) - - # Set rollout function - def rollout(carry, unused_t): - current_state, current_key = carry - current_key, next_key = jax.random.split(current_key) - next_state, data = acting.generate_unroll( - env, - current_state, - policy, - current_key, - unroll_length, - extra_fields=("truncation",), - ) - return (next_state, next_key), data - - # Rollout for `batch_size * num_minibatches * unroll_length` steps - (state, _), data = lax.scan( - f=rollout, - init=(state, key_generate_unroll), - xs=None, - length=batch_size * num_minibatches // num_envs, - ) - # shape = (batch_size * num_minibatches // num_envs, unroll_length, num_envs) - data = tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) - # shape = (batch_size * num_minibatches // num_envs, num_envs, unroll_length) - data = tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data) - # shape = (batch_size * num_minibatches, unroll_length) - assert data.discount.shape[1:] == (unroll_length,) - # Update normalization agent_param and normalize observations. - normalizer_param = running_statistics.update( - training_state.normalizer_param, - data.observation, - pmap_axis_name=self._PMAP_AXIS_NAME, - ) - # SGD steps - (agent_optim_state, agent_param, key_sgd), _ = lax.scan( - f=functools.partial( - sgd_step, data=data, normalizer_param=normalizer_param - ), - init=( - training_state.agent_optim_state, - training_state.agent_param, - key_sgd, - ), - xs=None, - length=update_epochs, - ) - # Set the new training state - new_training_state = TrainingState( - agent_optim_state=agent_optim_state, - agent_param=agent_param, - normalizer_param=normalizer_param, - env_step=training_state.env_step + env_step_per_training_step, - ) - return (new_training_state, state, new_key), None - - def training_epoch( - training_state: TrainingState, state: envs.State, key: types.PRNGKey - ) -> Tuple[TrainingState, envs.State, types.Metrics]: - (training_state, state, key), _ = lax.scan( - f=training_step, - init=(training_state, state, key), - xs=None, - length=num_training_steps_per_epoch, - ) - return training_state, state - - pmap_training_epoch = jax.pmap( - training_epoch, - in_axes=(None, 0, 0), - out_axes=(None, 0), - devices=jax.local_devices()[:local_devices_to_use], - axis_name=self._PMAP_AXIS_NAME, - ) - - # Set evaluator - evaluator = acting.Evaluator( - eval_env, - functools.partial(make_policy, deterministic=deterministic_eval), - num_eval_envs=num_eval_envs, - episode_length=episode_length, - action_repeat=action_repeat, - key=eval_key, - ) - - # Run an initial evaluation - i, current_step = 0, 0 - if process_id == 0 and num_evals > 1: - metrics = evaluator.run_evaluation( - (training_state.normalizer_param, training_state.agent_param.policy), - training_metrics={}, - ) - episode_return, _, _ = progress_fn(0, metrics) - self.logger.info( - f"<{self.config_idx}> Iteration {i}/{num_evals}, Step {current_step}, Return={episode_return:.2f}" - ) - - # Start training - for i in range(1, num_evals + 1): - self.start_time = time.time() - epoch_key, local_key = jax.random.split(local_key) - epoch_keys = jax.random.split(epoch_key, local_devices_to_use) - # Train for one epoch - training_state, env_states = pmap_training_epoch( - training_state, env_states, epoch_keys - ) - current_step = int(training_state.env_step) - # Run evaluation - if process_id == 0: - metrics = evaluator.run_evaluation( - ( - training_state.normalizer_param, - training_state.agent_param.policy, - ), - training_metrics={}, - ) - episode_return, speed, eta = progress_fn(current_step, metrics) - self.logger.info( - f"<{self.config_idx}> Iteration {i}/{num_evals}, Step {current_step}, Return={episode_return:.2f}, Speed={speed:.2f} (steps/s), ETA={eta:.2f} (mins)" - ) - if np.isnan(episode_return): - self.logger.info("NaN error detected!") - break + episode_return, speed, eta = progress_fn(current_step, metrics) + self.logger.info(f'<{self.config_idx}> Iteration {i}/{num_evals}, Step {current_step}, Return={episode_return:.2f}, Speed={speed:.2f} (steps/s), ETA={eta:.2f} (mins)') + if np.isnan(episode_return): + self.logger.info('NaN error detected!') + break \ No newline at end of file diff --git a/agents/PPOstar.py b/agents/PPOstar.py new file mode 100644 index 0000000..b1a5911 --- /dev/null +++ b/agents/PPOstar.py @@ -0,0 +1,306 @@ +# Copyright 2024 Garena Online Private Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import functools +import numpy as np +from typing import Tuple + +import jax +from jax import lax +import jax.numpy as jnp +from jax.tree_util import tree_map + +from brax import envs +from brax.training import acting, types +from brax.training.types import PRNGKey, Metrics +from brax.training.agents.ppo import losses as ppo_losses +from brax.training.agents.ppo import networks as ppo_networks +from brax.training.acme import specs + +from components import running_statistics +from components import star_gradients +from components.optim import set_optim +from utils.helper import jitted_split + +from agents.PPO import PPO, TrainingState + + +class PPOstar(PPO): + ''' + PPO for Brax for STAR optimizer only. + ''' + def __init__(self, cfg): + super().__init__(cfg) + + def train(self): + # Env + environment = self.env + num_timesteps = self.train_steps + episode_length = self.cfg['env']['episode_length'] + action_repeat = self.cfg['env']['action_repeat'] + reward_scaling = self.cfg['env']['reward_scaling'] + num_envs = self.cfg['env']['num_envs'] + num_evals = self.cfg['env']['num_evals'] + normalize_observations = self.cfg['env']['normalize_obs'] + # Agent + network_factory = ppo_networks.make_ppo_networks + gae_lambda = self.cfg['agent']['gae_lambda'] + unroll_length = self.cfg['agent']['rollout_steps'] + num_minibatches = self.cfg['agent']['num_minibatches'] + clipping_epsilon = self.cfg['agent']['clipping_epsilon'] + update_epochs = self.cfg['agent']['update_epochs'] + entropy_cost = self.cfg['agent']['entropy_weight'] + normalize_advantage = True + # Others + batch_size = self.cfg['batch_size'] + discounting = self.cfg['discount'] + seed = self.cfg['seed'] + deterministic_eval = False + progress_fn = self.save_progress + + """PPO training.""" + device_count = self.local_devices_to_use * self.process_count + assert num_envs % device_count == 0 + assert batch_size * num_minibatches % num_envs == 0 + # The number of environment steps executed for every training step. + env_step_per_training_step = batch_size * unroll_length * num_minibatches * action_repeat + num_evals = max(num_evals, 1) + # The number of training_step calls per training_epoch call. + num_training_steps_per_epoch = num_timesteps // (num_evals * env_step_per_training_step) + self.macro_step = num_training_steps_per_epoch * env_step_per_training_step + + # Prepare keys + # key_networks should be global so that + # the initialized networks are the same for different processes. + key = jax.random.PRNGKey(seed) + global_key, local_key = jitted_split(key) + local_key = jax.random.fold_in(local_key, self.process_id) + local_key, key_env, eval_key = jitted_split(local_key, 3) + key_policy, key_value, key_optim = jitted_split(global_key, 3) + del key, global_key + key_envs = jitted_split(key_env, num_envs // self.process_count) + # Reshape to (local_devices_to_use, num_envs//process_count, 2) + key_envs = jnp.reshape(key_envs, (self.local_devices_to_use, -1) + key_envs.shape[1:]) + + # Set training and evaluation env + env = envs.training.wrap( + environment, + episode_length = episode_length, + action_repeat = action_repeat, + randomization_fn = None + ) + eval_env = envs.training.wrap( + environment, + episode_length = episode_length, + action_repeat = action_repeat, + randomization_fn = None + ) + reset_fn = jax.pmap( + env.reset, + axis_name = self._PMAP_AXIS_NAME + ) + env_states = reset_fn(key_envs) + obs_shape = env_states.obs.shape # (local_devices_to_use, num_envs//process_count, ...) + + # Set optimizer + optimizer = set_optim(self.cfg['optim']['name'], self.cfg['optim']['kwargs'], key_optim) + + # Set PPO network + if normalize_observations: + normalize = running_statistics.normalize + else: + normalize = lambda x, y: x + ppo_network = network_factory( + obs_shape[-1], + env.action_size, + preprocess_observations_fn = normalize, + ) + make_policy = ppo_networks.make_inference_fn(ppo_network) + agent_param = ppo_losses.PPONetworkParams( + policy = ppo_network.policy_network.init(key_policy), + value = ppo_network.value_network.init(key_value) + ) + training_state = TrainingState( + agent_optim_state = optimizer.init(agent_param), + agent_param = agent_param, + normalizer_param = running_statistics.init_state(specs.Array(obs_shape[-1:], jnp.dtype('float32'))), + env_step = 0 + ) + + # Set loss function + loss_fn = functools.partial( + ppo_losses.compute_ppo_loss, + ppo_network=ppo_network, + entropy_cost=entropy_cost, + discounting=discounting, + reward_scaling=reward_scaling, + gae_lambda=gae_lambda, + clipping_epsilon=clipping_epsilon, + normalize_advantage=normalize_advantage + ) + gradient_update_fn = star_gradients.gradient_update_fn(loss_fn, optimizer, pmap_axis_name=self._PMAP_AXIS_NAME, has_aux=True) + + def convert_data(x: jnp.ndarray, key: PRNGKey): + x = jax.random.permutation(key, x) + x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:]) + return x + + def minibatch_step( + carry, data: types.Transition, + normalizer_param: running_statistics.RunningStatisticsState + ): + agent_optim_state, agent_param, key = carry + key, key_loss = jitted_split(key) + (loss, _), agent_param, agent_optim_state = gradient_update_fn( + agent_param, + normalizer_param, + data, + key_loss, + optimizer_state = agent_optim_state + ) + return (agent_optim_state, agent_param, key), None + + def sgd_step( + carry, + unused_t, + data: types.Transition, + normalizer_param: running_statistics.RunningStatisticsState + ): + agent_optim_state, agent_param, key = carry + key, key_perm, key_grad = jitted_split(key, 3) + shuffled_data = tree_map(functools.partial(convert_data, key=key_perm), data) + (agent_optim_state, agent_param, key_grad), _ = lax.scan( + f = functools.partial(minibatch_step, normalizer_param=normalizer_param), + init = (agent_optim_state, agent_param, key_grad), + xs = shuffled_data, + length = num_minibatches + ) + return (agent_optim_state, agent_param, key), None + + def training_step( + carry: Tuple[TrainingState, envs.State, PRNGKey], + unused_t + ) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]: + training_state, state, key = carry + key_sgd, key_generate_unroll, new_key = jitted_split(key, 3) + policy = make_policy((training_state.normalizer_param, training_state.agent_param.policy)) + # Set rollout function + def rollout(carry, unused_t): + current_state, current_key = carry + current_key, next_key = jitted_split(current_key) + next_state, data = acting.generate_unroll( + env, + current_state, + policy, + current_key, + unroll_length, + extra_fields = ('truncation',) + ) + return (next_state, next_key), data + # Rollout for `batch_size * num_minibatches * unroll_length` steps + (state, _), data = lax.scan( + f = rollout, + init = (state, key_generate_unroll), + length = batch_size * num_minibatches // num_envs, + xs = None + ) + # shape = (batch_size * num_minibatches // num_envs, unroll_length, num_envs, ...) + data = tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) + # shape = (batch_size * num_minibatches // num_envs, num_envs, unroll_length, ...) + data = tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data) + # shape = (batch_size * num_minibatches, unroll_length, ...) + assert data.discount.shape[1:] == (unroll_length,) + # Update agent_param normalization + normalizer_param = running_statistics.update( + training_state.normalizer_param, + data.observation, + pmap_axis_name = self._PMAP_AXIS_NAME + ) + # SGD steps + (agent_optim_state, agent_param, key_sgd), _ = lax.scan( + f = functools.partial(sgd_step, data=data, normalizer_param=normalizer_param), + init = (training_state.agent_optim_state, training_state.agent_param, key_sgd), + length = update_epochs, + xs = None + ) + # Set the new training_state + new_training_state = TrainingState( + agent_optim_state = agent_optim_state, + agent_param = agent_param, + normalizer_param = normalizer_param, + env_step = training_state.env_step + env_step_per_training_step + ) + return (new_training_state, state, new_key), None + + def training_epoch( + training_state: TrainingState, + state: envs.State, + key: PRNGKey + ) -> Tuple[TrainingState, envs.State, Metrics]: + (training_state, state, key), _ = lax.scan( + f = training_step, + init = (training_state, state, key), + length = num_training_steps_per_epoch, + xs = None + ) + return training_state, state + + pmap_training_epoch = jax.pmap( + training_epoch, + in_axes = (None, 0, 0), + out_axes = (None, 0), + devices = jax.local_devices()[:self.local_devices_to_use], + axis_name = self._PMAP_AXIS_NAME + ) + + # Set evaluator + evaluator = acting.Evaluator( + eval_env, + functools.partial(make_policy, deterministic=deterministic_eval), + num_eval_envs = 128, + episode_length = episode_length, + action_repeat = action_repeat, + key = eval_key + ) + + # Run an initial evaluation + i, current_step = 0, 0 + if self.process_id == 0 and num_evals > 1: + metrics = evaluator.run_evaluation( + (training_state.normalizer_param, training_state.agent_param.policy), + training_metrics={} + ) + episode_return, _, _ = progress_fn(0, metrics) + self.logger.info(f'<{self.config_idx}> Iteration {i}/{num_evals}, Step {current_step}, Return={episode_return:.2f}') + + # Start training + for i in range(1, num_evals+1): + self.start_time = time.time() + epoch_key, local_key = jitted_split(local_key) + epoch_keys = jitted_split(epoch_key, self.local_devices_to_use) + # Train for one epoch + training_state, env_states = pmap_training_epoch(training_state, env_states, epoch_keys) + current_step = int(training_state.env_step) + # Run evaluation + if self.process_id == 0: + metrics = evaluator.run_evaluation( + (training_state.normalizer_param, training_state.agent_param.policy), + training_metrics={} + ) + episode_return, speed, eta = progress_fn(current_step, metrics) + self.logger.info(f'<{self.config_idx}> Iteration {i}/{num_evals}, Step {current_step}, Return={episode_return:.2f}, Speed={speed:.2f} (steps/s), ETA={eta:.2f} (mins)') + if np.isnan(episode_return): + self.logger.info('NaN error detected!') + break \ No newline at end of file diff --git a/agents/RNNIndentity.py b/agents/RNNIndentity.py deleted file mode 100644 index 980dbde..0000000 --- a/agents/RNNIndentity.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright 2022 Garena Online Private Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pickle -import time - -import jax -import jax.numpy as jnp -import numpy as np -import optax -import pandas as pd -from jax import lax, random, tree_util - -from components import network -from components.optim import set_optimizer -from utils.logger import Logger - - -class RNNIndentity(object): - """ " - Train an RNN to approximate the identity function with agent gradients as input. - """ - - def __init__(self, cfg): - self.agent_name = cfg["meta_net"]["name"] - self.cfg = cfg - self.config_idx = cfg["config_idx"] - self.logger = Logger(cfg["logs_dir"]) - self.epoch = int(self.cfg["epoch"]) - self.log_path = cfg["logs_dir"] + "result_Train.feather" - self.clip_ratio = 0.1 - # Create model - self.model = self.create_meta_net() - # Set optimizer - self.optimizer = set_optimizer( - cfg["optimizer"]["name"], cfg["optimizer"]["kwargs"], None - ) - # Load data - self.seed = random.PRNGKey(self.cfg["seed"]) - if len(self.cfg["datapath"]) > 0: - self.batches = self.load_data(self.cfg["seq_len"], self.cfg["datapath"]) - else: - self.batches = self.load_uniform_data(self.cfg["seq_len"]) - - def create_meta_net(self): - self.cfg["meta_net"]["mlp_dims"] = tuple(self.cfg["meta_net"]["mlp_dims"]) - return getattr(network, self.cfg["meta_net"]["name"])(**self.cfg["meta_net"]) - - def load_data(self, seq_len, datapath): - npzfile = np.load(datapath) - xs = npzfile["x"] # shape=(batch_size, len) - self.batch_size, self.num_batch = xs.shape[0], xs.shape[1] // seq_len - self.logger.info( - f"dataset size: {xs.shape}, batch_size: {self.batch_size}, num_batch: {self.num_batch}" - ) - batches = [] # shape=(num_batch, batch_size, seq_len) - for i in range(self.num_batch): - start = i * seq_len - x = xs[:, start : start + seq_len] - batches.append([x, x]) - batches = np.array(batches) - return jax.device_put(batches) - - def load_uniform_data(self, seq_len): - self.batch_size, len, max_steps = 395, 3000, 500 - self.num_batch = len // seq_len - # Generate random data in [-1,1] - xs = [] - for i in range(len // max_steps): - seed, self.seed = random.split(self.seed) - x = random.uniform(seed, (self.batch_size, max_steps), minval=i-1, maxval=i+1) - xs.append(x) - xs = jnp.concatenate(xs, axis=-1) - self.logger.info(f"dataset size: {xs.shape}, batch_size: {self.batch_size}, num_batch: {self.num_batch}") - batches = [] # shape=(num_batch, batch_size, seq_len) - for i in range(self.num_batch): - start = i * seq_len - x = xs[:, start:start+seq_len] - batches.append([x,x]) - batches = np.array(batches) - return jax.device_put(batches) - - def compute_loss(self, param, hidden_state, batch): - x, y = batch[0], batch[1] - hidden_state, pred_y = lax.scan( - f=lambda hidden, x_in: self.model.apply(param, hidden, x_in), - init=hidden_state, - xs=x, - ) - loss = jnp.mean(jnp.square(pred_y - y)) - # Compute the accuracy of pred_y in y*(1+/-clip_ratio) - mask = ( - (y >= 0) - & (pred_y >= (1 - self.clip_ratio) * y) - & (pred_y <= (1 + self.clip_ratio) * y) - ) - mask = mask | (y < 0) & (pred_y <= (1 - self.clip_ratio) * y) & ( - pred_y >= (1 + self.clip_ratio) * y - ) - perf = jnp.mean(mask) - return loss, (lax.stop_gradient(perf), hidden_state) - - def train_step(self, param, hidden_state, optim_state, batch): - (loss, (perf, hidden_state)), grad = jax.value_and_grad( - self.compute_loss, has_aux=True - )(param, hidden_state, batch) - # Reduce mean gradient and mean loss across batch - grad = lax.pmean(grad, axis_name="batch") - loss = lax.pmean(loss, axis_name="batch") - perf = lax.pmean(perf, axis_name="batch") - param_update, optim_state = self.optimizer.update(grad, optim_state) - param = optax.apply_updates(param, param_update) - return param, hidden_state, optim_state, loss, perf - - def train(self): - # Initialize model parameter - model_seed, self.seed = random.split(self.seed) - dummy_input = jnp.array([0.0]) - dummy_hidden_state = self.model.init_hidden_state(dummy_input) - param = self.model.init(model_seed, dummy_hidden_state, dummy_input) - - # Set optimizer state - optim_state = self.optimizer.init(param) - # Start training - batched_train_step = jax.vmap( - self.train_step, - in_axes=(None, 0, None, 1), - out_axes=(None, 0, None, None, None), - axis_name="batch", - ) - loss_list, perf_list = [], [] - start_time = time.time() - self.best_perf = 0.0 - - def f_loop(carry, batch): - param, hidden_state, optim_state = carry - param, hidden_state, optim_state, loss, perf = batched_train_step( - param, hidden_state, optim_state, batch - ) - carry = (param, hidden_state, optim_state) - logs = dict(loss=loss, perf=perf) - return carry, logs - - dummy_input = jnp.zeros((self.batch_size,)) - for i in range(1, self.epoch + 1): - hidden_state = self.model.init_hidden_state(dummy_input) - carry, logs = lax.scan( - f=f_loop, init=(param, hidden_state, optim_state), xs=self.batches - ) - param, hidden_state, optim_state = carry - epoch_loss = jnp.mean(logs["loss"]) - epoch_perf = jnp.mean(logs["perf"]) - loss_list.append(epoch_loss) - perf_list.append(epoch_perf) - if epoch_perf > self.best_perf: - self.best_perf = epoch_perf - if self.cfg["save_param"]: - self.save_model_param(param, self.cfg["logs_dir"] + "param.pickle") - if i % self.cfg["display_interval"] == 0: - speed = (time.time() - start_time) / i - eta = (self.epoch - i) * speed / 60 if speed > 0 else -1 - self.logger.info( - f"<{self.config_idx}> Epoch {i}/{self.epoch}: Loss={epoch_loss:.8f}, Perf={epoch_perf:.8f}, Speed={speed:.2f} (s/epoch), ETA={eta:.2f} (mins)" - ) - if self.best_perf < 0.01 and i >= 4: - self.logger.info( - f"Early stop at epoch {i} due to bad best performance: {self.best_perf:.8f}." - ) - break - self.logger.info(f"Best performance: {self.best_perf:.8f}.") - self.save_logs(loss_list, perf_list) - - def save_logs(self, loss_list, perf_list): - loss_list = np.array(jax.device_get(loss_list)) - perf_list = np.array(jax.device_get(perf_list)) - result = { - "Agent": self.agent_name, - "Epoch": np.array(range(len(loss_list))), - "Loss": loss_list, - "Perf": perf_list, - } - result = pd.DataFrame(result) - result["Agent"] = result["Agent"].astype("category") - result.to_feather(self.log_path) - - def save_model_param(self, model_param, filepath): - f = open(filepath, "wb") - pickle.dump(model_param, f) - f.close() - - def load_model_param(self, filepath): - f = open(filepath, "rb") - model_param = pickle.load(f) - model_param = tree_util.tree_map(jnp.array, model_param) - f.close() - return model_param diff --git a/agents/SLCollect.py b/agents/SLCollect.py new file mode 100644 index 0000000..4f13fad --- /dev/null +++ b/agents/SLCollect.py @@ -0,0 +1,194 @@ +# Copyright 2024 Garena Online Private Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from jax import random +import jax.numpy as jnp + +import time +import optax +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from functools import partial +from flax.training.train_state import TrainState + +from components import network +from utils.logger import Logger +from utils.helper import pytree2array +from components.optim import set_optim +from utils.dataloader import load_data + + +class SLCollect(object): + """ + Classification task. + """ + def __init__(self, cfg): + self.cfg = cfg + self.config_idx = cfg['config_idx'] + self.logger = Logger(cfg['logs_dir']) + self.task = cfg['task'] + self.model_name = self.cfg['model']['name'] + self.seed = random.PRNGKey(self.cfg['seed']) + self.log_path = { + 'Train': cfg['logs_dir'] + 'result_Train.feather', + 'Test': cfg['logs_dir'] + 'result_Test.feather' + } + self.results = {'Train': [], 'Test': []} + try: + self.output_dim = cfg['model']['kwargs']['output_dim'] + except: + self.output_dim = 10 + + def createNN(self, model, model_cfg): + NN = getattr(network, model)(**model_cfg) + return NN + + def train(self): + self.logger.info(f'Load dataset: {self.task}') + self.seed, data_seed = random.split(self.seed) + self.data = load_data(dataset=self.task, seed=data_seed, batch_size=self.cfg['batch_size']) + for mode in ['Train', 'Test']: + self.logger.info(f'Datasize [{mode}]: {len(self.data[mode]["y"])}') + self.logger.info('Create train state ...') + self.logger.info('Create train state: build neural network model') + model = self.createNN(self.model_name, self.cfg['model']['kwargs']) + self.seed, nn_seed, optim_seed = random.split(self.seed, 3) + params = model.init(nn_seed, self.data['dummy_input']) + self.logger.info('Create train state: set optimzer') + optim = set_optim(self.cfg['optimizer']['name'], self.cfg['optimizer']['kwargs'], optim_seed) + self.state = TrainState.create( + apply_fn = jax.jit(model.apply), + params = params, + tx = optim + ) + self.loss_fn = jax.jit(self.compute_loss) + + mode='Train' + nan_error = False + data_size = len(self.data[mode]['x']) + batch_num = data_size // self.cfg['batch_size'] + + self.logger.info('Start training ...') + all_grad = [] + for epoch in range(1, self.cfg['epochs']+1): + epoch_start_time = time.time() + """Train for a single epoch.""" + self.seed, seed = random.split(self.seed) + perms = random.permutation(seed, data_size) + perms = perms[:batch_num * self.cfg['batch_size']] # Skip incomplete batch + perms = perms.reshape((batch_num, self.cfg['batch_size'])) + epoch_loss, epoch_perf = [], [] + for perm in perms: + batch = { + 'x': self.data[mode]['x'][perm, ...], + 'y': self.data[mode]['y'][perm, ...] + } + # Forward: compute loss, performance, and gradient + (loss, perf), grads = jax.value_and_grad(self.loss_fn, has_aux=True)( + self.state.params, + self.state, + batch + ) + # Backward: update train state + self.state = self.update_state(self.state, grads) + # Log + loss = float(jax.device_get(loss)) + perf = float(jax.device_get(perf)) + grad = pytree2array(grads) + idxs = jnp.array(range(0, len(grad), self.cfg['agent']['data_reduce'])) + grad = jax.device_get(grad[idxs]) + # Check NaN error + if np.isnan(loss) or np.isnan(perf): + nan_error = True + self.logger.info("NaN error detected!") + break + epoch_loss.append(loss) + epoch_perf.append(perf) + all_grad.append(grad) + if nan_error: + break + epoch_loss = np.mean(epoch_loss) + epoch_perf = np.mean(epoch_perf) + # Save training result + self.save_results(mode, epoch, epoch_loss, epoch_perf) + # Display speed + if (epoch % self.cfg['display_interval'] == 0) or (epoch == self.cfg['epochs']): + speed = time.time() - epoch_start_time + eta = (self.cfg['epochs'] - epoch) * speed / 60 if speed > 0 else -1 + self.logger.info(f'Speed={speed:.2f} (s/epoch), ETA={eta:.2f} (mins)') + self.logger.info(f'<{self.config_idx}> {self.task} {self.model_name} [{mode}]: Epoch={epoch}, Loss={loss:.4f}, Perf={perf:.4f}') + self.process_logs(all_grad) + + @partial(jax.jit, static_argnums=0) + def compute_loss(self, params, state, batch): + logits = state.apply_fn(params, batch['x']) + one_hot = jax.nn.one_hot(batch['y'], self.output_dim) + loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) + perf = jnp.mean(jnp.argmax(logits, -1) == batch['y']) + return loss, perf + + @partial(jax.jit, static_argnums=0) + def update_state(self, state, grads): + return state.apply_gradients(grads=grads) + + def save_results(self, mode, epoch, loss, perf): + """Save and display result.""" + result_dict = { + 'Task': self.task, + 'Model': self.model_name, + 'Epoch': epoch, + 'Loss': loss, + 'Perf': perf + } + self.results[mode].append(result_dict) + results = pd.DataFrame(self.results[mode]) + results['Task'] = results['Task'].astype('category') + results['Model'] = results['Model'].astype('category') + results.to_feather(self.log_path[mode]) + + def process_logs(self, agent_grad): + # Shape to: (num_param, optimization_steps) + agent_grad = np.array(agent_grad) + x = np.stack(agent_grad, axis=1) + # Save grad + self.logger.info(f"# of param collected: {x.shape[0]}") + np.savez(self.cfg['logs_dir']+'data.npz', x=x) + # Plot log(abs(grad)) + grad = x.reshape(-1) + log_abs_grad = np.log10(np.abs(grad)+1e-8) + self.logger.info(f'g: min = {grad.min():.4f}, max = {grad.max():.4f}, mean = {grad.mean():.4f}') + self.logger.info(f'log(|g|+1e-8): min = {log_abs_grad.min():.4f}, max = {log_abs_grad.max():.4f}, mean = {log_abs_grad.mean():.4f}') + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5), tight_layout=True) + ax1.hist(grad, bins=40, density=False) + ax1.set_yscale('log') + ax1.set_xlabel('$g$', fontsize=18) + ax1.set_ylabel('log(counts)', fontsize=18) + ax1.grid(True) + # Plot log(|grad|) + ax2.hist(log_abs_grad, bins=list(np.arange(-9, 5, 0.5)), density=True) + ax2.set_xlim(-9, 5) + ax2.set_xticks(list(np.arange(-9, 5, 1))) + ax2.set_xlabel('$\log(|g|+10^{-8})$', fontsize=18) + ax2.set_ylabel('Probability density', fontsize=18) + ax2.grid(True) + # Adjust figure layout + plt.tick_params(axis='both', which='major', labelsize=14) + fig.tight_layout() + # Save figure + plt.savefig(self.cfg['logs_dir']+'grad.png') + plt.clf() + plt.cla() + plt.close() \ No newline at end of file diff --git a/agents/StarA2C.py b/agents/StarA2C.py deleted file mode 100644 index fef4b9c..0000000 --- a/agents/StarA2C.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright 2022 Garena Online Private Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import time -from typing import Any - -import chex -import jax -import jax.numpy as jnp -import optax -from jax import lax, random, tree_util - -from agents.MetaA2C import MetaA2C -from utils.helper import tree_transpose - - -@chex.dataclass -class TrainingState: - agent_param: Any - agent_optim_param: Any - agent_optim_state: optax.OptState - - -class StarA2C(MetaA2C): - """ - Meta-train STAR during traing A2C in gridworlds. - """ - - def __init__(self, cfg): - super().__init__(cfg) - assert ( - len(self.env_names) == 1 - ), "Only single task training is supported in StarA2C." - - def agent_update(self, carry_in, _): - """Perform a step of inner update to the agent.""" - meta_param, training_state, env_state, seed, lr = carry_in - seed, step_seed = random.split(seed) - # Generate one rollout and compute agent gradient - (agent_loss, (env_state, rollout)), agent_grad = jax.value_and_grad( - self.compute_agent_loss, has_aux=True - )(training_state.agent_param, env_state, step_seed) - # Update agent parameters - agent_optim_state = self.agent_optimizer.update_with_param( - meta_param, agent_grad, training_state.agent_optim_state, agent_loss - ) - # Set new training_state - training_state = training_state.replace( - agent_param=training_state.agent_optim_state.params, - agent_optim_state=agent_optim_state, - ) - carry_out = [meta_param, training_state, env_state, seed, lr] - return carry_out, None - - def learn(self, carry_in): - """Two level updates for meta_param (outer update) and agent_param (inner update).""" - training_state, env_state, seed, lr = carry_in - # Perform inner updates and compute meta gradient. - seed, step_seed = random.split(seed) - carry_in = [training_state, env_state, step_seed, lr] - meta_param = training_state.agent_optim_param - meta_grad, carry_out = jax.grad(self.agent_update_and_meta_loss, has_aux=True)( - meta_param, carry_in - ) - training_state, env_state = carry_out - # Reduce mean gradient across batch an cores - meta_grad = lax.pmean(meta_grad, axis_name="batch") - meta_grad = lax.pmean(meta_grad, axis_name="core") - carry_out = [meta_grad, training_state, env_state, seed, lr] - return carry_out - - def get_training_state(self, seed, obs): - agent_param = self.agent_net.init(seed, obs) - training_state = TrainingState( - agent_param=agent_param, - agent_optim_param=self.agent_optimizer.get_optim_param(), - agent_optim_state=self.agent_optimizer.init(agent_param), - ) - return training_state - - def reset_agent_training( - self, - training_state, - env_state, - reset_index, - seed, - optim_param, - iter_num, - agent_reset_interval, - obs, - ): - # Select the new one if iter_num % agent_reset_interval == reset_index - def f_select(n_s, o_s): - return lax.select(iter_num % agent_reset_interval == reset_index, n_s, o_s) - - # Generate a new training_state and env_state - new_training_state = self.get_training_state(seed, obs) - new_env_state = self.env.reset(seed) - # Select the new training_state - training_state = tree_util.tree_map( - f_select, new_training_state, training_state - ) - env_state = tree_util.tree_map(f_select, new_env_state, env_state) - # Update optim_param - training_state = training_state.replace(agent_optim_param=optim_param) - return training_state, env_state - - def train(self): - seed = self.seed - # Initialize pmap_train_one_iteration and carries: hidden_state, agent_param, agent_optim_state, env_states, step_seeds - carries = dict() - pmap_train_one_iterations = dict() - pvmap_reset_agent_training = dict() - for i, env_name in enumerate(self.env_names): - # Generate random seeds for env and agent - seed, env_seed, agent_seed = random.split(seed, num=3) - # Set environment and agent network - self.env, self.agent_net = self.envs[i], self.agent_nets[i] - # Initialize agent parameter and optimizer - dummy_obs = self.env.render_obs(self.env.reset(env_seed))[None, :] - pvmap_reset_agent_training[env_name] = jax.pmap( - jax.vmap( - functools.partial(self.reset_agent_training, obs=dummy_obs), - in_axes=(0, 0, 0, 0, None, None, None), - ), - in_axes=(0, 0, 0, 0, None, None, None), - ) - # We initialize core_count*batch_size different agent parameters and optimizer states. - pvmap_get_training_state = jax.pmap( - jax.vmap(self.get_training_state, in_axes=(0, None)), in_axes=(0, None) - ) - agent_seed, *agent_seeds = random.split( - agent_seed, self.core_count * self.batch_size + 1 - ) - agent_seeds = self.reshape(jnp.stack(agent_seeds)) - training_states = pvmap_get_training_state(agent_seeds, dummy_obs) - # Intialize env_states over cores and batch - seed, *env_seeds = random.split(seed, self.core_count * self.batch_size + 1) - env_states = jax.vmap(self.env.reset)(jnp.stack(env_seeds)) - env_states = tree_util.tree_map(self.reshape, env_states) - seed, *step_seeds = random.split( - seed, self.core_count * self.batch_size + 1 - ) - step_seeds = self.reshape(jnp.stack(step_seeds)) - # Save in carries dict - carries[env_name] = [training_states, env_states, step_seeds, -1] - # Replicate the training process over multiple cores - batched_learn = jax.vmap( - self.learn, - in_axes=([0, 0, 0, None],), - out_axes=[None, 0, 0, 0, None], - axis_name="batch", - ) - pmap_train_one_iterations[env_name] = jax.pmap( - batched_learn, - in_axes=([0, 0, 0, None],), - out_axes=[None, 0, 0, 0, None], - axis_name="core", - ) - - self.meta_param = self.agent_optimizer.get_optim_param() - self.meta_optim_state = self.meta_optimizer.init(self.meta_param) - # Train for self.iterations for each env - for t in range(1, self.iterations + 1): - meta_grads = [] - start_time = time.time() - for i, env_name in enumerate(self.env_names): - # Set environment and agent network - self.env, self.agent_net = self.envs[i], self.agent_nets[i] - # Reset agent training: agent_param, hidden_state, env_state - # and update meta parameter (i.e. optim_param) - training_states, env_states = carries[env_name][0], carries[env_name][1] - seed, *reset_seeds = random.split( - seed, self.core_count * self.batch_size + 1 - ) - reset_seeds = self.reshape(jnp.stack(reset_seeds)) - training_states, env_states = pvmap_reset_agent_training[env_name]( - training_states, - env_states, - self.reset_indexes[env_name], - reset_seeds, - self.meta_param, - t, - self.reset_intervals[i], - ) - carries[env_name][0], carries[env_name][1] = training_states, env_states - # Train for one iteration - carry_in = carries[env_name] - carry_out = pmap_train_one_iterations[env_name](carry_in) - # Update carries - carries[env_name] = carry_out[1:] - # Gather meta grad and process - meta_grad = carry_out[0] - if self.max_norm > 0: - g_norm = self.global_norm(meta_grad) - meta_grad = tree_util.tree_map( - lambda x: (x / g_norm.astype(x.dtype)) * self.max_norm, - meta_grad, - ) - meta_grads.append(meta_grad) - # Update meta paramter - meta_grad = tree_transpose(meta_grads) - meta_grad = tree_util.tree_map(lambda x: jnp.mean(x, axis=0), meta_grad) - # Update meta parameter - meta_param_update, self.meta_optim_state = self.meta_optimizer.update( - meta_grad, self.meta_optim_state - ) - self.meta_param = optax.apply_updates(self.meta_param, meta_param_update) - # Reset agent_optimizer with new meta_param - self.agent_optimizer.reset_optimizer(self.meta_param) - # Show log - if t % self.cfg["display_interval"] == 0: - step_count = t * self.macro_step - speed = self.macro_step / (time.time() - start_time) - eta = (self.train_steps - step_count) / speed / 60 if speed > 0 else -1 - self.logger.info( - f"<{self.config_idx}> Step {step_count}/{self.train_steps} Iteration {t}/{self.iterations}: Speed={speed:.2f} (steps/s), ETA={eta:.2f} (mins)" - ) - # Save meta param - if (self.cfg["save_param"] > 0 and t % self.cfg["save_param"] == 0) or ( - t == self.iterations - ): - self.save_model_param( - self.meta_param, self.cfg["logs_dir"] + f"param{t}.pickle" - ) diff --git a/agents/__init__.py b/agents/__init__.py index 9d7af38..ea14149 100644 --- a/agents/__init__.py +++ b/agents/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +13,17 @@ # limitations under the License. from .BaseAgent import BaseAgent +from .SLCollect import SLCollect -from .PPO import PPO from .A2C import A2C -from .A2C2 import A2C2 -from .RNNIndentity import RNNIndentity - -from .CollectPPO import CollectPPO -from .CollectA2C import CollectA2C +from .A2Cstar import A2Cstar +from .A2Ccollect import A2Ccollect +from .MetaA2C import MetaA2C +from .MetapA2C import MetapA2C +from .MetaA2Cstar import MetaA2Cstar +from .PPO import PPO +from .PPOstar import PPOstar from .MetaPPO import MetaPPO -from .MetaA2C import MetaA2C -from .StarA2C import StarA2C \ No newline at end of file +from .MetapPPO import MetapPPO +from .MetaPPOstar import MetaPPOstar \ No newline at end of file diff --git a/analysis_brax.py b/analysis_brax.py index 4ccd451..267de90 100644 --- a/analysis_brax.py +++ b/analysis_brax.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import math +import numpy as np +from scipy.stats import bootstrap +from collections import namedtuple + from utils.plotter import Plotter -from utils.sweeper import memory_info, time_info, unfinished_index +from utils.sweeper import unfinished_index, time_info, memory_info def get_process_result_dict(result, config_idx, mode='Test'): @@ -21,17 +27,28 @@ def get_process_result_dict(result, config_idx, mode='Test'): 'Env': result['Env'][0], 'Agent': result['Agent'][0], 'Config Index': config_idx, - 'Return (mean)': result['Return'][-100:].mean(skipna=False) if mode=='Train' else result['Return'][-5:].mean(skipna=False) + 'Return (mean)': result['Return'][-100:].mean(skipna=False) if mode=='Train' else result['Return'][-2:].mean(skipna=False) } return result_dict -def get_csv_result_dict(result, config_idx, mode='Test'): +def get_csv_result_dict(result, config_idx, mode='Train', ci=90, method='percentile'): + return_mean = result['Return (mean)'].values.tolist() + if len(return_mean) > 1: + CI = bootstrap( + (result['Return (mean)'].values.tolist(),), + np.mean, confidence_level=ci/100, + method=method + ).confidence_interval + else: + CI = namedtuple('ConfidenceInterval', ['low', 'high'])(low=return_mean[0], high=return_mean[0]) result_dict = { 'Env': result['Env'][0], 'Agent': result['Agent'][0], 'Config Index': config_idx, 'Return (mean)': result['Return (mean)'].mean(skipna=False), - 'Return (se)': result['Return (mean)'].sem(ddof=0) + 'Return (se)': result['Return (mean)'].sem(ddof=0), + 'Return (bootstrap_mean)': (CI.high + CI.low) / 2, + f'Return (ci={ci})': (CI.high - CI.low) / 2, } return result_dict @@ -45,14 +62,15 @@ def get_csv_result_dict(result, config_idx, mode='Test'): 'hue_label': 'Agent', 'show': False, 'imgType': 'png', - 'ci': 'se', + 'estimator': 'mean', + 'ci': ('ci', 90), 'x_format': None, 'y_format': None, 'xlim': {'min': None, 'max': None}, 'ylim': {'min': None, 'max': None}, 'EMA': True, 'loc': 'upper left', - 'sweep_keys': ['optimizer/name', 'optimizer/kwargs/learning_rate'], + 'sweep_keys': ['optim/name', 'optim/kwargs/learning_rate'], 'sort_by': ['Return (mean)', 'Return (se)'], 'ascending': [False, True], 'runs': 1 @@ -64,43 +82,30 @@ def analyze(exp, runs=1): plotter = Plotter(cfg) sweep_keys_dict = dict( - ppo = ['optimizer/name', 'optimizer/kwargs/learning_rate', 'optimizer/kwargs/gradient_clip'], - collect = ['agent_optimizer/name', 'agent_optimizer/kwargs/learning_rate'], - lopt = ['optimizer/name', 'optimizer/kwargs/learning_rate', 'optimizer/kwargs/param_load_path', 'optimizer/kwargs/gradient_clip'], - meta = ['agent_optimizer/name', 'agent_optimizer/kwargs/learning_rate', 'meta_optimizer/kwargs/learning_rate', 'meta_optimizer/kwargs/max_norm'], - online = ['agent_optimizer/name', 'agent_optimizer/kwargs/learning_rate', 'meta_optimizer/kwargs/learning_rate', 'meta_optimizer/kwargs/max_norm'], + ppo = ['optim/name', 'optim/kwargs/learning_rate'], + lopt = ['optim/name', 'optim/kwargs/learning_rate', 'optim/kwargs/param_load_path'], + meta = ['agent/reset_interval', 'agent_optim/name', 'agent_optim/kwargs/learning_rate', 'meta_optim/kwargs/learning_rate', 'meta_optim/kwargs/grad_clip', 'meta_optim/kwargs/grad_norm'] ) - algo = exp.split('_')[-1].rstrip('0123456789') + algo = exp.split('_')[0].rstrip('0123456789') plotter.sweep_keys = sweep_keys_dict[algo] mode = 'Test' - plotter.csv_results(mode, get_csv_result_dict, get_process_result_dict) + plotter.csv_merged_results(mode, get_csv_result_dict, get_process_result_dict) plotter.plot_results(mode, indexes='all') if __name__ == "__main__": - """Collect""" - # exp, runs = 'ant_collect', 1 - # exp, runs = 'fetch_collect', 1 - # exp, runs = 'grasp_collect', 1 - # exp, runs = 'halfcheetah_collect', 1 - # exp, runs = 'humanoid_collect', 1 - # exp, runs = 'humanoidstandup_collect', 1 - # exp, runs = 'pusher_collect', 1 - # exp, runs = 'reacher_collect', 1 - # exp, runs = 'ur5e_collect', 1 - """PPO""" - # exp, runs = 'ant_ppo', 10 - # exp, runs = 'fetch_ppo', 10 - # exp, runs = 'grasp_ppo', 10 - # exp, runs = 'halfcheetah_ppo', 10 - # exp, runs = 'humanoid_ppo', 10 - # exp, runs = 'humanoidstandup_ppo', 10 - # exp, runs = 'pusher_ppo', 10 - # exp, runs = 'reacher_ppo', 10 - # exp, runs = 'ur5e_ppo', 10 - """Lopt""" - exp, runs = 'ant_lopt', 1 - unfinished_index(exp, runs=runs) - memory_info(exp, runs=runs) - time_info(exp, runs=runs) - analyze(exp, runs=runs) \ No newline at end of file + meta_ant_list = ['meta_rl_ant', 'meta_rlp_ant', 'meta_lin_ant', 'meta_l2l_ant', 'meta_star_ant'] + meta_humanoid_list = ['meta_rl_humanoid', 'meta_rlp_humanoid', 'meta_lin_humanoid', 'meta_l2l_humanoid', 'meta_star_humanoid'] + + ppo_list = ['ppo_ant', 'ppo_humanoid', 'ppo_pendulum', 'ppo_walker2d'] + lopt_ant_list = ['lopt_rl_ant', 'lopt_rlp_ant', 'lopt_lin_ant', 'lopt_l2l_ant', 'lopt_star_ant'] + lopt_humanoid_list = ['lopt_rl_humanoid', 'lopt_rlp_humanoid', 'lopt_lin_humanoid', 'lopt_l2l_humanoid', 'lopt_star_humanoid'] + lopt_rl_grid_brax = ['lopt_rl_grid_ant', 'lopt_rl_grid_humanoid', 'lopt_rl_grid_pendulum', 'lopt_rl_grid_walker2d'] + + exp_list, runs = meta_ant_list, 1 + exp_list, runs = lopt_ant_list, 10 + for exp in exp_list: + unfinished_index(exp, runs=runs) + memory_info(exp, runs=runs) + time_info(exp, runs=runs) + analyze(exp, runs=runs) \ No newline at end of file diff --git a/analysis_grid.py b/analysis_grid.py index e949e6c..35715c1 100644 --- a/analysis_grid.py +++ b/analysis_grid.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,135 +12,118 @@ # See the License for the specific language governing permissions and # limitations under the License. -from utils.plotter import Plotter -from utils.sweeper import memory_info, time_info, unfinished_index +import os +import math +import numpy as np +from scipy.stats import bootstrap +from collections import namedtuple +from utils.plotter import Plotter +from utils.sweeper import unfinished_index, time_info, memory_info -def get_process_result_dict(result, config_idx, mode="Train"): - result_dict = { - "Env": result["Env"][0], - "Agent": result["Agent"][0], - "Config Index": config_idx, - "Return (mean)": result["Return"][-20:].mean(skipna=False), - } - return result_dict +def get_process_result_dict(result, config_idx, mode='Train'): + result_dict = { + 'Env': result['Env'][0], + 'Agent': result['Agent'][0], + 'Config Index': config_idx, + 'Return (mean)': result['Return'][-20:].mean(skipna=False) + } + return result_dict -def get_csv_result_dict(result, config_idx, mode="Train"): - result_dict = { - "Env": result["Env"][0], - "Agent": result["Agent"][0], - "Config Index": config_idx, - "Return (mean)": result["Return (mean)"].mean(skipna=False), - "Return (se)": result["Return (mean)"].sem(ddof=0), - } - return result_dict +def get_csv_result_dict(result, config_idx, mode='Train', ci=90, method='percentile'): + return_mean = result['Return (mean)'].values.tolist() + if len(return_mean) > 1: + CI = bootstrap( + (result['Return (mean)'].values.tolist(),), + np.mean, confidence_level=ci/100, + method=method + ).confidence_interval + else: + CI = namedtuple('ConfidenceInterval', ['low', 'high'])(low=return_mean[0], high=return_mean[0]) + result_dict = { + 'Env': result['Env'][0], + 'Agent': result['Agent'][0], + 'Config Index': config_idx, + 'Return (mean)': result['Return (mean)'].mean(skipna=False), + 'Return (se)': result['Return (mean)'].sem(ddof=0), + 'Return (bootstrap_mean)': (CI.high + CI.low) / 2, + f'Return (ci={ci})': (CI.high - CI.low) / 2, + } + return result_dict cfg = { - "exp": "exp_name", - "merged": True, - "x_label": "Step", - "y_label": "Return", - "rolling_score_window": -1, - "hue_label": "Agent", - "show": False, - "imgType": "png", - "ci": "sd", - "x_format": None, - "y_format": None, - "xlim": {"min": None, "max": None}, - "ylim": {"min": None, "max": None}, - "EMA": True, - "loc": "upper left", - "sweep_keys": ["agent_optimizer/name", "agent_optimizer/kwargs/learning_rate"], - "sort_by": ["Return (mean)", "Return (se)"], - "ascending": [False, True], - "runs": 1, + 'exp': 'exp_name', + 'merged': True, + 'x_label': 'Step', + 'y_label': 'Return', + 'rolling_score_window': -1, + 'hue_label': 'Agent', + 'show': False, + 'imgType': 'png', + 'estimator': 'mean', + 'ci': ('ci', 90), + 'EMA': True, + 'loc': 'upper left', + 'sweep_keys': ['meta_optim/kwargs/learning_rate', 'inner_updates', 'meta_net/hidden_size', 'meta_net/inner_scale', 'meta_net/input_scale', 'grad_clip', 'meta_param_path'], + 'sort_by': ['Return (mean)', 'Return (se)'], + 'ascending': [False, True], + 'runs': 1 } - def analyze(exp, runs=1): - cfg["exp"] = exp - cfg["runs"] = runs - plotter = Plotter(cfg) + cfg['exp'] = exp + cfg['runs'] = runs + plotter = Plotter(cfg) - modes = [] - if "bdl" in exp: - modes.append("big_dense_long") - if "bss" in exp: - modes.append("big_sparse_short") - if "sdl" in exp: - modes.append("small_dense_long") - if "sds" in exp: - modes.append("small_dense_short") - if "short" in exp: - modes = [ - "small_sparse_short", - "small_dense_short", - "big_sparse_short", - "big_dense_short", - ] - if "long" in exp: - modes = [ - "small_sparse_long", - "small_dense_long", - "big_sparse_long", - "big_dense_long", - ] - if "grid" in exp: - modes = [ - "small_sparse_short", - "small_sparse_long", - "small_dense_short", - "small_dense_long", - "big_sparse_short", - "big_sparse_long", - "big_dense_short", - "big_dense_long", - ] + modes = [] + if 'catch' in exp: + modes.append('catch') + if 'sds' in exp: + modes.append("small_dense_short") + if 'sdl' in exp: + modes.append("small_dense_long") + if 'bss' in exp: + modes.append("big_sparse_short") + if 'bds' in exp: + modes.append("big_dense_short") + if 'bsl' in exp: + modes.append("big_sparse_long") + if 'bdl' in exp: + modes.append("big_dense_long") + if 'grid' in exp: + modes = ["small_dense_short", "small_dense_long", "big_sparse_short", "big_sparse_long", "big_dense_short", "big_dense_long"] - sweep_keys_dict = dict( - a2c=["agent_optimizer/name", "agent_optimizer/kwargs/learning_rate"], - collect=[ - "agent_optimizer/name", - "agent_optimizer/kwargs/learning_rate", - "env/reward_scaling", - ], - lopt=[ - "agent_optimizer/name", - "agent_optimizer/kwargs/learning_rate", - "agent_optimizer/kwargs/param_load_path", - ], - meta=[ - "agent_optimizer/name", - "agent_optimizer/kwargs/learning_rate", - "meta_optimizer/kwargs/learning_rate", - ], - online=[ - "agent_optimizer/name", - "agent_optimizer/kwargs/learning_rate", - "meta_optimizer/kwargs/learning_rate", - ], - star=[ - "agent_optimizer/name", - "agent_optimizer/kwargs/step_mult", - "agent_optimizer/kwargs/nominal_stepsize", - "agent_optimizer/kwargs/weight_decay", - "meta_optimizer/kwargs/learning_rate", - ], - ) - algo = exp.split("_")[-1].rstrip("0123456789") - plotter.sweep_keys = sweep_keys_dict[algo] + sweep_keys_dict = dict( + a2c = ['agent_optim/name', 'agent_optim/kwargs/learning_rate'], + collect = ['agent_optim/name', 'agent_optim/kwargs/learning_rate', 'agent_optim/kwargs/grid_clip'], + lopt = ['agent_optim/name', 'agent_optim/kwargs/learning_rate', 'agent_optim/kwargs/param_load_path'], + meta = ['agent/reset_interval', 'agent_optim/name', 'meta_optim/kwargs/learning_rate', 'meta_optim/kwargs/grad_clip', 'meta_optim/kwargs/grad_norm', 'meta_optim/kwargs/max_norm'] + ) + algo = exp.split('_')[0].rstrip('0123456789') + plotter.sweep_keys = sweep_keys_dict[algo] - for mode in modes: - plotter.csv_results(mode, get_csv_result_dict, get_process_result_dict) - plotter.plot_results(mode=mode, indexes="all") + for mode in modes: + plotter.csv_merged_results(mode, get_csv_result_dict, get_process_result_dict) + plotter.plot_results(mode=mode, indexes='all') if __name__ == "__main__": - exp, runs = "sds_lopt", 10 + meta_catch_list = ['meta_rl_catch', 'meta_lin_catch', 'meta_l2l_catch', 'meta_star_catch'] + meta_sdl_list = ['meta_rl_sdl', 'meta_rlp_sdl'] + meta_bdl_list = ['meta_rl_bdl', 'meta_rlp_bdl', 'meta_lin_bdl', 'meta_l2l_bdl', 'meta_star_bdl'] + meta_grid_list = ['meta_rl_grid'] + + a2c_list = ['a2c_grid', 'a2c_catch'] + lopt_catch_list = ['lopt_rl_catch', 'lopt_star_catch', 'lopt_l2l_catch', 'lopt_lin_catch'] + lopt_sdl_list = ['lopt_rl_sdl', 'lopt_rlp_sdl'] + lopt_bdl_list = ['lopt_rl_bdl', 'lopt_rlp_bdl', 'lopt_lin_bdl', 'lopt_l2l_bdl', 'lopt_star_bdl'] + + exp_list, runs = meta_catch_list, 1 + exp_list, runs = lopt_catch_list, 10 + for exp in exp_list: unfinished_index(exp, runs=runs) memory_info(exp, runs=runs) time_info(exp, runs=runs) - analyze(exp, runs=runs) + analyze(exp, runs=runs) \ No newline at end of file diff --git a/analysis_identity.py b/analysis_identity.py deleted file mode 100644 index c4ced4c..0000000 --- a/analysis_identity.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2022 Garena Online Private Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from utils.plotter import Plotter -from utils.sweeper import memory_info, time_info, unfinished_index - - -def get_process_result_dict(result, config_idx, mode="Train"): - result_dict = { - "Config Index": config_idx, - "Loss (mean)": result["Loss"][-5:].mean(skipna=False), - "Perf (mean)": result["Perf"][-5:].mean(skipna=False), - } - return result_dict - - -def get_csv_result_dict(result, config_idx, mode="Train"): - result_dict = { - "Config Index": config_idx, - "Loss (mean)": result["Loss (mean)"].mean(skipna=False), - "Perf (mean)": result["Perf (mean)"].mean(skipna=False), - } - return result_dict - - -cfg = { - "exp": "exp_name", - "merged": True, - "x_label": "Epoch", - "y_label": "Perf", - "rolling_score_window": -1, - "hue_label": "Agent", - "show": False, - "imgType": "png", - "ci": "sd", - "x_format": None, - "y_format": None, - "xlim": {"min": None, "max": None}, - "ylim": {"min": None, "max": None}, - "EMA": True, - "loc": "upper left", - "sweep_keys": ["meta_net/name", "optimizer/kwargs/learning_rate"], - "sort_by": ["Perf (mean)", "Loss (mean)"], - "ascending": [False, True], - "runs": 1, -} - - -def analyze(exp, runs=1): - cfg["exp"] = exp - cfg["runs"] = runs - plotter = Plotter(cfg) - - plotter.csv_results("Train", get_csv_result_dict, get_process_result_dict) - plotter.plot_results(mode="Train", indexes="all") - - -if __name__ == "__main__": - exp, runs = "bdl_identity", 10 - unfinished_index(exp, runs=runs) - memory_info(exp, runs=runs) - time_info(exp, runs=runs) - analyze(exp, runs=runs) diff --git a/components/gradients.py b/components/gradients.py index 7bfd9fa..e10d0dc 100644 --- a/components/gradients.py +++ b/components/gradients.py @@ -1,18 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2022 The Brax Authors. +# Copyright 2023 The Brax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,9 +14,10 @@ """Brax training gradient utility functions.""" +from typing import Callable, Optional + import jax import optax -from typing import Callable, Optional def loss_and_pgrad(loss_fn: Callable[..., float], @@ -70,7 +57,7 @@ def f(*args, optimizer_state): value, grads = loss_and_pgrad_fn(*args) params_update, optimizer_state = optimizer.update(grads, optimizer_state) params = optax.apply_updates(args[0], params_update) - return value, params, optimizer_state, grads, params_update + return value, params, optimizer_state, grads return f diff --git a/components/losses.py b/components/losses.py deleted file mode 100644 index 554aebb..0000000 --- a/components/losses.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2022 Garena Online Private Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2022 The Brax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Proximal policy optimization training. -See: https://arxiv.org/pdf/1707.06347.pdf -""" - -from typing import Any, Tuple - -import flax -import jax -import jax.numpy as jnp -from brax.training import types -from brax.training.types import Params -from brax.training.agents.ppo import networks as ppo_networks - - -@flax.struct.dataclass -class PPONetworkParams: - """Contains training state for the learner.""" - policy: Params - value: Params - - -def compute_gae(truncation: jnp.ndarray, - termination: jnp.ndarray, - rewards: jnp.ndarray, - values: jnp.ndarray, - bootstrap_value: jnp.ndarray, - lambda_: float = 1.0, - discount: float = 0.99): - """Calculates the Generalized Advantage Estimation (GAE). - - Args: - truncation: A float32 tensor of shape [T, B] with truncation signal. - termination: A float32 tensor of shape [T, B] with termination signal. - rewards: A float32 tensor of shape [T, B] containing rewards generated by - following the behaviour policy. - values: A float32 tensor of shape [T, B] with the value function estimates - wrt. the target policy. - bootstrap_value: A float32 of shape [B] with the value function estimate at - time T. - lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). Defaults to - lambda_=1. - discount: TD discount. - - Returns: - A float32 tensor of shape [T, B]. Can be used as target to - train a baseline (V(x_t) - vs_t)^2. - A float32 tensor of shape [T, B] of advantages. - """ - - truncation_mask = 1 - truncation - # Append bootstrapped value to get [v1, ..., v_t+1] - values_t_plus_1 = jnp.concatenate( - [values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) - deltas = rewards + discount * (1 - termination) * values_t_plus_1 - values - deltas *= truncation_mask - - acc = jnp.zeros_like(bootstrap_value) - vs_minus_v_xs = [] - - def compute_vs_minus_v_xs(carry, target_t): - lambda_, acc = carry - truncation_mask, delta, termination = target_t - acc = delta + discount * (1 - termination) * truncation_mask * lambda_ * acc - return (lambda_, acc), (acc) - - (_, _), (vs_minus_v_xs) = jax.lax.scan( - compute_vs_minus_v_xs, (lambda_, acc), - (truncation_mask, deltas, termination), - length=int(truncation_mask.shape[0]), - reverse=True) - # Add V(x_s) to get v_s. - vs = jnp.add(vs_minus_v_xs, values) - - vs_t_plus_1 = jnp.concatenate( - [vs[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) - advantages = (rewards + discount * - (1 - termination) * vs_t_plus_1 - values) * truncation_mask - return jax.lax.stop_gradient(vs), jax.lax.stop_gradient(advantages) - - -def compute_ppo_loss( - params: PPONetworkParams, - normalizer_params: Any, - data: types.Transition, - rng: jnp.ndarray, - ppo_network: ppo_networks.PPONetworks, - entropy_cost: float = 1e-4, - discounting: float = 0.9, - reward_scaling: float = 1.0, - gae_lambda: float = 0.95, - clip_ratio: float = 0.3, - normalize_advantage: bool = True) -> Tuple[jnp.ndarray, types.Metrics]: - """Computes PPO loss. - - Args: - params: Network parameters, - normalizer_params: Parameters of the normalizer. - data: Transition that with leading dimension [B, T]. extra fields required - are ['state_extras']['truncation'] ['policy_extras']['raw_action'] - ['policy_extras']['log_prob'] - rng: Random key - ppo_network: PPO networks. - entropy_cost: entropy cost. - discounting: discounting, - reward_scaling: reward multiplier. - gae_lambda: General advantage estimation lambda. - clip_ratio: Policy loss clipping epsilon - normalize_advantage: whether to normalize advantage estimate - - Returns: - A tuple (loss, metrics) - """ - parametric_action_distribution = ppo_network.parametric_action_distribution - policy_apply = ppo_network.policy_network.apply - value_apply = ppo_network.value_network.apply - - # Put the time dimension first. - data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) - policy_logits = policy_apply(normalizer_params, params.policy, data.observation) - - baseline = value_apply(normalizer_params, params.value, data.observation) - - bootstrap_value = value_apply(normalizer_params, params.value, - data.next_observation[-1]) - - rewards = data.reward * reward_scaling - truncation = data.extras['state_extras']['truncation'] - termination = (1 - data.discount) * (1 - truncation) - - target_action_log_probs = parametric_action_distribution.log_prob( - policy_logits, data.extras['policy_extras']['raw_action']) - behaviour_action_log_probs = data.extras['policy_extras']['log_prob'] - - vs, advantages = compute_gae( - truncation=truncation, - termination=termination, - rewards=rewards, - values=baseline, - bootstrap_value=bootstrap_value, - lambda_=gae_lambda, - discount=discounting) - if normalize_advantage: - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) - rho_s = jnp.exp(target_action_log_probs - behaviour_action_log_probs) - - surrogate_loss1 = rho_s * advantages - surrogate_loss2 = jnp.clip(rho_s, 1 - clip_ratio, - 1 + clip_ratio) * advantages - - policy_loss = -jnp.mean(jnp.minimum(surrogate_loss1, surrogate_loss2)) - - # Value function loss - v_error = vs - baseline - v_loss = jnp.mean(v_error * v_error) * 0.5 * 0.5 - - # Entropy reward - entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng)) - entropy_loss = entropy_cost * -entropy - - total_loss = policy_loss + v_loss + entropy_loss - return total_loss, { - 'total_loss': total_loss, - 'policy_loss': policy_loss, - 'v_loss': v_loss, - 'entropy_loss': entropy_loss - } \ No newline at end of file diff --git a/components/network.py b/components/network.py index a216c83..253360b 100644 --- a/components/network.py +++ b/components/network.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,17 +13,20 @@ # limitations under the License. import jax +import distrax from jax import lax -from jax import numpy as jnp +import jax.numpy as jnp import flax.linen as nn from flax.linen.initializers import lecun_uniform from typing import Any, Callable, Sequence + + Initializer = Callable[..., Any] class MLP(nn.Module): - """Multilayer Perceptron""" + '''Multilayer Perceptron''' layer_dims: Sequence[int] hidden_act: str = 'ReLU' output_act: str = 'Linear' @@ -63,14 +66,26 @@ def __call__(self, x): nn.relu ]) -D64_64 = nn.Sequential([ +D256 = nn.Sequential([ lambda x: x.reshape((x.shape[0], -1)), # flatten - nn.Dense(64), - nn.relu, - nn.Dense(64), + nn.Dense(256), nn.relu ]) + +class MNIST_CNN(nn.Module): + output_dim: int = 10 + + def setup(self): + self.feature_net = C16_D32 + self.head = nn.Dense(self.output_dim) + + def __call__(self, obs): + phi = self.feature_net(obs) + logits = self.head(phi) + return logits + + def select_feature_net(env_name): # Select a feature net if 'small' in env_name: @@ -79,8 +94,8 @@ def select_feature_net(env_name): return C16_D32 elif env_name in ['random_walk']: return lambda x: x - elif env_name in ['Pendulum-v1', 'CartPole-v1', 'MountainCar-v0', 'MountainCarContinuous-v0', 'Acrobot-v1'] or 'bsuite' in env_name: - return D64_64 + elif env_name in ['catch']: + return D256 class ActorVCriticNet(nn.Module): @@ -102,65 +117,4 @@ def __call__(self, obs): # Compute state value and action ditribution logits v = self.critic_net(phi).squeeze() action_logits = self.actor_net(phi) - return action_logits, v - - -class RobustRNN(nn.Module): - name: str = 'RobustRNN' - rnn_type: str = 'GRU' - mlp_dims: Sequence[int] = () - hidden_size: int = 8 - out_size: int = 1 - eps: float = 1e-18 - - def setup(self): - # Set up RNN - if self.rnn_type == 'LSTM': - self.rnn = nn.OptimizedLSTMCell() - elif self.rnn_type == 'GRU': - self.rnn = nn.GRUCell() - # Set up MLP - layers = [] - layer_dims = list(self.mlp_dims) - layer_dims.append(self.out_size) - for i in range(len(layer_dims)): - layers.append(nn.Dense(layer_dims[i])) - layers.append(nn.relu) - layers.pop() - self.mlp = nn.Sequential(layers) - - def __call__(self, h, g): - g_sign = jnp.sign(g) - g_log = jnp.log(jnp.abs(g) + self.eps) - g_sign = lax.stop_gradient(g_sign[..., None]) - g_log = lax.stop_gradient(g_log[..., None]) - g_input = jnp.concatenate([g_sign, g_log], axis=-1) - h, x = self.rnn(h, g_input) - outs = self.mlp(x) - out = g_sign[..., 0] * jnp.exp(outs[..., 0]) - return h, out - - def init_hidden_state(self, params): - # Use fixed random key since default state init fn is just zeros. - if self.rnn_type == 'LSTM': - h = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), params.shape, self.hidden_size) - elif self.rnn_type == 'GRU': - h = nn.GRUCell.initialize_carry(jax.random.PRNGKey(0), params.shape, self.hidden_size) - return h - - -class NormalRNN(RobustRNN): - name: str = 'NormalRNN' - rnn_type: str = 'GRU' - mlp_dims: Sequence[int] = () - hidden_size: int = 8 - out_size: int = 1 - - def __call__(self, h, g): - # Expand parameter dimension so that the network is "coodinatewise" - g = lax.stop_gradient(g[..., None]) - g_input = g - h, x = self.rnn(h, g_input) - outs = self.mlp(x) - out = outs[..., 0] - return h, out \ No newline at end of file + return action_logits, v \ No newline at end of file diff --git a/components/optim.py b/components/optim.py index 7e661ae..46a4b9b 100644 --- a/components/optim.py +++ b/components/optim.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,76 +13,84 @@ # limitations under the License. import jax -import chex +import flax import optax -import pickle import flax.linen as nn +import jax.numpy as jnp +from copy import deepcopy from typing import Sequence -from jax import numpy as jnp from jax import lax, random, tree_util -from learned_optimization.learned_optimizers.adafac_nominal import MLPNomLOpt +from flax.linen.initializers import zeros_init +from components.star import MLPNomLOpt +from utils.helper import load_model_param -@chex.dataclass + +activations = { + 'ReLU': nn.relu, + 'ELU': nn.elu, + 'Softplus': nn.softplus, + 'LeakyReLU': nn.leaky_relu, + 'Tanh': jnp.tanh, + 'Sigmoid': nn.sigmoid, + 'Exp': jnp.exp +} + + +@flax.struct.dataclass class OptimState: """Contains training state for the learner.""" - hidden_state: chex.ArrayTree - optim_param: chex.ArrayTree + hidden_state: flax.core.FrozenDict + optim_param: flax.core.FrozenDict + iteration: jnp.ndarray -def set_optimizer(optimizer_name, optimizer_kwargs, key): - if optimizer_name in ['LinearOptim', 'Optim4RL', 'L2LGD2']: - optimizer = OptimizerWrapper(optimizer_name, optimizer_kwargs, key) - elif optimizer_name == 'Star': - optimizer = StarWrapper(optimizer_name, optimizer_kwargs, key) +def set_optim(optim_name, original_optim_cfg, key): + optim_cfg = deepcopy(original_optim_cfg) + if optim_name in ['LinearOptim', 'L2LGD2'] or 'Optim4RL' in optim_name: + optim = OptimizerWrapper(optim_name, optim_cfg, key) + elif optim_name == 'Star': + optim = StarWrapper(optim_name, optim_cfg, key) else: - gradient_clip = optimizer_kwargs['gradient_clip'] - del optimizer_kwargs['gradient_clip'] - if gradient_clip > 0: - optimizer = optax.chain( - optax.clip(gradient_clip), - getattr(optax, optimizer_name.lower())(**optimizer_kwargs) + optim_cfg.setdefault('grad_clip', -1) + optim_cfg.setdefault('grad_norm', -1) + grad_clip = optim_cfg['grad_clip'] + grad_norm = optim_cfg['grad_norm'] + del optim_cfg['grad_clip'], optim_cfg['grad_norm'] + if grad_clip > 0: + optim = optax.chain( + optax.clip(grad_clip), + getattr(optax, optim_name.lower())(**optim_cfg) + ) + elif grad_norm > 0: + optim = optax.chain( + optax.clip_by_global_norm(grad_norm), + getattr(optax, optim_name.lower())(**optim_cfg) ) else: - optimizer = getattr(optax, optimizer_name.lower())(**optimizer_kwargs) - return optimizer - - -def set_meta_optimizer(optimizer_name, optimizer_kwargs, key): - if 'gradient_clip' in optimizer_kwargs.keys(): - gradient_clip = optimizer_kwargs['gradient_clip'] - del optimizer_kwargs['gradient_clip'] - else: - gradient_clip = -1 - - if gradient_clip > 0: - optimizer = optax.chain( - optax.clip(gradient_clip), - getattr(optax, optimizer_name.lower())(**optimizer_kwargs) - ) - else: - optimizer = getattr(optax, optimizer_name.lower())(**optimizer_kwargs) - return optimizer + optim = getattr(optax, optim_name.lower())(**optim_cfg) + return optim class Optim4RL(nn.Module): name: str = 'Optim4RL' - mlp_dims: Sequence[int] = (16, 16) + rnn_hidden_act: str = 'Tanh' + mlp_dims: Sequence[int] = () hidden_size: int = 8 learning_rate: float = 1.0 - gradient_clip: float = -1.0 - eps: float = 1e-18 - eps_root: float = 1e-18 - bias: float = 1.0 + eps: float = 1e-8 + out_size1: int = 1 + out_size2: int = 1 def setup(self): - # Set RNNs - self.rnn1 = nn.GRUCell() - self.rnn2 = nn.GRUCell() - # Set MLPs + # Set up RNNs + act_fn = activations[self.rnn_hidden_act] + self.rnn1 = nn.GRUCell(features=self.hidden_size, activation_fn=act_fn) + self.rnn2 = nn.GRUCell(features=self.hidden_size, activation_fn=act_fn) + # Set up MLPs layer_dims1, layer_dims2 = list(self.mlp_dims), list(self.mlp_dims) - layer_dims1.append(2) - layer_dims2.append(1) + layer_dims1.append(self.out_size1) + layer_dims2.append(self.out_size2) layers1, layers2 = [], [] for i in range(len(layer_dims1)): layers1.append(nn.Dense(layer_dims1[i])) @@ -94,225 +102,202 @@ def setup(self): self.mlp1 = nn.Sequential(layers1) self.mlp2 = nn.Sequential(layers2) - def __call__(self, h, g): - # Clip the gradient to prevent large update - g = lax.select(self.gradient_clip > 0, jnp.clip(g, -self.gradient_clip, self.gradient_clip), g) - g_sign = jnp.sign(g) - g_log = jnp.log(jnp.abs(g) + self.eps) + def __call__(self, h, g, t=0): # Expand parameter dimension so that the network is "coodinatewise" - g_sign = lax.stop_gradient(g_sign[..., None]) - g_log = lax.stop_gradient(g_log[..., None]) - g_input = jnp.concatenate([g_sign, g_log], axis=-1) + g = lax.stop_gradient(g[..., None]) + g_square = lax.stop_gradient(jnp.square(g)) + g_sign = lax.stop_gradient(jnp.sign(g)) # RNN h1, h2 = h # Compute m: 1st pseudo moment estimate - h1, x1 = self.rnn1(h1, g_input) + h1, x1 = self.rnn1(h1, g) o1 = self.mlp1(x1) - # Add a small bias so that m_sign=1 initially - m_sign_raw = jnp.tanh(o1[..., 0]+self.bias) - m_sign = lax.stop_gradient(2.0*(m_sign_raw >= 0.0) - 1.0 - m_sign_raw) + m_sign_raw - m = g_sign[..., 0] * m_sign * jnp.exp(o1[..., 1]) + m = g_sign * jnp.exp(o1) # Compute v: 2nd pseudo moment estimate - h2, x2 = self.rnn2(h2, 2.0*g_log) + h2, x2 = self.rnn2(h2, g_square) o2 = self.mlp2(x2) - sqrt_v = jnp.sqrt(jnp.exp(o2[..., 0]) + self.eps_root) - # Compute the parameter update - out = -self.learning_rate * m / sqrt_v - return (h1, h2), out + rsqrt_v = lax.rsqrt(jnp.exp(o2) + self.eps) + # Compute the output: Delta theta + out = -self.learning_rate * m * rsqrt_v + return jnp.array([h1, h2]), out[..., 0] - def init_hidden_state(self, params): + def init_hidden_state(self, param): # Use fixed random key since default state init fn is just zeros. - h = ( - nn.GRUCell.initialize_carry(random.PRNGKey(0), params.shape, self.hidden_size), - nn.GRUCell.initialize_carry(random.PRNGKey(0), params.shape, self.hidden_size) - ) + seed = random.PRNGKey(0) + mem_shape = param.shape + (self.hidden_size,) + h = jnp.array([ + zeros_init()(seed, mem_shape), + zeros_init()(seed, mem_shape) + ]) return h class LinearOptim(nn.Module): name: str = 'LinearOptim' - mlp_dims: Sequence[int] = (16, 16) + mlp_dims: Sequence[int] = () hidden_size: int = 8 learning_rate: float = 1.0 - gradient_clip: float = -1.0 - eps: float = 1e-18 + eps: float = 1e-8 + out_size: int = 3 def setup(self): - # Set RNN - self.rnn = nn.GRUCell() - # Set MLP + # Set up RNN + self.rnn = nn.GRUCell(features=self.hidden_size) + # Set up MLP layer_dims = list(self.mlp_dims) - layer_dims.append(3) + layer_dims.append(self.out_size) layers = [] - for i in range(len(layer_dims)): + for i in range(len(layer_dims)): layers.append(nn.Dense(layer_dims[i])) layers.append(nn.relu) layers.pop() self.mlp = nn.Sequential(layers) - def __call__(self, h, g): - # Clip the gradient to prevent large update - g = lax.select(self.gradient_clip > 0, jnp.clip(g, -self.gradient_clip, self.gradient_clip), g) - g = lax.stop_gradient(g) - g_sign = jnp.sign(g) - g_log = jnp.log(jnp.abs(g) + self.eps) + def __call__(self, h, g, t=0): # Expand parameter dimension so that the network is "coodinatewise" - g_sign = lax.stop_gradient(g_sign[..., None]) - g_log = lax.stop_gradient(g_log[..., None]) - g_input = jnp.concatenate([g_sign, g_log], axis=-1) - h, x = self.rnn(h, g_input) + g = lax.stop_gradient(g[..., None]) + h, x = self.rnn(h, g) outs = self.mlp(x) # Slice outs into several elements o1, o2, o3 = outs[..., 0], outs[..., 1], outs[..., 2] # Compute the output: Delta theta - out = -self.learning_rate * (jnp.exp(o1) * g + jnp.exp(o2) * o3) + out = -self.learning_rate * (jnp.exp(o1) * g[..., 0] + jnp.exp(o2) * o3) return h, out - def init_hidden_state(self, params): + def init_hidden_state(self, param): # Use fixed random key since default state init fn is just zeros. - h = nn.GRUCell.initialize_carry(random.PRNGKey(0), params.shape, self.hidden_size) + seed = random.PRNGKey(0) + mem_shape = param.shape + (self.hidden_size,) + h = jnp.array(zeros_init()(seed, mem_shape)) return h -class L2LGD2(nn.Module): - """ - Implementaion of [Learning to learn by gradient descent by gradient descent](http://arxiv.org/abs/1606.04474) - Note that this implementation is not exactly the same as the original paper. - For example, we use a slightly different gradient processing. - """ +class L2LGD2(LinearOptim): name: str = 'L2LGD2' - mlp_dims: Sequence[int] = (16, 16) + mlp_dims: Sequence[int] = () hidden_size: int = 8 learning_rate: float = 1.0 - gradient_clip: float = -1.0 - eps: float = 1e-18 + p: int = 10 + out_size: int = 1 def setup(self): - # Set RNN - self.rnn = nn.GRUCell() - # Set MLP - layer_dims = list(self.mlp_dims) - layer_dims.append(1) - layers = [] - for i in range(len(layer_dims)): - layers.append(nn.Dense(layer_dims[i])) - layers.append(nn.relu) - layers.pop() - self.mlp = nn.Sequential(layers) + super().setup() + self.f_select = jax.vmap(lambda s, x, y: lax.select(s>=-1.0, x, y)) - def __call__(self, h, g): - # Clip the gradient to prevent large update - g = lax.select(self.gradient_clip > 0, jnp.clip(g, -self.gradient_clip, self.gradient_clip), g) - g = lax.stop_gradient(g) - g_sign = jnp.sign(g) - g_log = jnp.log(jnp.abs(g) + self.eps) + def __call__(self, h, g, t=0): # Expand parameter dimension so that the network is "coodinatewise" - g_sign = lax.stop_gradient(g_sign[..., None]) - g_log = lax.stop_gradient(g_log[..., None]) - g_input = jnp.concatenate([g_sign, g_log], axis=-1) - h, x = self.rnn(h, g_input) + g = lax.stop_gradient(g[..., None]) + g_sign = jnp.sign(g) + g_log = jnp.log(jnp.abs(g) + self.eps) / self.p + g_in1 = jnp.concatenate([g_log, g_sign], axis=-1).reshape((-1,2)) + g_in2 = jnp.concatenate([-1.0*jnp.ones_like(g), jnp.exp(self.p)*g], axis=-1).reshape((-1,2)) + g_in = self.f_select(g_log.reshape(-1), g_in1, g_in2) + g_in = g_in.reshape(g.shape[:-1]+(2,)) + h, x = self.rnn(h, g_in) outs = self.mlp(x) - # Slice outs into several elements - o = outs[..., 0] # Compute the output: Delta theta - out = -self.learning_rate * jnp.exp(o) * g - return h, out - - def init_hidden_state(self, params): - # Use fixed random key since default state init fn is just zeros. - h = nn.GRUCell.initialize_carry(random.PRNGKey(0), params.shape, self.hidden_size) - return h - - -def load_model_param(filepath): - f = open(filepath, 'rb') - model_param = pickle.load(f) - model_param = tree_util.tree_map(jnp.array, model_param) - f.close() - return model_param - + out = -self.learning_rate * jnp.exp(outs) * g + return h, out[..., 0] + class OptimizerWrapper(object): - """Optimizer Wrapper for learned optimizers: Optim4RL, LinearOptim, and L2LGD2.""" - def __init__(self, optimizer_name, cfg, seed): + """Optimizer Wrapper for learned optimizers.""" + def __init__(self, optim_name, cfg, seed): self.seed = seed - self.optimizer_name = optimizer_name - cfg['name'] = optimizer_name - cfg['mlp_dims'] = tuple(cfg['mlp_dims']) + self.optim_name = optim_name + cfg['name'] = optim_name + if 'mlp_dims' in cfg.keys(): + cfg['mlp_dims'] = tuple(cfg['mlp_dims']) cfg.setdefault('param_load_path', '') self.param_load_path = cfg['param_load_path'] - del cfg['param_load_path'] + cfg.setdefault('grad_clip', -1.0) + self.grad_clip = cfg['grad_clip'] + del cfg['param_load_path'], cfg['grad_clip'] # Set RNN optimizer - if optimizer_name == 'Optim4RL': - self.optimizer = Optim4RL(**cfg) - self.is_rnn_output = lambda x: type(x)==tuple and type(x[0])==tuple and type(x[1])!=tuple - elif optimizer_name in ['LinearOptim', 'L2LGD2']: - if optimizer_name == 'LinearRNNOptimizer': - self.optimizer = LinearRNNOptimizer(**cfg) - elif optimizer_name == 'L2LGD2': - self.optimizer = L2LGD2(**cfg) - self.is_rnn_output = lambda x: type(x)==tuple and type(x[0])!=tuple and type(x[1])!=tuple + assert optim_name in ['LinearOptim', 'L2LGD2'] or 'Optim4RL' in optim_name, f'{optim_name} is not supported.' + self.optim = eval(optim_name)(**cfg) + self.is_rnn_output = lambda x: type(x)==tuple and type(x[0])!=tuple and type(x[1])!=tuple # Initialize param for RNN optimizer - if len(self.param_load_path) > 0: + if len(self.param_load_path)>0: self.optim_param = load_model_param(self.param_load_path) else: dummy_grad = jnp.array([0.0]) - dummy_hidden_state = self.optimizer.init_hidden_state(dummy_grad) - self.optim_param = self.optimizer.init(self.seed, dummy_hidden_state, dummy_grad) + dummy_hidden_state = self.optim.init_hidden_state(dummy_grad) + self.optim_param = self.optim.init(self.seed, dummy_hidden_state, dummy_grad) def init(self, param): """ - Initialize optim_state, i.e. hidden_state of RNN optimizer + optim parameter - optim_state = optimizer.init(param) + Initialize optim_state, i.e. hidden_state of RNN optimizer + optim_state = optim.init(param) """ - hidden_state = tree_util.tree_map(self.optimizer.init_hidden_state, param) - optim_state = OptimState(hidden_state=hidden_state, optim_param=self.optim_param) + hidden_state = tree_util.tree_map(self.optim.init_hidden_state, param) + optim_state = OptimState(hidden_state=hidden_state, optim_param=self.optim_param, iteration=0) return optim_state - def update(self, grad, optim_state, params=None): - """param_update, optim_state = optimizer.update(grad, optim_state)""" + def update(self, grad, optim_state): + """param_update, optim_state = optim.update(grad, optim_state)""" + # Clip the gradient to prevent large update + grad = lax.cond( + self.grad_clip > 0, + lambda x: jax.tree_util.tree_map(lambda g: jnp.clip(g, -self.grad_clip, self.grad_clip), x), + lambda x: x, + grad + ) out = jax.tree_util.tree_map( - lambda grad, hidden: self.optimizer.apply(optim_state.optim_param, hidden, grad), - grad, - optim_state.hidden_state + lambda hidden, grad: self.optim.apply(optim_state.optim_param, hidden, grad, optim_state.iteration), + optim_state.hidden_state, + grad ) # Split output into hidden_state and agent parameter update hidden_state = jax.tree_util.tree_map(lambda x: x[0], out, is_leaf=self.is_rnn_output) param_update = jax.tree_util.tree_map(lambda x: x[1], out, is_leaf=self.is_rnn_output) optim_state = optim_state.replace(hidden_state=hidden_state) + optim_state = optim_state.replace(iteration=optim_state.iteration+1) return param_update, optim_state def update_with_param(self, optim_param, grad, optim_state, lr=1.0): """ - param_update, optim_state = optimizer.update(optim_param, grad, optim_state) + param_update, optim_state = optim.update(optim_param, grad, optim_state) Used for training the optimizer. """ + grad = lax.cond( + self.grad_clip > 0, + lambda x: jax.tree_util.tree_map(lambda g: jnp.clip(g, -self.grad_clip, self.grad_clip), x), + lambda x: x, + grad + ) out = jax.tree_util.tree_map( - lambda grad, hidden: self.optimizer.apply(optim_param, hidden, grad), + lambda hidden, grad: self.optim.apply(optim_param, hidden, grad, optim_state.iteration), + optim_state.hidden_state, grad, - optim_state.hidden_state ) # Split output into hidden_state and agent parameter update hidden_state = jax.tree_util.tree_map(lambda x: x[0], out, is_leaf=self.is_rnn_output) param_update = jax.tree_util.tree_map(lambda x: lr * x[1], out, is_leaf=self.is_rnn_output) optim_state = optim_state.replace(hidden_state=hidden_state) + optim_state = optim_state.replace(iteration=optim_state.iteration+1) return param_update, optim_state class StarWrapper(object): """Optimizer Wrapper for STAR.""" - def __init__(self, optimizer_name, cfg, seed): + def __init__(self, optim_name, cfg, seed): self.seed = seed - self.optimizer_name = optimizer_name + self.optim_name = optim_name self.train_steps = cfg['train_steps'] cfg.setdefault('param_load_path', '') self.param_load_path = cfg['param_load_path'] - del cfg['param_load_path'], cfg['train_steps'] + cfg.setdefault('grad_clip', -1.0) + self.grad_clip = cfg['grad_clip'] + # Rename lr to step_mult + cfg['step_mult'] = cfg['learning_rate'] + cfg.setdefault('nominal_stepsize', 1e-3) + del cfg['param_load_path'], cfg['train_steps'], cfg['grad_clip'], cfg['learning_rate'] # Set Star optimizer - assert optimizer_name == 'Star', 'Only Star is supported.' + assert optim_name == 'Star', 'Only Star is supported.' self.star_net = MLPNomLOpt(**cfg) # Initialize param for Star optimizer - if len(self.param_load_path) > 0: + if len(self.param_load_path)>0: optim_param = load_model_param(self.param_load_path) else: optim_param = self.star_net.init(self.seed) @@ -320,25 +305,37 @@ def __init__(self, optimizer_name, cfg, seed): self.reset_optimizer(optim_param) def reset_optimizer(self, optim_param): - self.optimizer = self.star_net.opt_fn(optim_param, is_training=True) + self.optim = self.star_net.opt_fn(optim_param, is_training=True) def init(self, param): """Initialize optim_state""" - optim_state = self.optimizer.init(params=param, num_steps=self.train_steps) + optim_state = self.optim.init(params=param, num_steps=self.train_steps) return optim_state def get_optim_param(self): - return self.optimizer.theta + return self.optim.theta def update(self, grad, optim_state, loss): - optim_state = self.optimizer.update(optim_state, grad, loss) + grad = lax.cond( + self.grad_clip > 0, + lambda x: jax.tree_util.tree_map(lambda g: jnp.clip(g, -self.grad_clip, self.grad_clip), x), + lambda x: x, + grad + ) + optim_state = self.optim.update(optim_state, grad, loss) return optim_state def update_with_param(self, optim_param, grad, optim_state, loss): """ - optim_state = optimizer.update(optim_param, grad, optim_state, loss) + param_update, optim_state = optim.update(optim_param, grad, optim_state) Used for training the optimizer. """ - self.optimizer.theta = optim_param # This is important, do not remove - optim_state = self.optimizer.update(optim_state, grad, loss) - return optim_state + grad = lax.cond( + self.grad_clip > 0, + lambda x: jax.tree_util.tree_map(lambda g: jnp.clip(g, -self.grad_clip, self.grad_clip), x), + lambda x: x, + grad + ) + self.optim.theta = optim_param # This is important + optim_state = self.optim.update(optim_state, grad, loss) + return optim_state \ No newline at end of file diff --git a/components/ppo_networks.py b/components/ppo_networks.py deleted file mode 100644 index 133fa48..0000000 --- a/components/ppo_networks.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2022 Garena Online Private Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2022 The Brax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""PPO networks.""" - -from typing import Sequence, Tuple - -import flax -from flax import linen -from brax.training import distribution, networks, types - - -@flax.struct.dataclass -class PPONetworks: - policy_network: networks.FeedForwardNetwork - value_network: networks.FeedForwardNetwork - parametric_action_distribution: distribution.ParametricDistribution - - -def make_inference_fn(ppo_networks: PPONetworks): - """Creates params and inference function for the PPO agent.""" - - def make_policy(params: types.PolicyParams, - deterministic: bool = False) -> types.Policy: - policy_network = ppo_networks.policy_network - parametric_action_distribution = ppo_networks.parametric_action_distribution - - def policy(observations: types.Observation, - key_sample: types.PRNGKey) -> Tuple[types.Action, types.Extra]: - logits = policy_network.apply(*params, observations) - if deterministic: - return ppo_networks.parametric_action_distribution.mode(logits), {} - raw_actions = parametric_action_distribution.sample_no_postprocessing( - logits, key_sample) - log_prob = parametric_action_distribution.log_prob(logits, raw_actions) - postprocessed_actions = parametric_action_distribution.postprocess( - raw_actions) - return postprocessed_actions, { - 'log_prob': log_prob, - 'raw_action': raw_actions - } - - return policy - - return make_policy - - -def make_ppo_networks( - observation_size: int, - action_size: int, - preprocess_observations_fn: types.PreprocessObservationFn = types - .identity_observation_preprocessor, - policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, - value_hidden_layer_sizes: Sequence[int] = (256,) * 5, - activation: networks.ActivationFn = linen.swish) -> PPONetworks: - """Make PPO networks with preprocessor.""" - parametric_action_distribution = distribution.NormalTanhDistribution( - event_size=action_size) - policy_network = networks.make_policy_network( - parametric_action_distribution.param_size, - observation_size, - preprocess_observations_fn=preprocess_observations_fn, - hidden_layer_sizes=policy_hidden_layer_sizes, - activation=activation) - value_network = networks.make_value_network( - observation_size, - preprocess_observations_fn=preprocess_observations_fn, - hidden_layer_sizes=value_hidden_layer_sizes, - activation=activation) - - return PPONetworks( - policy_network=policy_network, - value_network=value_network, - parametric_action_distribution=parametric_action_distribution) \ No newline at end of file diff --git a/components/running_statistics.py b/components/running_statistics.py index c4a4f26..5e9667d 100644 --- a/components/running_statistics.py +++ b/components/running_statistics.py @@ -1,17 +1,3 @@ -# Copyright 2022 Garena Online Private Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # Copyright 2022 The Brax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -31,12 +17,12 @@ https://github.com/deepmind/acme/blob/master/acme/jax/running_statistics.py """ -from typing import Optional, Tuple +from typing import Any, Optional, Tuple -import jax -import jax.numpy as jnp from brax.training.acme import types from flax import struct +import jax +import jax.numpy as jnp def _zeros_like(nest: types.Nest, dtype=None) -> types.Nest: diff --git a/components/star.py b/components/star.py new file mode 100644 index 0000000..70dd7a8 --- /dev/null +++ b/components/star.py @@ -0,0 +1,508 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MLP learned optimizer with adafactor features and nominal AggMo term.""" +import functools +from typing import Any, Optional + +import flax +import gin +import haiku as hk +import jax +from jax import lax +import jax.numpy as jnp +from learned_optimization import tree_utils +from learned_optimization.learned_optimizers import base as lopt_base +from learned_optimization.learned_optimizers import common +from learned_optimization.optimizers import base as opt_base +import numpy as onp + +PRNGKey = jnp.ndarray + + +def second_moment_normalizer(x, axis, eps=1e-5): + return x * lax.rsqrt(eps + jnp.mean(jnp.square(x), axis=axis, keepdims=True)) + + +def tanh_embedding(x): + f32 = jnp.float32 + + def one_freq(timescale): + return jnp.tanh(x / (f32(timescale)) - 1.0) + + timescales = jnp.asarray( + [1, 3, 10, 30, 100, 300, 1000, 3000, 10000, 30000, 100000], + dtype=jnp.float32) + return jax.vmap(one_freq)(timescales) + + +@flax.struct.dataclass +class AdafacMLPLOptState: + params: Any + state: Any + mom_rolling: common.MomAccumulator + rms_rolling: common.RMSAccumulator + fac_rolling_features: common.FactoredAccum + num_steps: jnp.ndarray + iteration: jnp.ndarray + + +def decay_to_param(x): + return jnp.log(1 - x) / 10. + + +def param_to_decay(x): + return 1 - jnp.exp(x * 10.) + + +@gin.configurable +class MLPNomLOpt(lopt_base.LearnedOptimizer): + """MLP based learned optimizer with adafactor style accumulators.""" + + def __init__(self, + exp_mult=(0.001, 0.001, 0.001), + step_mult=0.001, + hidden_size=4, + hidden_layers=2, + initial_momentum_decays=(0.9, 0.99, 0.999), + initial_rms_decays=(0.999,), + initial_adafactor_decays=(0.9, 0.99, 0.999), + nominal_stepsize=0., + weight_decay=0., + concat_weights=True, + make_separate_weights=False, + split_weights=False, + nominal_controller=False, + regularization_controller=False, + nominal_grad_estimator="AdamAggMo", + aggregate_magnitude=False, + aggregate_nom_magnitude=False, + aggregate_reg_magnitude=False, + normalize_blackbox=False, + selfnormalize_blackbox=False): + + super().__init__() + self._exp_mult = exp_mult + self._step_mult = step_mult + self._hidden_size = hidden_size + self._hidden_layers = hidden_layers + self._initial_momentum_decays = initial_momentum_decays + self._initial_rms_decays = initial_rms_decays + self._initial_adafactor_decays = initial_adafactor_decays + self._concat_weights = concat_weights + self._make_separate_weights = make_separate_weights + self._split_weights = split_weights + self._stepsize = nominal_stepsize + self._nom_controller = nominal_controller + self._reg_controller = regularization_controller + self._weight_decay = weight_decay + self._nominal_grad_estimator = nominal_grad_estimator + self._aggregate_magnitude = aggregate_magnitude + self._aggregate_nom_magnitude = aggregate_nom_magnitude + self._aggregate_reg_magnitude = aggregate_reg_magnitude + self._normalize_blackbox = normalize_blackbox + self._selfnormalize_blackbox = selfnormalize_blackbox + + self._mod_init, self._mod_apply = hk.without_apply_rng( + hk.transform(self._mod)) + + def _mod(self, global_feat, p, g, m, rms, fac_g, fac_vec_col, fac_vec_row, + fac_vec_v): + # this doesn't work with scalar parameters, so instead lets just reshape. + if not p.shape: + p = jnp.expand_dims(p, 0) + g = jnp.expand_dims(g, 0) + m = jnp.expand_dims(m, 0) + rms = jnp.expand_dims(rms, 0) + fac_g = jnp.expand_dims(fac_g, 0) + fac_vec_v = jnp.expand_dims(fac_vec_v, 0) + fac_vec_col = jnp.expand_dims(fac_vec_col, 0) + fac_vec_row = jnp.expand_dims(fac_vec_row, 0) + did_reshape = True + else: + did_reshape = False + inps = [] + + inps.append(jnp.expand_dims(g, axis=-1)) + inps.append(jnp.expand_dims(p, axis=-1)) + inps.append(m) + inps.append(rms) + rsqrt = lax.rsqrt(rms + 1e-6) + adam_feats = m * rsqrt + inps.append(adam_feats) + inps.append(rsqrt) + inps.append(fac_g) + + factored_dims = common.factored_dims(g.shape) + if factored_dims is not None: + # Construct features for + d1, d0 = factored_dims + + # add 2 dims: 1 for batch of decay, one because low rank + to_tile = [1] * (1 + len(g.shape)) + to_tile[d0] = g.shape[d0] + + row_feat = jnp.tile(jnp.expand_dims(fac_vec_row, axis=d0), to_tile) + + to_tile = [1] * (1 + len(g.shape)) + to_tile[d1] = g.shape[d1] + col_feat = jnp.tile(jnp.expand_dims(fac_vec_col, axis=d1), to_tile) + + # 3 possible kinds of adafactor style features. + # Raw values + inps.append(row_feat) + inps.append(col_feat) + + # 1/sqrt + inps.append(lax.rsqrt(row_feat + 1e-8)) + inps.append(lax.rsqrt(col_feat + 1e-8)) + + # multiplied by momentum + reduced_d1 = d1 - 1 if d1 > d0 else d1 + row_col_mean = jnp.mean(fac_vec_row, axis=reduced_d1, keepdims=True) + + row_factor = common.safe_rsqrt(fac_vec_row / (row_col_mean + 1e-9)) + col_factor = common.safe_rsqrt(fac_vec_col) + fac_mom_mult = ( + m * jnp.expand_dims(row_factor, axis=d0) * + jnp.expand_dims(col_factor, axis=d1)) + inps.append(fac_mom_mult) + else: + # In the non-factored case, match what RMSProp does. + inps.append(fac_vec_v) + inps.append(fac_vec_v) + + inps.append(lax.rsqrt(fac_vec_v + 1e-8)) + inps.append(lax.rsqrt(fac_vec_v + 1e-8)) + + fac_mom_mult = m * (fac_vec_v + 1e-6)**-0.5 + inps.append(fac_mom_mult) + + # Build the weights of the NN + last_size = jnp.concatenate(inps, axis=-1).shape[-1] + last_size += global_feat["training_step_feature"].shape[-1] + + weights = [] + biases = [] + + out_dim = [4] + for wi, w in enumerate([self._hidden_size] * self._hidden_layers + out_dim): + stddev = 1. / onp.sqrt(last_size) + w_init = hk.initializers.TruncatedNormal(stddev=stddev) + + make_full_weights = self._concat_weights or ( + not self._make_separate_weights) + if make_full_weights: + weights.append( + hk.get_parameter( + f"w{wi}", shape=(last_size, w), dtype=jnp.float32, init=w_init)) + biases.append( + hk.get_parameter( + f"b{wi}", shape=(w,), dtype=jnp.float32, init=jnp.zeros)) + else: + # Otherwise weights will be stored as scalars. + # these scalars could be made from scratch, split from weights made + # above + if self._make_separate_weights: + # Manually make the weight matrix in scalars. + weights.append([]) + for vi in range(last_size): + ww = [] + for oi in range(w): + wij = hk.get_parameter( + f"w{wi}_{vi}_{oi}", shape=[], dtype=jnp.float32, init=w_init) + ww.append(wij) + weights[-1].append(ww) + biases.append([]) + for oi in range(w): + b = hk.get_parameter( + f"b{wi}_{oi}", shape=[], dtype=jnp.float32, init=jnp.zeros) + biases[-1].append(b) + elif self._split_weights: + # split up the weights first before running computation. + f = list(x for x in weights[-1].ravel()) + weights[-1] = [[None] * w for i in range(last_size)] + for fi, ff in enumerate(f): + i = fi % last_size + j = fi // last_size + weights[-1][i][j] = ff + biases[-1] = list(b for b in biases[-1]) + last_size = w + + # 2 different methods to compute the learned optimizer weight update are + # provided. First, using matmuls (like a standard NN). Second, with the + # computation unpacked using only scalar math. This uses a different path + # in hardware and can be much faster for small learned optimizer hidden + # sizes. + if self._concat_weights: + # concat the inputs, normalize + inp_stack = jnp.concatenate(inps, axis=-1) + axis = list(range(len(p.shape))) + inp_stack = second_moment_normalizer(inp_stack, axis=axis) + + # add features that should not be normalized + training_step_feature = global_feat["training_step_feature"] + stacked = jnp.reshape(training_step_feature, [1] * len(axis) + + list(training_step_feature.shape[-1:])) + stacked = jnp.tile(stacked, list(p.shape) + [1]) + inp_stack = jnp.concatenate([inp_stack, stacked], axis=-1) + + # Manually run the neural network. + net = inp_stack + for wi, (w, b) in enumerate(zip(weights, biases)): + o_tmp = net @ w + net = o_tmp + jnp.broadcast_to(b, list(net.shape[0:-1]) + [w.shape[-1]]) # pytype: disable=attribute-error + + if wi != len(weights) - 1: + net = jax.nn.relu(net) + + direction = net[..., 0] + magnitude = net[..., 1] + if self._nom_controller: + nom_magnitude = net[..., 2] + else: + nom_magnitude = jnp.ones_like(net[..., 0]) + + if self._reg_controller: + reg_magnitude = net[..., 3] + else: + reg_magnitude = jnp.zeros_like(net[..., 0]) + else: + # The scalar math path. + flat_features = [] + for i in inps: + flat_features.extend( + [jnp.squeeze(x, -1) for x in jnp.split(i, i.shape[-1], axis=-1)]) + + # match the second moment normalize calculation but applied to each scalar + inp = [ + x * lax.rsqrt(1e-5 + jnp.mean(jnp.square(x), keepdims=True)) + for x in flat_features + ] + for wi, (w, b) in enumerate(zip(weights, biases)): + grids = [] + + # hidden layer wi + for oi in range(len(w[0])): + outs = [] + for vi, v in enumerate(inp): + if type(w) == list: # pylint: disable=unidiomatic-typecheck + outs.append(v * w[vi][oi]) + else: + outs.append(v * w[vi, oi]) # pytype: disable=unsupported-operands + + if wi == 0: + training_step_feature = global_feat["training_step_feature"] + for i, vi in enumerate( + range(vi + 1, vi + 1 + len(training_step_feature))): + if type(w) == list: # pylint: disable=unidiomatic-typecheck + outs.append(training_step_feature[i] * w[vi][oi]) + else: + outs.append(training_step_feature[i] * w[vi, oi]) # pytype: disable=unsupported-operands + + grids.append(outs) + + out_mul = [sum(g) for g in grids] + + # bias + inp = [] + for oi, net in enumerate(out_mul): + inp.append(net + b[oi]) + + # activation + if wi != len(weights) - 1: + inp = [jax.nn.relu(x) for x in inp] + + direction = inp[0] + magnitude = inp[1] + nom_magnitude = inp[2] if self._nom_controller else jnp.ones_like(inp[0]) + reg_magnitude = inp[3] if self._reg_controller else jnp.zeros_like(inp[0]) + + if self._aggregate_magnitude: + magnitude = jnp.mean(magnitude) + if self._aggregate_nom_magnitude: + nom_magnitude = jnp.mean(nom_magnitude) + if self._aggregate_reg_magnitude: + reg_magnitude = jnp.mean(reg_magnitude) + + step = direction * jnp.exp(magnitude * self._exp_mult[0]) + step *= self._step_mult + + if self._normalize_blackbox: + step = step * rsqrt[..., 0] + + if self._selfnormalize_blackbox: + step = step * jnp.mean(g) + + step = step.reshape(p.shape) + + reg = (1. - self._weight_decay * jnp.exp(reg_magnitude * self._exp_mult[2])) + reg = reg.reshape(p.shape) + + # nominal grad estimator + if self._nominal_grad_estimator == "SGD": + g_est = g + elif self._nominal_grad_estimator == "SGDM": + g_est = m[..., -1] + elif self._nominal_grad_estimator == "AggMo": + g_est = jnp.mean(m, axis=-1) + elif self._nominal_grad_estimator == "Adam": + g_est = adam_feats[..., -1] + elif self._nominal_grad_estimator == "AdamAggMo": + g_est = jnp.mean(adam_feats, axis=-1) + else: + raise NotImplementedError + + nom_step = self._stepsize * jnp.exp( + nom_magnitude * self._exp_mult[1]) * g_est + + # compute updated params + new_p = reg * p + new_p -= step + new_p -= nom_step + + if did_reshape: + new_p = jnp.squeeze(new_p, 0) + + return new_p + + def init(self, key: PRNGKey) -> lopt_base.MetaParams: + # We meta-learn: + # * weights of the MLP + # * decays of momentum, RMS, and adafactor style accumulators + + training_step_feature = tanh_embedding(1) + global_features = { + "iterations": 0, + "num_steps": 10, + "training_step_feature": training_step_feature, + } + # fake weights with 2 dimension + r = 10 + c = 10 + p = jnp.ones([r, c]) + g = jnp.ones([r, c]) + + m = jnp.ones([r, c, len(self._initial_momentum_decays)]) + rms = jnp.ones([r, c, len(self._initial_rms_decays)]) + + fac_g = jnp.ones([r, c, len(self._initial_adafactor_decays)]) + fac_vec_row = jnp.ones([r, len(self._initial_adafactor_decays)]) + fac_vec_col = jnp.ones([c, len(self._initial_adafactor_decays)]) + fac_vec_v = jnp.ones([len(self._initial_adafactor_decays)]) + mod_theta = self._mod_init(key, global_features, p, g, m, rms, fac_g, + fac_vec_col, fac_vec_row, fac_vec_v) + return hk.data_structures.to_haiku_dict({ + "momentum_decays": jnp.zeros([len(self._initial_momentum_decays)]), + "rms_decays": jnp.zeros([len(self._initial_rms_decays)]), + "adafactor_decays": jnp.zeros([len(self._initial_adafactor_decays)]), + "nn": mod_theta + }) + + def opt_fn(self, + theta: lopt_base.MetaParams, + is_training: Optional[bool] = False) -> opt_base.Optimizer: + + mod_apply = self._mod_apply + parent = self + + class _Opt(opt_base.Optimizer): + """Optimizer capturing the meta params.""" + + def __init__(self, theta): + self.theta = theta + + def _get_rolling(self): + mom_decay = param_to_decay( + decay_to_param(jnp.asarray(parent._initial_momentum_decays)) + # pylint: disable=protected-access + self.theta["momentum_decays"]) + mom_roll = common.vec_rolling_mom(mom_decay) + + rms_decay = param_to_decay( + decay_to_param(jnp.asarray(parent._initial_rms_decays)) + # pylint: disable=protected-access + self.theta["rms_decays"]) + rms_roll = common.vec_rolling_rms(rms_decay) + + adafactor_decay = param_to_decay( + decay_to_param(jnp.asarray(parent._initial_adafactor_decays)) + # pylint: disable=protected-access + self.theta["adafactor_decays"]) + fac_vec_roll = common.vec_factored_rolling(adafactor_decay) + return mom_roll, rms_roll, fac_vec_roll + + def init( + self, + params: opt_base.Params, + model_state: Optional[opt_base.ModelState] = None, + num_steps: Optional[int] = None, + key: Optional[PRNGKey] = None, + ) -> AdafacMLPLOptState: + if num_steps is None: + raise ValueError("Must specify number of steps for this lopt!") + + mom_roll, rms_roll, fac_vec_roll = self._get_rolling() + + return AdafacMLPLOptState( + params=params, + state=model_state, + rms_rolling=rms_roll.init(params), + mom_rolling=mom_roll.init(params), + fac_rolling_features=fac_vec_roll.init(params), + iteration=jnp.asarray(0, dtype=jnp.int32), + num_steps=jnp.asarray(num_steps)) + + def update( + self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks + opt_state: AdafacMLPLOptState, + grad: opt_base.Gradient, + loss: jnp.ndarray, + model_state: Optional[opt_base.ModelState] = None, + is_valid: bool = False, + key: Optional[PRNGKey] = None, + ) -> AdafacMLPLOptState: + mom_roll, rms_roll, fac_vec_roll = self._get_rolling() + next_mom_rolling = mom_roll.update(opt_state.mom_rolling, grad) + next_rms_rolling = rms_roll.update(opt_state.rms_rolling, grad) + next_fac_rolling_features, fac_g = fac_vec_roll.update( + opt_state.fac_rolling_features, grad) + + # compute some global features + training_step_feature = tanh_embedding(opt_state.iteration) + + global_features = { + "iterations": opt_state.iteration, + "num_steps": opt_state.num_steps, + "training_step_feature": training_step_feature, + } + + fun = functools.partial(mod_apply, self.theta["nn"], global_features) + + next_params = jax.tree_util.tree_map(fun, opt_state.params, grad, + next_mom_rolling.m, + next_rms_rolling.rms, fac_g, + next_fac_rolling_features.v_col, + next_fac_rolling_features.v_row, + next_fac_rolling_features.v_diag) + + next_opt_state = AdafacMLPLOptState( + params=next_params, + mom_rolling=next_mom_rolling, + rms_rolling=next_rms_rolling, + fac_rolling_features=next_fac_rolling_features, + iteration=opt_state.iteration + 1, + state=model_state, + num_steps=opt_state.num_steps) + + return tree_utils.match_type(next_opt_state, opt_state) + + return _Opt(theta) \ No newline at end of file diff --git a/components/star_gradients.py b/components/star_gradients.py new file mode 100644 index 0000000..f105954 --- /dev/null +++ b/components/star_gradients.py @@ -0,0 +1,94 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Brax training gradient utility functions.""" + +from typing import Callable, Optional + +import jax +import optax + + +def loss_and_pgrad(loss_fn: Callable[..., float], + pmap_axis_name: Optional[str], + has_aux: bool = False): + g = jax.value_and_grad(loss_fn, has_aux=has_aux) + + def h(*args, **kwargs): + value, grad = g(*args, **kwargs) + return value, jax.lax.pmean(grad, axis_name=pmap_axis_name) + + return g if pmap_axis_name is None else h + + +def gradient_update_fn(loss_fn: Callable[..., float], + optimizer: optax.GradientTransformation, + pmap_axis_name: Optional[str], + has_aux: bool = False): + """Wrapper of the loss function that apply gradient updates. + + Args: + loss_fn: The loss function. + optimizer: The optimizer to apply gradients. + pmap_axis_name: If relevant, the name of the pmap axis to synchronize + gradients. + has_aux: Whether the loss_fn has auxiliary data. + + Returns: + A function that takes the same argument as the loss function plus the + optimizer state. The output of this function is the loss, the new parameter, + and the new optimizer state. + """ + loss_and_pgrad_fn = loss_and_pgrad( + loss_fn, pmap_axis_name=pmap_axis_name, has_aux=has_aux) + + def f(*args, optimizer_state): + value, grads = loss_and_pgrad_fn(*args) + loss = value[0] + optimizer_state = optimizer.update(grads, optimizer_state, loss) + params = optimizer_state.params + return value, params, optimizer_state + + return f + + +def gradient_update_fn_with_optim_param(loss_fn: Callable[..., float], + optimizer: optax.GradientTransformation, + pmap_axis_name: Optional[str], + has_aux: bool = False): + """Wrapper of the loss function that apply gradient updates. + + Args: + loss_fn: The loss function. + optimizer: The optimizer to apply gradients. + pmap_axis_name: If relevant, the name of the pmap axis to synchronize + gradients. + has_aux: Whether the loss_fn has auxiliary data. + + Returns: + A function that takes the same argument as the loss function plus the + optimizer state. The output of this function is the loss, the new parameter, + and the new optimizer state. + """ + loss_and_pgrad_fn = loss_and_pgrad( + loss_fn, pmap_axis_name=pmap_axis_name, has_aux=has_aux) + + def f(*args, optim_param, optimizer_state): + value, grads = loss_and_pgrad_fn(*args) + loss = value[0] + optimizer_state = optimizer.update_with_param(optim_param, grads, optimizer_state, loss) + params = optimizer_state.params + return value, params, optimizer_state + + return f \ No newline at end of file diff --git a/configs/a2c_catch.json b/configs/a2c_catch.json new file mode 100644 index 0000000..c1ffacd --- /dev/null +++ b/configs/a2c_catch.json @@ -0,0 +1,21 @@ +{ + "env": [{ + "name": [["catch"]], + "num_envs": [64], + "train_steps": [5e5] + }], + "agent": [{ + "name": ["A2C"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["RMSProp", "Adam"], + "kwargs": [{"learning_rate": [1e-3], "grad_clip": [-1]}] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/a2c_grid.json b/configs/a2c_grid.json new file mode 100644 index 0000000..77fc0e5 --- /dev/null +++ b/configs/a2c_grid.json @@ -0,0 +1,21 @@ +{ + "env": [{ + "name": [["small_dense_short", "small_dense_long", "big_sparse_short", "big_sparse_long", "big_dense_short", "big_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["A2C"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["RMSProp", "Adam"], + "kwargs": [{"learning_rate": [3e-2, 1e-2, 3e-3, 1e-3, 3e-4, 1e-4], "grad_clip": [1]}] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/ant_collect.json b/configs/ant_collect.json deleted file mode 100644 index 5e8ea0a..0000000 --- a/configs/ant_collect.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "env": [ - { - "name": ["ant"], - "train_steps": [1e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [10], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["CollectPPO"], - "data_reduce": [100], - "gae_lambda": [0.95], - "rollout_steps": [5], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [4], - "entropy_weight": [1e-2] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [1024], - "discount": [0.97], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/ant_lopt.json b/configs/ant_lopt.json deleted file mode 100644 index 018b369..0000000 --- a/configs/ant_lopt.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "env": [ - { - "name": ["ant"], - "train_steps": [1e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [10], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [5], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [4], - "entropy_weight": [1e-2] - } - ], - "optimizer": [ - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [ - "./logs/exp/index/meta_param_path1.pickle", - "./logs/exp/index/rnn_parameter_path2.pickle" - ], - "learning_rate": [3e-4], - "gradient_clip": [1] - } - ] - } - ], - "batch_size": [1024], - "discount": [0.97], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/ant_meta.json b/configs/ant_meta.json deleted file mode 100644 index d099ded..0000000 --- a/configs/ant_meta.json +++ /dev/null @@ -1,55 +0,0 @@ -{ - "env": [ - { - "name": ["ant"], - "train_steps": [3e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [10], - "num_envs": [2048], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["MetaPPO"], - "inner_updates": [4], - "gae_lambda": [0.95], - "rollout_steps": [5], - "num_minibatches": [8], - "clip_ratio": [0.3], - "update_epochs": [4], - "entropy_weight": [1e-2], - "reset_interval": [512, 1024] - } - ], - "agent_optimizer": [ - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [""], - "learning_rate": [3e-4], - "gradient_clip": [1] - } - ] - } - ], - "meta_optimizer": [ - { - "name": ["Adam"], - "kwargs": [ - { "learning_rate": [3e-5, 1e-4, 3e-4, 1e-3], "gradient_clip": [1] } - ] - } - ], - "save_param": [512], - "display_interval": [50], - "batch_size": [1024], - "discount": [0.97], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/ant_ppo.json b/configs/ant_ppo.json deleted file mode 100644 index aee858b..0000000 --- a/configs/ant_ppo.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "env": [ - { - "name": ["ant"], - "train_steps": [1e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [10], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [5], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [4], - "entropy_weight": [1e-2] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [1024], - "discount": [0.97], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/bdl_a2c.json b/configs/bdl_a2c.json deleted file mode 100644 index 05b59cc..0000000 --- a/configs/bdl_a2c.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "env": [ - { - "name": [["big_dense_long"]], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["A2C"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01] - } - ], - "agent_optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [ - { - "learning_rate": [3e-2, 1e-2, 3e-3, 1e-3, 3e-4, 1e-4], - "gradient_clip": [1] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "generate_random_seed": [true] -} diff --git a/configs/bdl_collect.json b/configs/bdl_collect.json deleted file mode 100644 index 7accabe..0000000 --- a/configs/bdl_collect.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "env": [ - { - "name": [["big_dense_long"]], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["CollectA2C"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01], - "data_reduce": [100] - } - ], - "agent_optimizer": [ - { - "name": ["RMSProp"], - "kwargs": [{ "learning_rate": [3e-3], "gradient_clip": [-1] }] - } - ], - "discount": [0.995], - "seed": [42], - "generate_random_seed": [false] -} diff --git a/configs/bdl_identity.json b/configs/bdl_identity.json deleted file mode 100644 index 20fe500..0000000 --- a/configs/bdl_identity.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "name": [["big_dense_long"]], - "agent": [{ "name": ["RNNIndentity"] }], - "epoch": [1000], - "seq_len": [5], - "datapath": ["./logs/bdl_collect/1/data.npz"], - "meta_net": [ - { - "name": ["RobustRNN", "NormalRNN"], - "rnn_type": ["LSTM"], - "mlp_dims": [[16, 16]], - "hidden_size": [8] - } - ], - "optimizer": [ - { - "name": ["Adam"], - "kwargs": [ - { "learning_rate": [3e-3, 1e-3, 3e-4, 1e-4], "gradient_clip": [-1] } - ] - } - ], - "display_interval": [25], - "seed": [100], - "generate_random_seed": [true] -} diff --git a/configs/bdl_lopt.json b/configs/bdl_lopt.json deleted file mode 100644 index 2beea30..0000000 --- a/configs/bdl_lopt.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "env": [ - { - "name": [["big_dense_long"]], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["A2C"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01] - } - ], - "agent_optimizer": [ - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [ - "./logs/exp/index/meta_param_path1.pickle", - "./logs/exp/index/rnn_parameter_path2.pickle" - ], - "learning_rate": [3e-3], - "gradient_clip": [1] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "generate_random_seed": [true] -} diff --git a/configs/bdl_meta.json b/configs/bdl_meta.json deleted file mode 100644 index 0f71df3..0000000 --- a/configs/bdl_meta.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "env": [ - { - "name": [["big_dense_long"]], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["MetaA2C"], - "inner_updates": [4], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01], - "reset_interval": [256, 512] - } - ], - "agent_optimizer": [ - { - "name": ["Optim4RL", "LinearOptim", "L2LGD2"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [""], - "learning_rate": [3e-3], - "gradient_clip": [1] - } - ] - } - ], - "meta_optimizer": [ - { - "name": ["Adam"], - "kwargs": [ - { - "learning_rate": [3e-5, 1e-4, 3e-4, 1e-3], - "gradient_clip": [1] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "save_param": [256], - "display_interval": [50], - "generate_random_seed": [true] -} diff --git a/configs/bdl_star.json b/configs/bdl_star.json deleted file mode 100644 index 8480f50..0000000 --- a/configs/bdl_star.json +++ /dev/null @@ -1,49 +0,0 @@ -{ - "env": [ - { - "name": [["big_dense_long"]], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["StarA2C"], - "inner_updates": [4], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01], - "reset_interval": [256] - } - ], - "agent_optimizer": [ - { - "name": ["Star"], - "kwargs": [ - { - "train_steps": [3e7], - "step_mult": [3e-3, 1e-3, 3e-4], - "nominal_stepsize": [3e-3, 1e-3, 3e-4, 0.0], - "weight_decay": [0.0, 0.1, 0.5] - } - ] - } - ], - "meta_optimizer": [ - { - "name": ["Adam"], - "kwargs": [ - { - "learning_rate": [1e-4, 3e-4], - "gradient_clip": [1] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "save_param": [256], - "display_interval": [50], - "generate_random_seed": [true] -} diff --git a/configs/bdl_star_lopt.json b/configs/bdl_star_lopt.json deleted file mode 100644 index def0303..0000000 --- a/configs/bdl_star_lopt.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "env": [ - { - "name": [["big_dense_long"]], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["A2C2"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01] - } - ], - "agent_optimizer": [ - { - "name": ["Star"], - "kwargs": [ - { - "train_steps": [3e7], - "param_load_path": [ - "./logs/exp/index/meta_param_path1.pickle", - "./logs/exp/index/rnn_parameter_path2.pickle" - ], - "step_mult": [3e-3], - "nominal_stepsize": [3e-3], - "weight_decay": [0.0] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "generate_random_seed": [true] -} diff --git a/configs/collect_bdl.json b/configs/collect_bdl.json new file mode 100644 index 0000000..099036b --- /dev/null +++ b/configs/collect_bdl.json @@ -0,0 +1,22 @@ +{ + "env": [{ + "name": [["big_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["A2Ccollect"], + "data_reduce": [100], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["RMSProp", "Adam"], + "kwargs": [{"learning_rate": [0.003], "grad_clip": [1]}] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [false] +} \ No newline at end of file diff --git a/configs/collect_mnist.json b/configs/collect_mnist.json new file mode 100644 index 0000000..1f1944b --- /dev/null +++ b/configs/collect_mnist.json @@ -0,0 +1,17 @@ +{ + "task": ["MNIST"], + "agent": [{ + "name": ["SLCollect"], + "data_reduce": [1000] + }], + "model": [{"name": ["MNIST_CNN"], "kwargs": [{"output_dim": [10]}]}], + "optimizer": [{ + "name": ["RMSProp"], + "kwargs": [{"learning_rate": [3e-4], "grad_clip": [-1]}] + }], + "epochs": [10], + "batch_size": [128], + "display_interval": [1], + "seed": [1], + "generate_random_seed": [false] +} \ No newline at end of file diff --git a/configs/fetch_collect.json b/configs/fetch_collect.json deleted file mode 100644 index ba48e81..0000000 --- a/configs/fetch_collect.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "env": [ - { - "name": ["fetch"], - "train_steps": [1e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [5], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["CollectPPO"], - "data_reduce": [100], - "gae_lambda": [0.95], - "rollout_steps": [20], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [4], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [256], - "discount": [0.997], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/fetch_lopt.json b/configs/fetch_lopt.json deleted file mode 100644 index 0aa2846..0000000 --- a/configs/fetch_lopt.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "env": [ - { - "name": ["fetch"], - "train_steps": [1e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [5], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [4], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [ - "./logs/exp/index/meta_param_path1.pickle", - "./logs/exp/index/rnn_parameter_path2.pickle" - ], - "learning_rate": [3e-4], - "gradient_clip": [1] - } - ] - } - ], - "batch_size": [256], - "discount": [0.997], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/fetch_ppo.json b/configs/fetch_ppo.json deleted file mode 100644 index b569b13..0000000 --- a/configs/fetch_ppo.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "env": [ - { - "name": ["fetch"], - "train_steps": [1e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [5], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [4], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [256], - "discount": [0.997], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/grasp_collect.json b/configs/grasp_collect.json deleted file mode 100644 index b8cc091..0000000 --- a/configs/grasp_collect.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "env": [ - { - "name": ["grasp"], - "train_steps": [6e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [10], - "num_envs": [2048], - "num_evals": [10], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["CollectPPO"], - "data_reduce": [5000], - "gae_lambda": [0.95], - "rollout_steps": [20], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [2], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [256], - "discount": [0.99], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/grasp_lopt.json b/configs/grasp_lopt.json deleted file mode 100644 index e462951..0000000 --- a/configs/grasp_lopt.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "env": [ - { - "name": ["grasp"], - "train_steps": [6e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [10], - "num_envs": [2048], - "num_evals": [10], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [2], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [ - "./logs/exp/index/meta_param_path1.pickle", - "./logs/exp/index/rnn_parameter_path2.pickle" - ], - "learning_rate": [3e-4], - "gradient_clip": [1] - } - ] - } - ], - "batch_size": [256], - "discount": [0.99], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/grasp_ppo.json b/configs/grasp_ppo.json deleted file mode 100644 index 8983aa8..0000000 --- a/configs/grasp_ppo.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "env": [ - { - "name": ["grasp"], - "train_steps": [6e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [10], - "num_envs": [2048], - "num_evals": [10], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [2], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [256], - "discount": [0.99], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/grid_a2c.json b/configs/grid_a2c.json deleted file mode 100644 index 87959c4..0000000 --- a/configs/grid_a2c.json +++ /dev/null @@ -1,43 +0,0 @@ -{ - "env": [ - { - "name": [ - [ - "small_sparse_short", - "small_sparse_long", - "small_dense_short", - "small_dense_long", - "big_sparse_short", - "big_sparse_long", - "big_dense_short", - "big_dense_long" - ] - ], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["A2C"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01] - } - ], - "agent_optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [ - { - "learning_rate": [3e-2, 1e-2, 3e-3, 1e-3, 3e-4, 1e-4], - "gradient_clip": [1] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "generate_random_seed": [true] -} diff --git a/configs/grid_collect.json b/configs/grid_collect.json deleted file mode 100644 index 8686b3e..0000000 --- a/configs/grid_collect.json +++ /dev/null @@ -1,43 +0,0 @@ -{ - "env": [ - { - "name": [ - ["small_sparse_short"], - ["small_sparse_long"], - ["small_dense_short"], - ["small_dense_long"], - ["big_sparse_short"], - ["big_sparse_long"], - ["big_dense_short"], - ["big_dense_long"] - ], - "reward_scaling": [1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["CollectA2C"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01], - "data_reduce": [10] - } - ], - "agent_optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [ - { - "learning_rate": [3e-2, 1e-2, 3e-3, 1e-3, 3e-4, 1e-4], - "gradient_clip": [1] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "generate_random_seed": [true] -} diff --git a/configs/grid_meta.json b/configs/grid_meta.json deleted file mode 100644 index e40e83b..0000000 --- a/configs/grid_meta.json +++ /dev/null @@ -1,61 +0,0 @@ -{ - "env": [ - { - "name": [ - [ - "small_dense_long", - "small_dense_short", - "big_sparse_short", - "big_dense_short", - "big_sparse_long", - "big_dense_long" - ] - ], - "reward_scaling": [[1e3, 1e2, 1e2, 1e1, 1e1, 1]], - "num_envs": [512], - "train_steps": [1.5e8] - } - ], - "agent": [ - { - "name": ["MetaA2C"], - "inner_updates": [4], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01], - "reset_interval": [256, 512, 1024] - } - ], - "agent_optimizer": [ - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [""], - "learning_rate": [[1e-3, 3e-3, 3e-3, 3e-3, 1e-3, 3e-3]], - "gradient_clip": [1] - } - ] - } - ], - "meta_optimizer": [ - { - "name": ["Adam"], - "kwargs": [ - { - "learning_rate": [3e-5, 1e-4, 3e-4, 1e-3], - "gradient_clip": [1, -1], - "max_norm": [-1, 1e2, 1e0, 1e-2] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "save_param": [256], - "display_interval": [50], - "generate_random_seed": [true] -} diff --git a/configs/halfcheetah_collect.json b/configs/halfcheetah_collect.json deleted file mode 100644 index bebe928..0000000 --- a/configs/halfcheetah_collect.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "env": [ - { - "name": ["halfcheetah"], - "train_steps": [1e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [1], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["CollectPPO"], - "data_reduce": [100], - "gae_lambda": [0.95], - "rollout_steps": [20], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [8], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [512], - "discount": [0.95], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/halfcheetah_lopt.json b/configs/halfcheetah_lopt.json deleted file mode 100644 index 938062f..0000000 --- a/configs/halfcheetah_lopt.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "env": [ - { - "name": ["halfcheetah"], - "train_steps": [1e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [1], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [8], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [ - "./logs/exp/index/meta_param_path1.pickle", - "./logs/exp/index/rnn_parameter_path2.pickle" - ], - "learning_rate": [3e-4], - "gradient_clip": [1] - } - ] - } - ], - "batch_size": [512], - "discount": [0.95], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/halfcheetah_ppo.json b/configs/halfcheetah_ppo.json deleted file mode 100644 index bad4142..0000000 --- a/configs/halfcheetah_ppo.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "env": [ - { - "name": ["halfcheetah"], - "train_steps": [1e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [1], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [8], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [512], - "discount": [0.95], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/humanoid_collect.json b/configs/humanoid_collect.json deleted file mode 100644 index e9e5ecd..0000000 --- a/configs/humanoid_collect.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "env": [ - { - "name": ["humanoid"], - "train_steps": [5e7], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [0.1], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["CollectPPO"], - "data_reduce": [100], - "gae_lambda": [0.95], - "rollout_steps": [10], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [8], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [1024], - "discount": [0.97], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/humanoid_lopt.json b/configs/humanoid_lopt.json deleted file mode 100644 index e214e86..0000000 --- a/configs/humanoid_lopt.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "env": [ - { - "name": ["humanoid"], - "train_steps": [5e7], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [0.1], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [10], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [8], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [ - "./logs/exp/index/meta_param_path1.pickle", - "./logs/exp/index/rnn_parameter_path2.pickle" - ], - "learning_rate": [3e-4], - "gradient_clip": [1] - } - ] - } - ], - "batch_size": [1024], - "discount": [0.97], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/humanoid_meta.json b/configs/humanoid_meta.json deleted file mode 100644 index 8cfb6d6..0000000 --- a/configs/humanoid_meta.json +++ /dev/null @@ -1,55 +0,0 @@ -{ - "env": [ - { - "name": ["humanoid"], - "train_steps": [3e8], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [0.1], - "num_envs": [2048], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["MetaPPO"], - "inner_updates": [4], - "gae_lambda": [0.95], - "rollout_steps": [10], - "num_minibatches": [8], - "clip_ratio": [0.3], - "update_epochs": [4], - "entropy_weight": [1e-3], - "reset_interval": [256, 512] - } - ], - "agent_optimizer": [ - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [""], - "learning_rate": [3e-4], - "gradient_clip": [1] - } - ] - } - ], - "meta_optimizer": [ - { - "name": ["Adam"], - "kwargs": [ - { "learning_rate": [3e-5, 1e-4, 3e-4, 1e-3], "gradient_clip": [1] } - ] - } - ], - "save_param": [256], - "display_interval": [50], - "batch_size": [1024], - "discount": [0.97], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/humanoid_ppo.json b/configs/humanoid_ppo.json deleted file mode 100644 index 8f4ae8e..0000000 --- a/configs/humanoid_ppo.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "env": [ - { - "name": ["humanoid"], - "train_steps": [5e7], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [0.1], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [10], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [8], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [1024], - "discount": [0.97], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/lopt_l2l_ant.json b/configs/lopt_l2l_ant.json new file mode 100644 index 0000000..5f52e61 --- /dev/null +++ b/configs/lopt_l2l_ant.json @@ -0,0 +1,85 @@ +{ + "env": [{ + "name": ["ant"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [4096], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2] + }], + "optim": [{ + "name": ["L2LGD2"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_l2l_ant/1/param.pickle", + "./logs/meta_l2l_ant/2/param.pickle", + "./logs/meta_l2l_ant/3/param.pickle", + "./logs/meta_l2l_ant/4/param.pickle", + "./logs/meta_l2l_ant/5/param.pickle", + "./logs/meta_l2l_ant/6/param.pickle", + "./logs/meta_l2l_ant/7/param.pickle", + "./logs/meta_l2l_ant/8/param.pickle", + "./logs/meta_l2l_ant/9/param.pickle", + "./logs/meta_l2l_ant/10/param.pickle", + "./logs/meta_l2l_ant/11/param.pickle", + "./logs/meta_l2l_ant/12/param.pickle", + "./logs/meta_l2l_ant/13/param.pickle", + "./logs/meta_l2l_ant/14/param.pickle", + "./logs/meta_l2l_ant/15/param.pickle", + "./logs/meta_l2l_ant/16/param.pickle", + "./logs/meta_l2l_ant/17/param.pickle", + "./logs/meta_l2l_ant/18/param.pickle", + "./logs/meta_l2l_ant/19/param.pickle", + "./logs/meta_l2l_ant/20/param.pickle", + "./logs/meta_l2l_ant/21/param.pickle", + "./logs/meta_l2l_ant/22/param.pickle", + "./logs/meta_l2l_ant/23/param.pickle", + "./logs/meta_l2l_ant/24/param.pickle", + "./logs/meta_l2l_ant/25/param.pickle", + "./logs/meta_l2l_ant/26/param.pickle", + "./logs/meta_l2l_ant/27/param.pickle", + "./logs/meta_l2l_ant/28/param.pickle", + "./logs/meta_l2l_ant/29/param.pickle", + "./logs/meta_l2l_ant/30/param.pickle", + "./logs/meta_l2l_ant/31/param.pickle", + "./logs/meta_l2l_ant/32/param.pickle", + "./logs/meta_l2l_ant/33/param.pickle", + "./logs/meta_l2l_ant/34/param.pickle", + "./logs/meta_l2l_ant/35/param.pickle", + "./logs/meta_l2l_ant/36/param.pickle", + "./logs/meta_l2l_ant/37/param.pickle", + "./logs/meta_l2l_ant/38/param.pickle", + "./logs/meta_l2l_ant/39/param.pickle", + "./logs/meta_l2l_ant/40/param.pickle", + "./logs/meta_l2l_ant/41/param.pickle", + "./logs/meta_l2l_ant/42/param.pickle", + "./logs/meta_l2l_ant/43/param.pickle", + "./logs/meta_l2l_ant/44/param.pickle", + "./logs/meta_l2l_ant/45/param.pickle", + "./logs/meta_l2l_ant/46/param.pickle", + "./logs/meta_l2l_ant/47/param.pickle", + "./logs/meta_l2l_ant/48/param.pickle", + "./logs/meta_l2l_ant/49/param.pickle", + "./logs/meta_l2l_ant/50/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [2048], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_l2l_bdl.json b/configs/lopt_l2l_bdl.json new file mode 100644 index 0000000..5a0240c --- /dev/null +++ b/configs/lopt_l2l_bdl.json @@ -0,0 +1,66 @@ +{ + "env": [{ + "name": [["big_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["A2C"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["L2LGD2"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_l2l_bdl/1/param.pickle", + "./logs/meta_l2l_bdl/2/param.pickle", + "./logs/meta_l2l_bdl/3/param.pickle", + "./logs/meta_l2l_bdl/4/param.pickle", + "./logs/meta_l2l_bdl/5/param.pickle", + "./logs/meta_l2l_bdl/6/param.pickle", + "./logs/meta_l2l_bdl/7/param.pickle", + "./logs/meta_l2l_bdl/8/param.pickle", + "./logs/meta_l2l_bdl/9/param.pickle", + "./logs/meta_l2l_bdl/10/param.pickle", + "./logs/meta_l2l_bdl/11/param.pickle", + "./logs/meta_l2l_bdl/12/param.pickle", + "./logs/meta_l2l_bdl/13/param.pickle", + "./logs/meta_l2l_bdl/14/param.pickle", + "./logs/meta_l2l_bdl/15/param.pickle", + "./logs/meta_l2l_bdl/16/param.pickle", + "./logs/meta_l2l_bdl/17/param.pickle", + "./logs/meta_l2l_bdl/18/param.pickle", + "./logs/meta_l2l_bdl/19/param.pickle", + "./logs/meta_l2l_bdl/20/param.pickle", + "./logs/meta_l2l_bdl/21/param.pickle", + "./logs/meta_l2l_bdl/22/param.pickle", + "./logs/meta_l2l_bdl/23/param.pickle", + "./logs/meta_l2l_bdl/24/param.pickle", + "./logs/meta_l2l_bdl/25/param.pickle", + "./logs/meta_l2l_bdl/26/param.pickle", + "./logs/meta_l2l_bdl/27/param.pickle", + "./logs/meta_l2l_bdl/28/param.pickle", + "./logs/meta_l2l_bdl/29/param.pickle", + "./logs/meta_l2l_bdl/30/param.pickle", + "./logs/meta_l2l_bdl/31/param.pickle", + "./logs/meta_l2l_bdl/32/param.pickle", + "./logs/meta_l2l_bdl/33/param.pickle", + "./logs/meta_l2l_bdl/34/param.pickle", + "./logs/meta_l2l_bdl/35/param.pickle", + "./logs/meta_l2l_bdl/36/param.pickle", + "./logs/meta_l2l_bdl/37/param.pickle", + "./logs/meta_l2l_bdl/38/param.pickle", + "./logs/meta_l2l_bdl/39/param.pickle", + "./logs/meta_l2l_bdl/40/param.pickle" + ], + "learning_rate": [0.003], + "grad_clip": [1] + }] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_l2l_catch.json b/configs/lopt_l2l_catch.json new file mode 100644 index 0000000..0d9151b --- /dev/null +++ b/configs/lopt_l2l_catch.json @@ -0,0 +1,46 @@ +{ + "env": [{ + "name": [["catch"]], + "num_envs": [64], + "train_steps": [5e5] + }], + "agent": [{ + "name": ["A2C"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["L2LGD2"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_l2l_catch/1/param.pickle", + "./logs/meta_l2l_catch/2/param.pickle", + "./logs/meta_l2l_catch/3/param.pickle", + "./logs/meta_l2l_catch/4/param.pickle", + "./logs/meta_l2l_catch/5/param.pickle", + "./logs/meta_l2l_catch/6/param.pickle", + "./logs/meta_l2l_catch/7/param.pickle", + "./logs/meta_l2l_catch/8/param.pickle", + "./logs/meta_l2l_catch/9/param.pickle", + "./logs/meta_l2l_catch/10/param.pickle", + "./logs/meta_l2l_catch/11/param.pickle", + "./logs/meta_l2l_catch/12/param.pickle", + "./logs/meta_l2l_catch/13/param.pickle", + "./logs/meta_l2l_catch/14/param.pickle", + "./logs/meta_l2l_catch/15/param.pickle", + "./logs/meta_l2l_catch/16/param.pickle", + "./logs/meta_l2l_catch/17/param.pickle", + "./logs/meta_l2l_catch/18/param.pickle", + "./logs/meta_l2l_catch/19/param.pickle", + "./logs/meta_l2l_catch/20/param.pickle" + ], + "learning_rate": [1e-3], + "grad_clip": [-1] + }] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_l2l_humanoid.json b/configs/lopt_l2l_humanoid.json new file mode 100644 index 0000000..7839df4 --- /dev/null +++ b/configs/lopt_l2l_humanoid.json @@ -0,0 +1,85 @@ +{ + "env": [{ + "name": ["humanoid"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [0.1], + "num_envs": [2048], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [10], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [8], + "entropy_weight": [1e-3] + }], + "optim": [{ + "name": ["L2LGD2"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_l2l_humanoid/1/param.pickle", + "./logs/meta_l2l_humanoid/2/param.pickle", + "./logs/meta_l2l_humanoid/3/param.pickle", + "./logs/meta_l2l_humanoid/4/param.pickle", + "./logs/meta_l2l_humanoid/5/param.pickle", + "./logs/meta_l2l_humanoid/6/param.pickle", + "./logs/meta_l2l_humanoid/7/param.pickle", + "./logs/meta_l2l_humanoid/8/param.pickle", + "./logs/meta_l2l_humanoid/9/param.pickle", + "./logs/meta_l2l_humanoid/10/param.pickle", + "./logs/meta_l2l_humanoid/11/param.pickle", + "./logs/meta_l2l_humanoid/12/param.pickle", + "./logs/meta_l2l_humanoid/13/param.pickle", + "./logs/meta_l2l_humanoid/14/param.pickle", + "./logs/meta_l2l_humanoid/15/param.pickle", + "./logs/meta_l2l_humanoid/16/param.pickle", + "./logs/meta_l2l_humanoid/17/param.pickle", + "./logs/meta_l2l_humanoid/18/param.pickle", + "./logs/meta_l2l_humanoid/19/param.pickle", + "./logs/meta_l2l_humanoid/20/param.pickle", + "./logs/meta_l2l_humanoid/21/param.pickle", + "./logs/meta_l2l_humanoid/22/param.pickle", + "./logs/meta_l2l_humanoid/23/param.pickle", + "./logs/meta_l2l_humanoid/24/param.pickle", + "./logs/meta_l2l_humanoid/25/param.pickle", + "./logs/meta_l2l_humanoid/26/param.pickle", + "./logs/meta_l2l_humanoid/27/param.pickle", + "./logs/meta_l2l_humanoid/28/param.pickle", + "./logs/meta_l2l_humanoid/29/param.pickle", + "./logs/meta_l2l_humanoid/30/param.pickle", + "./logs/meta_l2l_humanoid/31/param.pickle", + "./logs/meta_l2l_humanoid/32/param.pickle", + "./logs/meta_l2l_humanoid/33/param.pickle", + "./logs/meta_l2l_humanoid/34/param.pickle", + "./logs/meta_l2l_humanoid/35/param.pickle", + "./logs/meta_l2l_humanoid/36/param.pickle", + "./logs/meta_l2l_humanoid/37/param.pickle", + "./logs/meta_l2l_humanoid/38/param.pickle", + "./logs/meta_l2l_humanoid/39/param.pickle", + "./logs/meta_l2l_humanoid/40/param.pickle", + "./logs/meta_l2l_humanoid/41/param.pickle", + "./logs/meta_l2l_humanoid/42/param.pickle", + "./logs/meta_l2l_humanoid/43/param.pickle", + "./logs/meta_l2l_humanoid/44/param.pickle", + "./logs/meta_l2l_humanoid/45/param.pickle", + "./logs/meta_l2l_humanoid/46/param.pickle", + "./logs/meta_l2l_humanoid/47/param.pickle", + "./logs/meta_l2l_humanoid/48/param.pickle", + "./logs/meta_l2l_humanoid/49/param.pickle", + "./logs/meta_l2l_humanoid/50/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_lin_ant.json b/configs/lopt_lin_ant.json new file mode 100644 index 0000000..613a636 --- /dev/null +++ b/configs/lopt_lin_ant.json @@ -0,0 +1,85 @@ +{ + "env": [{ + "name": ["ant"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [4096], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2] + }], + "optim": [{ + "name": ["LinearOptim"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_lin_ant/1/param.pickle", + "./logs/meta_lin_ant/2/param.pickle", + "./logs/meta_lin_ant/3/param.pickle", + "./logs/meta_lin_ant/4/param.pickle", + "./logs/meta_lin_ant/5/param.pickle", + "./logs/meta_lin_ant/6/param.pickle", + "./logs/meta_lin_ant/7/param.pickle", + "./logs/meta_lin_ant/8/param.pickle", + "./logs/meta_lin_ant/9/param.pickle", + "./logs/meta_lin_ant/10/param.pickle", + "./logs/meta_lin_ant/11/param.pickle", + "./logs/meta_lin_ant/12/param.pickle", + "./logs/meta_lin_ant/13/param.pickle", + "./logs/meta_lin_ant/14/param.pickle", + "./logs/meta_lin_ant/15/param.pickle", + "./logs/meta_lin_ant/16/param.pickle", + "./logs/meta_lin_ant/17/param.pickle", + "./logs/meta_lin_ant/18/param.pickle", + "./logs/meta_lin_ant/19/param.pickle", + "./logs/meta_lin_ant/20/param.pickle", + "./logs/meta_lin_ant/21/param.pickle", + "./logs/meta_lin_ant/22/param.pickle", + "./logs/meta_lin_ant/23/param.pickle", + "./logs/meta_lin_ant/24/param.pickle", + "./logs/meta_lin_ant/25/param.pickle", + "./logs/meta_lin_ant/26/param.pickle", + "./logs/meta_lin_ant/27/param.pickle", + "./logs/meta_lin_ant/28/param.pickle", + "./logs/meta_lin_ant/29/param.pickle", + "./logs/meta_lin_ant/30/param.pickle", + "./logs/meta_lin_ant/31/param.pickle", + "./logs/meta_lin_ant/32/param.pickle", + "./logs/meta_lin_ant/33/param.pickle", + "./logs/meta_lin_ant/34/param.pickle", + "./logs/meta_lin_ant/35/param.pickle", + "./logs/meta_lin_ant/36/param.pickle", + "./logs/meta_lin_ant/37/param.pickle", + "./logs/meta_lin_ant/38/param.pickle", + "./logs/meta_lin_ant/39/param.pickle", + "./logs/meta_lin_ant/40/param.pickle", + "./logs/meta_lin_ant/41/param.pickle", + "./logs/meta_lin_ant/42/param.pickle", + "./logs/meta_lin_ant/43/param.pickle", + "./logs/meta_lin_ant/44/param.pickle", + "./logs/meta_lin_ant/45/param.pickle", + "./logs/meta_lin_ant/46/param.pickle", + "./logs/meta_lin_ant/47/param.pickle", + "./logs/meta_lin_ant/48/param.pickle", + "./logs/meta_lin_ant/49/param.pickle", + "./logs/meta_lin_ant/50/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [2048], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_lin_bdl.json b/configs/lopt_lin_bdl.json new file mode 100644 index 0000000..1208465 --- /dev/null +++ b/configs/lopt_lin_bdl.json @@ -0,0 +1,66 @@ +{ + "env": [{ + "name": [["big_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["A2C"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["LinearOptim"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_lin_bdl/1/param.pickle", + "./logs/meta_lin_bdl/2/param.pickle", + "./logs/meta_lin_bdl/3/param.pickle", + "./logs/meta_lin_bdl/4/param.pickle", + "./logs/meta_lin_bdl/5/param.pickle", + "./logs/meta_lin_bdl/6/param.pickle", + "./logs/meta_lin_bdl/7/param.pickle", + "./logs/meta_lin_bdl/8/param.pickle", + "./logs/meta_lin_bdl/9/param.pickle", + "./logs/meta_lin_bdl/10/param.pickle", + "./logs/meta_lin_bdl/11/param.pickle", + "./logs/meta_lin_bdl/12/param.pickle", + "./logs/meta_lin_bdl/13/param.pickle", + "./logs/meta_lin_bdl/14/param.pickle", + "./logs/meta_lin_bdl/15/param.pickle", + "./logs/meta_lin_bdl/16/param.pickle", + "./logs/meta_lin_bdl/17/param.pickle", + "./logs/meta_lin_bdl/18/param.pickle", + "./logs/meta_lin_bdl/19/param.pickle", + "./logs/meta_lin_bdl/20/param.pickle", + "./logs/meta_lin_bdl/21/param.pickle", + "./logs/meta_lin_bdl/22/param.pickle", + "./logs/meta_lin_bdl/23/param.pickle", + "./logs/meta_lin_bdl/24/param.pickle", + "./logs/meta_lin_bdl/25/param.pickle", + "./logs/meta_lin_bdl/26/param.pickle", + "./logs/meta_lin_bdl/27/param.pickle", + "./logs/meta_lin_bdl/28/param.pickle", + "./logs/meta_lin_bdl/29/param.pickle", + "./logs/meta_lin_bdl/30/param.pickle", + "./logs/meta_lin_bdl/31/param.pickle", + "./logs/meta_lin_bdl/32/param.pickle", + "./logs/meta_lin_bdl/33/param.pickle", + "./logs/meta_lin_bdl/34/param.pickle", + "./logs/meta_lin_bdl/35/param.pickle", + "./logs/meta_lin_bdl/36/param.pickle", + "./logs/meta_lin_bdl/37/param.pickle", + "./logs/meta_lin_bdl/38/param.pickle", + "./logs/meta_lin_bdl/39/param.pickle", + "./logs/meta_lin_bdl/40/param.pickle" + ], + "learning_rate": [0.003], + "grad_clip": [1] + }] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_lin_catch.json b/configs/lopt_lin_catch.json new file mode 100644 index 0000000..7b3f9a6 --- /dev/null +++ b/configs/lopt_lin_catch.json @@ -0,0 +1,46 @@ +{ + "env": [{ + "name": [["catch"]], + "num_envs": [64], + "train_steps": [5e5] + }], + "agent": [{ + "name": ["A2C"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["LinearOptim"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_lin_catch/1/param.pickle", + "./logs/meta_lin_catch/2/param.pickle", + "./logs/meta_lin_catch/3/param.pickle", + "./logs/meta_lin_catch/4/param.pickle", + "./logs/meta_lin_catch/5/param.pickle", + "./logs/meta_lin_catch/6/param.pickle", + "./logs/meta_lin_catch/7/param.pickle", + "./logs/meta_lin_catch/8/param.pickle", + "./logs/meta_lin_catch/9/param.pickle", + "./logs/meta_lin_catch/10/param.pickle", + "./logs/meta_lin_catch/11/param.pickle", + "./logs/meta_lin_catch/12/param.pickle", + "./logs/meta_lin_catch/13/param.pickle", + "./logs/meta_lin_catch/14/param.pickle", + "./logs/meta_lin_catch/15/param.pickle", + "./logs/meta_lin_catch/16/param.pickle", + "./logs/meta_lin_catch/17/param.pickle", + "./logs/meta_lin_catch/18/param.pickle", + "./logs/meta_lin_catch/19/param.pickle", + "./logs/meta_lin_catch/20/param.pickle" + ], + "learning_rate": [1e-3], + "grad_clip": [-1] + }] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_lin_humanoid.json b/configs/lopt_lin_humanoid.json new file mode 100644 index 0000000..7e7b823 --- /dev/null +++ b/configs/lopt_lin_humanoid.json @@ -0,0 +1,85 @@ +{ + "env": [{ + "name": ["humanoid"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [0.1], + "num_envs": [2048], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [10], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [8], + "entropy_weight": [1e-3] + }], + "optim": [{ + "name": ["LinearOptim"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_lin_humanoid/1/param.pickle", + "./logs/meta_lin_humanoid/2/param.pickle", + "./logs/meta_lin_humanoid/3/param.pickle", + "./logs/meta_lin_humanoid/4/param.pickle", + "./logs/meta_lin_humanoid/5/param.pickle", + "./logs/meta_lin_humanoid/6/param.pickle", + "./logs/meta_lin_humanoid/7/param.pickle", + "./logs/meta_lin_humanoid/8/param.pickle", + "./logs/meta_lin_humanoid/9/param.pickle", + "./logs/meta_lin_humanoid/10/param.pickle", + "./logs/meta_lin_humanoid/11/param.pickle", + "./logs/meta_lin_humanoid/12/param.pickle", + "./logs/meta_lin_humanoid/13/param.pickle", + "./logs/meta_lin_humanoid/14/param.pickle", + "./logs/meta_lin_humanoid/15/param.pickle", + "./logs/meta_lin_humanoid/16/param.pickle", + "./logs/meta_lin_humanoid/17/param.pickle", + "./logs/meta_lin_humanoid/18/param.pickle", + "./logs/meta_lin_humanoid/19/param.pickle", + "./logs/meta_lin_humanoid/20/param.pickle", + "./logs/meta_lin_humanoid/21/param.pickle", + "./logs/meta_lin_humanoid/22/param.pickle", + "./logs/meta_lin_humanoid/23/param.pickle", + "./logs/meta_lin_humanoid/24/param.pickle", + "./logs/meta_lin_humanoid/25/param.pickle", + "./logs/meta_lin_humanoid/26/param.pickle", + "./logs/meta_lin_humanoid/27/param.pickle", + "./logs/meta_lin_humanoid/28/param.pickle", + "./logs/meta_lin_humanoid/29/param.pickle", + "./logs/meta_lin_humanoid/30/param.pickle", + "./logs/meta_lin_humanoid/31/param.pickle", + "./logs/meta_lin_humanoid/32/param.pickle", + "./logs/meta_lin_humanoid/33/param.pickle", + "./logs/meta_lin_humanoid/34/param.pickle", + "./logs/meta_lin_humanoid/35/param.pickle", + "./logs/meta_lin_humanoid/36/param.pickle", + "./logs/meta_lin_humanoid/37/param.pickle", + "./logs/meta_lin_humanoid/38/param.pickle", + "./logs/meta_lin_humanoid/39/param.pickle", + "./logs/meta_lin_humanoid/40/param.pickle", + "./logs/meta_lin_humanoid/41/param.pickle", + "./logs/meta_lin_humanoid/42/param.pickle", + "./logs/meta_lin_humanoid/43/param.pickle", + "./logs/meta_lin_humanoid/44/param.pickle", + "./logs/meta_lin_humanoid/45/param.pickle", + "./logs/meta_lin_humanoid/46/param.pickle", + "./logs/meta_lin_humanoid/47/param.pickle", + "./logs/meta_lin_humanoid/48/param.pickle", + "./logs/meta_lin_humanoid/49/param.pickle", + "./logs/meta_lin_humanoid/50/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rl_ant.json b/configs/lopt_rl_ant.json new file mode 100644 index 0000000..6faf251 --- /dev/null +++ b/configs/lopt_rl_ant.json @@ -0,0 +1,85 @@ +{ + "env": [{ + "name": ["ant"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [4096], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2] + }], + "optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rl_ant/1/param.pickle", + "./logs/meta_rl_ant/2/param.pickle", + "./logs/meta_rl_ant/3/param.pickle", + "./logs/meta_rl_ant/4/param.pickle", + "./logs/meta_rl_ant/5/param.pickle", + "./logs/meta_rl_ant/6/param.pickle", + "./logs/meta_rl_ant/7/param.pickle", + "./logs/meta_rl_ant/8/param.pickle", + "./logs/meta_rl_ant/9/param.pickle", + "./logs/meta_rl_ant/10/param.pickle", + "./logs/meta_rl_ant/11/param.pickle", + "./logs/meta_rl_ant/12/param.pickle", + "./logs/meta_rl_ant/13/param.pickle", + "./logs/meta_rl_ant/14/param.pickle", + "./logs/meta_rl_ant/15/param.pickle", + "./logs/meta_rl_ant/16/param.pickle", + "./logs/meta_rl_ant/17/param.pickle", + "./logs/meta_rl_ant/18/param.pickle", + "./logs/meta_rl_ant/19/param.pickle", + "./logs/meta_rl_ant/20/param.pickle", + "./logs/meta_rl_ant/21/param.pickle", + "./logs/meta_rl_ant/22/param.pickle", + "./logs/meta_rl_ant/23/param.pickle", + "./logs/meta_rl_ant/24/param.pickle", + "./logs/meta_rl_ant/25/param.pickle", + "./logs/meta_rl_ant/26/param.pickle", + "./logs/meta_rl_ant/27/param.pickle", + "./logs/meta_rl_ant/28/param.pickle", + "./logs/meta_rl_ant/29/param.pickle", + "./logs/meta_rl_ant/30/param.pickle", + "./logs/meta_rl_ant/31/param.pickle", + "./logs/meta_rl_ant/32/param.pickle", + "./logs/meta_rl_ant/33/param.pickle", + "./logs/meta_rl_ant/34/param.pickle", + "./logs/meta_rl_ant/35/param.pickle", + "./logs/meta_rl_ant/36/param.pickle", + "./logs/meta_rl_ant/37/param.pickle", + "./logs/meta_rl_ant/38/param.pickle", + "./logs/meta_rl_ant/39/param.pickle", + "./logs/meta_rl_ant/40/param.pickle", + "./logs/meta_rl_ant/41/param.pickle", + "./logs/meta_rl_ant/42/param.pickle", + "./logs/meta_rl_ant/43/param.pickle", + "./logs/meta_rl_ant/44/param.pickle", + "./logs/meta_rl_ant/45/param.pickle", + "./logs/meta_rl_ant/46/param.pickle", + "./logs/meta_rl_ant/47/param.pickle", + "./logs/meta_rl_ant/48/param.pickle", + "./logs/meta_rl_ant/49/param.pickle", + "./logs/meta_rl_ant/50/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [2048], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rl_bdl.json b/configs/lopt_rl_bdl.json new file mode 100644 index 0000000..34de63c --- /dev/null +++ b/configs/lopt_rl_bdl.json @@ -0,0 +1,66 @@ +{ + "env": [{ + "name": [["big_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["A2C"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rl_bdl/1/param.pickle", + "./logs/meta_rl_bdl/2/param.pickle", + "./logs/meta_rl_bdl/3/param.pickle", + "./logs/meta_rl_bdl/4/param.pickle", + "./logs/meta_rl_bdl/5/param.pickle", + "./logs/meta_rl_bdl/6/param.pickle", + "./logs/meta_rl_bdl/7/param.pickle", + "./logs/meta_rl_bdl/8/param.pickle", + "./logs/meta_rl_bdl/9/param.pickle", + "./logs/meta_rl_bdl/10/param.pickle", + "./logs/meta_rl_bdl/11/param.pickle", + "./logs/meta_rl_bdl/12/param.pickle", + "./logs/meta_rl_bdl/13/param.pickle", + "./logs/meta_rl_bdl/14/param.pickle", + "./logs/meta_rl_bdl/15/param.pickle", + "./logs/meta_rl_bdl/16/param.pickle", + "./logs/meta_rl_bdl/17/param.pickle", + "./logs/meta_rl_bdl/18/param.pickle", + "./logs/meta_rl_bdl/19/param.pickle", + "./logs/meta_rl_bdl/20/param.pickle", + "./logs/meta_rl_bdl/21/param.pickle", + "./logs/meta_rl_bdl/22/param.pickle", + "./logs/meta_rl_bdl/23/param.pickle", + "./logs/meta_rl_bdl/24/param.pickle", + "./logs/meta_rl_bdl/25/param.pickle", + "./logs/meta_rl_bdl/26/param.pickle", + "./logs/meta_rl_bdl/27/param.pickle", + "./logs/meta_rl_bdl/28/param.pickle", + "./logs/meta_rl_bdl/29/param.pickle", + "./logs/meta_rl_bdl/30/param.pickle", + "./logs/meta_rl_bdl/31/param.pickle", + "./logs/meta_rl_bdl/32/param.pickle", + "./logs/meta_rl_bdl/33/param.pickle", + "./logs/meta_rl_bdl/34/param.pickle", + "./logs/meta_rl_bdl/35/param.pickle", + "./logs/meta_rl_bdl/36/param.pickle", + "./logs/meta_rl_bdl/37/param.pickle", + "./logs/meta_rl_bdl/38/param.pickle", + "./logs/meta_rl_bdl/39/param.pickle", + "./logs/meta_rl_bdl/40/param.pickle" + ], + "learning_rate": [0.003], + "grad_clip": [1] + }] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rl_catch.json b/configs/lopt_rl_catch.json new file mode 100644 index 0000000..9f544f4 --- /dev/null +++ b/configs/lopt_rl_catch.json @@ -0,0 +1,46 @@ +{ + "env": [{ + "name": [["catch"]], + "num_envs": [64], + "train_steps": [5e5] + }], + "agent": [{ + "name": ["A2C"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rl_catch/1/param.pickle", + "./logs/meta_rl_catch/2/param.pickle", + "./logs/meta_rl_catch/3/param.pickle", + "./logs/meta_rl_catch/4/param.pickle", + "./logs/meta_rl_catch/5/param.pickle", + "./logs/meta_rl_catch/6/param.pickle", + "./logs/meta_rl_catch/7/param.pickle", + "./logs/meta_rl_catch/8/param.pickle", + "./logs/meta_rl_catch/9/param.pickle", + "./logs/meta_rl_catch/10/param.pickle", + "./logs/meta_rl_catch/11/param.pickle", + "./logs/meta_rl_catch/12/param.pickle", + "./logs/meta_rl_catch/13/param.pickle", + "./logs/meta_rl_catch/14/param.pickle", + "./logs/meta_rl_catch/15/param.pickle", + "./logs/meta_rl_catch/16/param.pickle", + "./logs/meta_rl_catch/17/param.pickle", + "./logs/meta_rl_catch/18/param.pickle", + "./logs/meta_rl_catch/19/param.pickle", + "./logs/meta_rl_catch/20/param.pickle" + ], + "learning_rate": [1e-3], + "grad_clip": [-1] + }] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rl_grid_ant.json b/configs/lopt_rl_grid_ant.json new file mode 100644 index 0000000..16c5ebc --- /dev/null +++ b/configs/lopt_rl_grid_ant.json @@ -0,0 +1,131 @@ +{ + "env": [{ + "name": ["ant"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [4096], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2] + }], + "optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rl_grid/1/param.pickle", + "./logs/meta_rl_grid/2/param.pickle", + "./logs/meta_rl_grid/3/param.pickle", + "./logs/meta_rl_grid/4/param.pickle", + "./logs/meta_rl_grid/5/param.pickle", + "./logs/meta_rl_grid/6/param.pickle", + "./logs/meta_rl_grid/7/param.pickle", + "./logs/meta_rl_grid/8/param.pickle", + "./logs/meta_rl_grid/9/param.pickle", + "./logs/meta_rl_grid/10/param.pickle", + "./logs/meta_rl_grid/11/param.pickle", + "./logs/meta_rl_grid/12/param.pickle", + "./logs/meta_rl_grid/13/param.pickle", + "./logs/meta_rl_grid/14/param.pickle", + "./logs/meta_rl_grid/15/param.pickle", + "./logs/meta_rl_grid/16/param.pickle", + "./logs/meta_rl_grid/17/param.pickle", + "./logs/meta_rl_grid/18/param.pickle", + "./logs/meta_rl_grid/19/param.pickle", + "./logs/meta_rl_grid/20/param.pickle", + "./logs/meta_rl_grid/21/param.pickle", + "./logs/meta_rl_grid/22/param.pickle", + "./logs/meta_rl_grid/23/param.pickle", + "./logs/meta_rl_grid/24/param.pickle", + "./logs/meta_rl_grid/25/param.pickle", + "./logs/meta_rl_grid/26/param.pickle", + "./logs/meta_rl_grid/27/param.pickle", + "./logs/meta_rl_grid/28/param.pickle", + "./logs/meta_rl_grid/29/param.pickle", + "./logs/meta_rl_grid/30/param.pickle", + "./logs/meta_rl_grid/31/param.pickle", + "./logs/meta_rl_grid/32/param.pickle", + "./logs/meta_rl_grid/33/param.pickle", + "./logs/meta_rl_grid/34/param.pickle", + "./logs/meta_rl_grid/35/param.pickle", + "./logs/meta_rl_grid/36/param.pickle", + "./logs/meta_rl_grid/37/param.pickle", + "./logs/meta_rl_grid/38/param.pickle", + "./logs/meta_rl_grid/39/param.pickle", + "./logs/meta_rl_grid/40/param.pickle", + "./logs/meta_rl_grid/41/param.pickle", + "./logs/meta_rl_grid/42/param.pickle", + "./logs/meta_rl_grid/43/param.pickle", + "./logs/meta_rl_grid/44/param.pickle", + "./logs/meta_rl_grid/45/param.pickle", + "./logs/meta_rl_grid/46/param.pickle", + "./logs/meta_rl_grid/47/param.pickle", + "./logs/meta_rl_grid/48/param.pickle", + "./logs/meta_rl_grid/49/param.pickle", + "./logs/meta_rl_grid/50/param.pickle", + "./logs/meta_rl_grid/51/param.pickle", + "./logs/meta_rl_grid/52/param.pickle", + "./logs/meta_rl_grid/53/param.pickle", + "./logs/meta_rl_grid/54/param.pickle", + "./logs/meta_rl_grid/55/param.pickle", + "./logs/meta_rl_grid/56/param.pickle", + "./logs/meta_rl_grid/57/param.pickle", + "./logs/meta_rl_grid/58/param.pickle", + "./logs/meta_rl_grid/59/param.pickle", + "./logs/meta_rl_grid/60/param.pickle", + "./logs/meta_rl_grid/61/param.pickle", + "./logs/meta_rl_grid/62/param.pickle", + "./logs/meta_rl_grid/63/param.pickle", + "./logs/meta_rl_grid/64/param.pickle", + "./logs/meta_rl_grid/65/param.pickle", + "./logs/meta_rl_grid/66/param.pickle", + "./logs/meta_rl_grid/67/param.pickle", + "./logs/meta_rl_grid/68/param.pickle", + "./logs/meta_rl_grid/69/param.pickle", + "./logs/meta_rl_grid/70/param.pickle", + "./logs/meta_rl_grid/71/param.pickle", + "./logs/meta_rl_grid/72/param.pickle", + "./logs/meta_rl_grid/73/param.pickle", + "./logs/meta_rl_grid/74/param.pickle", + "./logs/meta_rl_grid/75/param.pickle", + "./logs/meta_rl_grid/76/param.pickle", + "./logs/meta_rl_grid/77/param.pickle", + "./logs/meta_rl_grid/78/param.pickle", + "./logs/meta_rl_grid/79/param.pickle", + "./logs/meta_rl_grid/80/param.pickle", + "./logs/meta_rl_grid/81/param.pickle", + "./logs/meta_rl_grid/82/param.pickle", + "./logs/meta_rl_grid/83/param.pickle", + "./logs/meta_rl_grid/84/param.pickle", + "./logs/meta_rl_grid/85/param.pickle", + "./logs/meta_rl_grid/86/param.pickle", + "./logs/meta_rl_grid/87/param.pickle", + "./logs/meta_rl_grid/88/param.pickle", + "./logs/meta_rl_grid/89/param.pickle", + "./logs/meta_rl_grid/90/param.pickle", + "./logs/meta_rl_grid/91/param.pickle", + "./logs/meta_rl_grid/92/param.pickle", + "./logs/meta_rl_grid/93/param.pickle", + "./logs/meta_rl_grid/94/param.pickle", + "./logs/meta_rl_grid/95/param.pickle", + "./logs/meta_rl_grid/96/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [2048], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rl_grid_humanoid.json b/configs/lopt_rl_grid_humanoid.json new file mode 100644 index 0000000..78c5f5c --- /dev/null +++ b/configs/lopt_rl_grid_humanoid.json @@ -0,0 +1,131 @@ +{ + "env": [{ + "name": ["humanoid"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [0.1], + "num_envs": [2048], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [10], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [8], + "entropy_weight": [1e-3] + }], + "optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rl_grid/1/param.pickle", + "./logs/meta_rl_grid/2/param.pickle", + "./logs/meta_rl_grid/3/param.pickle", + "./logs/meta_rl_grid/4/param.pickle", + "./logs/meta_rl_grid/5/param.pickle", + "./logs/meta_rl_grid/6/param.pickle", + "./logs/meta_rl_grid/7/param.pickle", + "./logs/meta_rl_grid/8/param.pickle", + "./logs/meta_rl_grid/9/param.pickle", + "./logs/meta_rl_grid/10/param.pickle", + "./logs/meta_rl_grid/11/param.pickle", + "./logs/meta_rl_grid/12/param.pickle", + "./logs/meta_rl_grid/13/param.pickle", + "./logs/meta_rl_grid/14/param.pickle", + "./logs/meta_rl_grid/15/param.pickle", + "./logs/meta_rl_grid/16/param.pickle", + "./logs/meta_rl_grid/17/param.pickle", + "./logs/meta_rl_grid/18/param.pickle", + "./logs/meta_rl_grid/19/param.pickle", + "./logs/meta_rl_grid/20/param.pickle", + "./logs/meta_rl_grid/21/param.pickle", + "./logs/meta_rl_grid/22/param.pickle", + "./logs/meta_rl_grid/23/param.pickle", + "./logs/meta_rl_grid/24/param.pickle", + "./logs/meta_rl_grid/25/param.pickle", + "./logs/meta_rl_grid/26/param.pickle", + "./logs/meta_rl_grid/27/param.pickle", + "./logs/meta_rl_grid/28/param.pickle", + "./logs/meta_rl_grid/29/param.pickle", + "./logs/meta_rl_grid/30/param.pickle", + "./logs/meta_rl_grid/31/param.pickle", + "./logs/meta_rl_grid/32/param.pickle", + "./logs/meta_rl_grid/33/param.pickle", + "./logs/meta_rl_grid/34/param.pickle", + "./logs/meta_rl_grid/35/param.pickle", + "./logs/meta_rl_grid/36/param.pickle", + "./logs/meta_rl_grid/37/param.pickle", + "./logs/meta_rl_grid/38/param.pickle", + "./logs/meta_rl_grid/39/param.pickle", + "./logs/meta_rl_grid/40/param.pickle", + "./logs/meta_rl_grid/41/param.pickle", + "./logs/meta_rl_grid/42/param.pickle", + "./logs/meta_rl_grid/43/param.pickle", + "./logs/meta_rl_grid/44/param.pickle", + "./logs/meta_rl_grid/45/param.pickle", + "./logs/meta_rl_grid/46/param.pickle", + "./logs/meta_rl_grid/47/param.pickle", + "./logs/meta_rl_grid/48/param.pickle", + "./logs/meta_rl_grid/49/param.pickle", + "./logs/meta_rl_grid/50/param.pickle", + "./logs/meta_rl_grid/51/param.pickle", + "./logs/meta_rl_grid/52/param.pickle", + "./logs/meta_rl_grid/53/param.pickle", + "./logs/meta_rl_grid/54/param.pickle", + "./logs/meta_rl_grid/55/param.pickle", + "./logs/meta_rl_grid/56/param.pickle", + "./logs/meta_rl_grid/57/param.pickle", + "./logs/meta_rl_grid/58/param.pickle", + "./logs/meta_rl_grid/59/param.pickle", + "./logs/meta_rl_grid/60/param.pickle", + "./logs/meta_rl_grid/61/param.pickle", + "./logs/meta_rl_grid/62/param.pickle", + "./logs/meta_rl_grid/63/param.pickle", + "./logs/meta_rl_grid/64/param.pickle", + "./logs/meta_rl_grid/65/param.pickle", + "./logs/meta_rl_grid/66/param.pickle", + "./logs/meta_rl_grid/67/param.pickle", + "./logs/meta_rl_grid/68/param.pickle", + "./logs/meta_rl_grid/69/param.pickle", + "./logs/meta_rl_grid/70/param.pickle", + "./logs/meta_rl_grid/71/param.pickle", + "./logs/meta_rl_grid/72/param.pickle", + "./logs/meta_rl_grid/73/param.pickle", + "./logs/meta_rl_grid/74/param.pickle", + "./logs/meta_rl_grid/75/param.pickle", + "./logs/meta_rl_grid/76/param.pickle", + "./logs/meta_rl_grid/77/param.pickle", + "./logs/meta_rl_grid/78/param.pickle", + "./logs/meta_rl_grid/79/param.pickle", + "./logs/meta_rl_grid/80/param.pickle", + "./logs/meta_rl_grid/81/param.pickle", + "./logs/meta_rl_grid/82/param.pickle", + "./logs/meta_rl_grid/83/param.pickle", + "./logs/meta_rl_grid/84/param.pickle", + "./logs/meta_rl_grid/85/param.pickle", + "./logs/meta_rl_grid/86/param.pickle", + "./logs/meta_rl_grid/87/param.pickle", + "./logs/meta_rl_grid/88/param.pickle", + "./logs/meta_rl_grid/89/param.pickle", + "./logs/meta_rl_grid/90/param.pickle", + "./logs/meta_rl_grid/91/param.pickle", + "./logs/meta_rl_grid/92/param.pickle", + "./logs/meta_rl_grid/93/param.pickle", + "./logs/meta_rl_grid/94/param.pickle", + "./logs/meta_rl_grid/95/param.pickle", + "./logs/meta_rl_grid/96/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rl_grid_pendulum.json b/configs/lopt_rl_grid_pendulum.json new file mode 100644 index 0000000..26afee7 --- /dev/null +++ b/configs/lopt_rl_grid_pendulum.json @@ -0,0 +1,131 @@ +{ + "env": [{ + "name": ["inverted_double_pendulum"], + "train_steps": [2e7], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [2048], + "num_evals": [20], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2] + }], + "optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rl_grid/1/param.pickle", + "./logs/meta_rl_grid/2/param.pickle", + "./logs/meta_rl_grid/3/param.pickle", + "./logs/meta_rl_grid/4/param.pickle", + "./logs/meta_rl_grid/5/param.pickle", + "./logs/meta_rl_grid/6/param.pickle", + "./logs/meta_rl_grid/7/param.pickle", + "./logs/meta_rl_grid/8/param.pickle", + "./logs/meta_rl_grid/9/param.pickle", + "./logs/meta_rl_grid/10/param.pickle", + "./logs/meta_rl_grid/11/param.pickle", + "./logs/meta_rl_grid/12/param.pickle", + "./logs/meta_rl_grid/13/param.pickle", + "./logs/meta_rl_grid/14/param.pickle", + "./logs/meta_rl_grid/15/param.pickle", + "./logs/meta_rl_grid/16/param.pickle", + "./logs/meta_rl_grid/17/param.pickle", + "./logs/meta_rl_grid/18/param.pickle", + "./logs/meta_rl_grid/19/param.pickle", + "./logs/meta_rl_grid/20/param.pickle", + "./logs/meta_rl_grid/21/param.pickle", + "./logs/meta_rl_grid/22/param.pickle", + "./logs/meta_rl_grid/23/param.pickle", + "./logs/meta_rl_grid/24/param.pickle", + "./logs/meta_rl_grid/25/param.pickle", + "./logs/meta_rl_grid/26/param.pickle", + "./logs/meta_rl_grid/27/param.pickle", + "./logs/meta_rl_grid/28/param.pickle", + "./logs/meta_rl_grid/29/param.pickle", + "./logs/meta_rl_grid/30/param.pickle", + "./logs/meta_rl_grid/31/param.pickle", + "./logs/meta_rl_grid/32/param.pickle", + "./logs/meta_rl_grid/33/param.pickle", + "./logs/meta_rl_grid/34/param.pickle", + "./logs/meta_rl_grid/35/param.pickle", + "./logs/meta_rl_grid/36/param.pickle", + "./logs/meta_rl_grid/37/param.pickle", + "./logs/meta_rl_grid/38/param.pickle", + "./logs/meta_rl_grid/39/param.pickle", + "./logs/meta_rl_grid/40/param.pickle", + "./logs/meta_rl_grid/41/param.pickle", + "./logs/meta_rl_grid/42/param.pickle", + "./logs/meta_rl_grid/43/param.pickle", + "./logs/meta_rl_grid/44/param.pickle", + "./logs/meta_rl_grid/45/param.pickle", + "./logs/meta_rl_grid/46/param.pickle", + "./logs/meta_rl_grid/47/param.pickle", + "./logs/meta_rl_grid/48/param.pickle", + "./logs/meta_rl_grid/49/param.pickle", + "./logs/meta_rl_grid/50/param.pickle", + "./logs/meta_rl_grid/51/param.pickle", + "./logs/meta_rl_grid/52/param.pickle", + "./logs/meta_rl_grid/53/param.pickle", + "./logs/meta_rl_grid/54/param.pickle", + "./logs/meta_rl_grid/55/param.pickle", + "./logs/meta_rl_grid/56/param.pickle", + "./logs/meta_rl_grid/57/param.pickle", + "./logs/meta_rl_grid/58/param.pickle", + "./logs/meta_rl_grid/59/param.pickle", + "./logs/meta_rl_grid/60/param.pickle", + "./logs/meta_rl_grid/61/param.pickle", + "./logs/meta_rl_grid/62/param.pickle", + "./logs/meta_rl_grid/63/param.pickle", + "./logs/meta_rl_grid/64/param.pickle", + "./logs/meta_rl_grid/65/param.pickle", + "./logs/meta_rl_grid/66/param.pickle", + "./logs/meta_rl_grid/67/param.pickle", + "./logs/meta_rl_grid/68/param.pickle", + "./logs/meta_rl_grid/69/param.pickle", + "./logs/meta_rl_grid/70/param.pickle", + "./logs/meta_rl_grid/71/param.pickle", + "./logs/meta_rl_grid/72/param.pickle", + "./logs/meta_rl_grid/73/param.pickle", + "./logs/meta_rl_grid/74/param.pickle", + "./logs/meta_rl_grid/75/param.pickle", + "./logs/meta_rl_grid/76/param.pickle", + "./logs/meta_rl_grid/77/param.pickle", + "./logs/meta_rl_grid/78/param.pickle", + "./logs/meta_rl_grid/79/param.pickle", + "./logs/meta_rl_grid/80/param.pickle", + "./logs/meta_rl_grid/81/param.pickle", + "./logs/meta_rl_grid/82/param.pickle", + "./logs/meta_rl_grid/83/param.pickle", + "./logs/meta_rl_grid/84/param.pickle", + "./logs/meta_rl_grid/85/param.pickle", + "./logs/meta_rl_grid/86/param.pickle", + "./logs/meta_rl_grid/87/param.pickle", + "./logs/meta_rl_grid/88/param.pickle", + "./logs/meta_rl_grid/89/param.pickle", + "./logs/meta_rl_grid/90/param.pickle", + "./logs/meta_rl_grid/91/param.pickle", + "./logs/meta_rl_grid/92/param.pickle", + "./logs/meta_rl_grid/93/param.pickle", + "./logs/meta_rl_grid/94/param.pickle", + "./logs/meta_rl_grid/95/param.pickle", + "./logs/meta_rl_grid/96/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rl_grid_walker2d.json b/configs/lopt_rl_grid_walker2d.json new file mode 100644 index 0000000..97e3179 --- /dev/null +++ b/configs/lopt_rl_grid_walker2d.json @@ -0,0 +1,131 @@ +{ + "env": [{ + "name": ["walker2d"], + "train_steps": [5e7], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [5], + "num_envs": [2048], + "num_evals": [20], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [8], + "entropy_weight": [1e-3] + }], + "optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rl_grid/1/param.pickle", + "./logs/meta_rl_grid/2/param.pickle", + "./logs/meta_rl_grid/3/param.pickle", + "./logs/meta_rl_grid/4/param.pickle", + "./logs/meta_rl_grid/5/param.pickle", + "./logs/meta_rl_grid/6/param.pickle", + "./logs/meta_rl_grid/7/param.pickle", + "./logs/meta_rl_grid/8/param.pickle", + "./logs/meta_rl_grid/9/param.pickle", + "./logs/meta_rl_grid/10/param.pickle", + "./logs/meta_rl_grid/11/param.pickle", + "./logs/meta_rl_grid/12/param.pickle", + "./logs/meta_rl_grid/13/param.pickle", + "./logs/meta_rl_grid/14/param.pickle", + "./logs/meta_rl_grid/15/param.pickle", + "./logs/meta_rl_grid/16/param.pickle", + "./logs/meta_rl_grid/17/param.pickle", + "./logs/meta_rl_grid/18/param.pickle", + "./logs/meta_rl_grid/19/param.pickle", + "./logs/meta_rl_grid/20/param.pickle", + "./logs/meta_rl_grid/21/param.pickle", + "./logs/meta_rl_grid/22/param.pickle", + "./logs/meta_rl_grid/23/param.pickle", + "./logs/meta_rl_grid/24/param.pickle", + "./logs/meta_rl_grid/25/param.pickle", + "./logs/meta_rl_grid/26/param.pickle", + "./logs/meta_rl_grid/27/param.pickle", + "./logs/meta_rl_grid/28/param.pickle", + "./logs/meta_rl_grid/29/param.pickle", + "./logs/meta_rl_grid/30/param.pickle", + "./logs/meta_rl_grid/31/param.pickle", + "./logs/meta_rl_grid/32/param.pickle", + "./logs/meta_rl_grid/33/param.pickle", + "./logs/meta_rl_grid/34/param.pickle", + "./logs/meta_rl_grid/35/param.pickle", + "./logs/meta_rl_grid/36/param.pickle", + "./logs/meta_rl_grid/37/param.pickle", + "./logs/meta_rl_grid/38/param.pickle", + "./logs/meta_rl_grid/39/param.pickle", + "./logs/meta_rl_grid/40/param.pickle", + "./logs/meta_rl_grid/41/param.pickle", + "./logs/meta_rl_grid/42/param.pickle", + "./logs/meta_rl_grid/43/param.pickle", + "./logs/meta_rl_grid/44/param.pickle", + "./logs/meta_rl_grid/45/param.pickle", + "./logs/meta_rl_grid/46/param.pickle", + "./logs/meta_rl_grid/47/param.pickle", + "./logs/meta_rl_grid/48/param.pickle", + "./logs/meta_rl_grid/49/param.pickle", + "./logs/meta_rl_grid/50/param.pickle", + "./logs/meta_rl_grid/51/param.pickle", + "./logs/meta_rl_grid/52/param.pickle", + "./logs/meta_rl_grid/53/param.pickle", + "./logs/meta_rl_grid/54/param.pickle", + "./logs/meta_rl_grid/55/param.pickle", + "./logs/meta_rl_grid/56/param.pickle", + "./logs/meta_rl_grid/57/param.pickle", + "./logs/meta_rl_grid/58/param.pickle", + "./logs/meta_rl_grid/59/param.pickle", + "./logs/meta_rl_grid/60/param.pickle", + "./logs/meta_rl_grid/61/param.pickle", + "./logs/meta_rl_grid/62/param.pickle", + "./logs/meta_rl_grid/63/param.pickle", + "./logs/meta_rl_grid/64/param.pickle", + "./logs/meta_rl_grid/65/param.pickle", + "./logs/meta_rl_grid/66/param.pickle", + "./logs/meta_rl_grid/67/param.pickle", + "./logs/meta_rl_grid/68/param.pickle", + "./logs/meta_rl_grid/69/param.pickle", + "./logs/meta_rl_grid/70/param.pickle", + "./logs/meta_rl_grid/71/param.pickle", + "./logs/meta_rl_grid/72/param.pickle", + "./logs/meta_rl_grid/73/param.pickle", + "./logs/meta_rl_grid/74/param.pickle", + "./logs/meta_rl_grid/75/param.pickle", + "./logs/meta_rl_grid/76/param.pickle", + "./logs/meta_rl_grid/77/param.pickle", + "./logs/meta_rl_grid/78/param.pickle", + "./logs/meta_rl_grid/79/param.pickle", + "./logs/meta_rl_grid/80/param.pickle", + "./logs/meta_rl_grid/81/param.pickle", + "./logs/meta_rl_grid/82/param.pickle", + "./logs/meta_rl_grid/83/param.pickle", + "./logs/meta_rl_grid/84/param.pickle", + "./logs/meta_rl_grid/85/param.pickle", + "./logs/meta_rl_grid/86/param.pickle", + "./logs/meta_rl_grid/87/param.pickle", + "./logs/meta_rl_grid/88/param.pickle", + "./logs/meta_rl_grid/89/param.pickle", + "./logs/meta_rl_grid/90/param.pickle", + "./logs/meta_rl_grid/91/param.pickle", + "./logs/meta_rl_grid/92/param.pickle", + "./logs/meta_rl_grid/93/param.pickle", + "./logs/meta_rl_grid/94/param.pickle", + "./logs/meta_rl_grid/95/param.pickle", + "./logs/meta_rl_grid/96/param.pickle" + ], + "learning_rate": [3e-5], + "grad_clip": [1] + }] + }], + "batch_size": [512], + "discount": [0.997], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rl_humanoid.json b/configs/lopt_rl_humanoid.json new file mode 100644 index 0000000..3b74fb4 --- /dev/null +++ b/configs/lopt_rl_humanoid.json @@ -0,0 +1,85 @@ +{ + "env": [{ + "name": ["humanoid"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [0.1], + "num_envs": [2048], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [10], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [8], + "entropy_weight": [1e-3] + }], + "optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rl_humanoid/1/param.pickle", + "./logs/meta_rl_humanoid/2/param.pickle", + "./logs/meta_rl_humanoid/3/param.pickle", + "./logs/meta_rl_humanoid/4/param.pickle", + "./logs/meta_rl_humanoid/5/param.pickle", + "./logs/meta_rl_humanoid/6/param.pickle", + "./logs/meta_rl_humanoid/7/param.pickle", + "./logs/meta_rl_humanoid/8/param.pickle", + "./logs/meta_rl_humanoid/9/param.pickle", + "./logs/meta_rl_humanoid/10/param.pickle", + "./logs/meta_rl_humanoid/11/param.pickle", + "./logs/meta_rl_humanoid/12/param.pickle", + "./logs/meta_rl_humanoid/13/param.pickle", + "./logs/meta_rl_humanoid/14/param.pickle", + "./logs/meta_rl_humanoid/15/param.pickle", + "./logs/meta_rl_humanoid/16/param.pickle", + "./logs/meta_rl_humanoid/17/param.pickle", + "./logs/meta_rl_humanoid/18/param.pickle", + "./logs/meta_rl_humanoid/19/param.pickle", + "./logs/meta_rl_humanoid/20/param.pickle", + "./logs/meta_rl_humanoid/21/param.pickle", + "./logs/meta_rl_humanoid/22/param.pickle", + "./logs/meta_rl_humanoid/23/param.pickle", + "./logs/meta_rl_humanoid/24/param.pickle", + "./logs/meta_rl_humanoid/25/param.pickle", + "./logs/meta_rl_humanoid/26/param.pickle", + "./logs/meta_rl_humanoid/27/param.pickle", + "./logs/meta_rl_humanoid/28/param.pickle", + "./logs/meta_rl_humanoid/29/param.pickle", + "./logs/meta_rl_humanoid/30/param.pickle", + "./logs/meta_rl_humanoid/31/param.pickle", + "./logs/meta_rl_humanoid/32/param.pickle", + "./logs/meta_rl_humanoid/33/param.pickle", + "./logs/meta_rl_humanoid/34/param.pickle", + "./logs/meta_rl_humanoid/35/param.pickle", + "./logs/meta_rl_humanoid/36/param.pickle", + "./logs/meta_rl_humanoid/37/param.pickle", + "./logs/meta_rl_humanoid/38/param.pickle", + "./logs/meta_rl_humanoid/39/param.pickle", + "./logs/meta_rl_humanoid/40/param.pickle", + "./logs/meta_rl_humanoid/41/param.pickle", + "./logs/meta_rl_humanoid/42/param.pickle", + "./logs/meta_rl_humanoid/43/param.pickle", + "./logs/meta_rl_humanoid/44/param.pickle", + "./logs/meta_rl_humanoid/45/param.pickle", + "./logs/meta_rl_humanoid/46/param.pickle", + "./logs/meta_rl_humanoid/47/param.pickle", + "./logs/meta_rl_humanoid/48/param.pickle", + "./logs/meta_rl_humanoid/49/param.pickle", + "./logs/meta_rl_humanoid/50/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rl_sdl.json b/configs/lopt_rl_sdl.json new file mode 100644 index 0000000..231d3f5 --- /dev/null +++ b/configs/lopt_rl_sdl.json @@ -0,0 +1,66 @@ +{ + "env": [{ + "name": [["small_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["A2C"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rl_sdl/1/param.pickle", + "./logs/meta_rl_sdl/2/param.pickle", + "./logs/meta_rl_sdl/3/param.pickle", + "./logs/meta_rl_sdl/4/param.pickle", + "./logs/meta_rl_sdl/5/param.pickle", + "./logs/meta_rl_sdl/6/param.pickle", + "./logs/meta_rl_sdl/7/param.pickle", + "./logs/meta_rl_sdl/8/param.pickle", + "./logs/meta_rl_sdl/9/param.pickle", + "./logs/meta_rl_sdl/10/param.pickle", + "./logs/meta_rl_sdl/11/param.pickle", + "./logs/meta_rl_sdl/12/param.pickle", + "./logs/meta_rl_sdl/13/param.pickle", + "./logs/meta_rl_sdl/14/param.pickle", + "./logs/meta_rl_sdl/15/param.pickle", + "./logs/meta_rl_sdl/16/param.pickle", + "./logs/meta_rl_sdl/17/param.pickle", + "./logs/meta_rl_sdl/18/param.pickle", + "./logs/meta_rl_sdl/19/param.pickle", + "./logs/meta_rl_sdl/20/param.pickle", + "./logs/meta_rl_sdl/21/param.pickle", + "./logs/meta_rl_sdl/22/param.pickle", + "./logs/meta_rl_sdl/23/param.pickle", + "./logs/meta_rl_sdl/24/param.pickle", + "./logs/meta_rl_sdl/25/param.pickle", + "./logs/meta_rl_sdl/26/param.pickle", + "./logs/meta_rl_sdl/27/param.pickle", + "./logs/meta_rl_sdl/28/param.pickle", + "./logs/meta_rl_sdl/29/param.pickle", + "./logs/meta_rl_sdl/30/param.pickle", + "./logs/meta_rl_sdl/31/param.pickle", + "./logs/meta_rl_sdl/32/param.pickle", + "./logs/meta_rl_sdl/33/param.pickle", + "./logs/meta_rl_sdl/34/param.pickle", + "./logs/meta_rl_sdl/35/param.pickle", + "./logs/meta_rl_sdl/36/param.pickle", + "./logs/meta_rl_sdl/37/param.pickle", + "./logs/meta_rl_sdl/38/param.pickle", + "./logs/meta_rl_sdl/39/param.pickle", + "./logs/meta_rl_sdl/40/param.pickle" + ], + "learning_rate": [0.03], + "grad_clip": [1] + }] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rlp_ant.json b/configs/lopt_rlp_ant.json new file mode 100644 index 0000000..c1fe8a1 --- /dev/null +++ b/configs/lopt_rlp_ant.json @@ -0,0 +1,45 @@ +{ + "env": [{ + "name": ["ant"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [4096], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2] + }], + "optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rlp_ant/1/param.pickle", + "./logs/meta_rlp_ant/2/param.pickle", + "./logs/meta_rlp_ant/3/param.pickle", + "./logs/meta_rlp_ant/4/param.pickle", + "./logs/meta_rlp_ant/5/param.pickle", + "./logs/meta_rlp_ant/6/param.pickle", + "./logs/meta_rlp_ant/7/param.pickle", + "./logs/meta_rlp_ant/8/param.pickle", + "./logs/meta_rlp_ant/9/param.pickle", + "./logs/meta_rlp_ant/10/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [2048], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rlp_bdl.json b/configs/lopt_rlp_bdl.json new file mode 100644 index 0000000..2834d80 --- /dev/null +++ b/configs/lopt_rlp_bdl.json @@ -0,0 +1,36 @@ +{ + "env": [{ + "name": [["big_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["A2C"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rlp_bdl/1/param.pickle", + "./logs/meta_rlp_bdl/2/param.pickle", + "./logs/meta_rlp_bdl/3/param.pickle", + "./logs/meta_rlp_bdl/4/param.pickle", + "./logs/meta_rlp_bdl/5/param.pickle", + "./logs/meta_rlp_bdl/6/param.pickle", + "./logs/meta_rlp_bdl/7/param.pickle", + "./logs/meta_rlp_bdl/8/param.pickle", + "./logs/meta_rlp_bdl/9/param.pickle", + "./logs/meta_rlp_bdl/10/param.pickle" + ], + "learning_rate": [0.003], + "grad_clip": [1] + }] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rlp_humanoid.json b/configs/lopt_rlp_humanoid.json new file mode 100644 index 0000000..1a9e095 --- /dev/null +++ b/configs/lopt_rlp_humanoid.json @@ -0,0 +1,45 @@ +{ + "env": [{ + "name": ["humanoid"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [0.1], + "num_envs": [2048], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [10], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [8], + "entropy_weight": [1e-3] + }], + "optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rlp_humanoid/1/param.pickle", + "./logs/meta_rlp_humanoid/2/param.pickle", + "./logs/meta_rlp_humanoid/3/param.pickle", + "./logs/meta_rlp_humanoid/4/param.pickle", + "./logs/meta_rlp_humanoid/5/param.pickle", + "./logs/meta_rlp_humanoid/6/param.pickle", + "./logs/meta_rlp_humanoid/7/param.pickle", + "./logs/meta_rlp_humanoid/8/param.pickle", + "./logs/meta_rlp_humanoid/9/param.pickle", + "./logs/meta_rlp_humanoid/10/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_rlp_sdl.json b/configs/lopt_rlp_sdl.json new file mode 100644 index 0000000..c3c8cb2 --- /dev/null +++ b/configs/lopt_rlp_sdl.json @@ -0,0 +1,36 @@ +{ + "env": [{ + "name": [["small_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["A2C"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [ + "./logs/meta_rlp_sdl/1/param.pickle", + "./logs/meta_rlp_sdl/2/param.pickle", + "./logs/meta_rlp_sdl/3/param.pickle", + "./logs/meta_rlp_sdl/4/param.pickle", + "./logs/meta_rlp_sdl/5/param.pickle", + "./logs/meta_rlp_sdl/6/param.pickle", + "./logs/meta_rlp_sdl/7/param.pickle", + "./logs/meta_rlp_sdl/8/param.pickle", + "./logs/meta_rlp_sdl/9/param.pickle", + "./logs/meta_rlp_sdl/10/param.pickle" + ], + "learning_rate": [0.03], + "grad_clip": [1] + }] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_star_ant.json b/configs/lopt_star_ant.json new file mode 100644 index 0000000..18381aa --- /dev/null +++ b/configs/lopt_star_ant.json @@ -0,0 +1,86 @@ +{ + "env": [{ + "name": ["ant"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [4096], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPOstar"], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2] + }], + "optim": [{ + "name": ["Star"], + "kwargs": [{ + "train_steps": [1e8], + "param_load_path": [ + "./logs/meta_star_ant/1/param.pickle", + "./logs/meta_star_ant/2/param.pickle", + "./logs/meta_star_ant/3/param.pickle", + "./logs/meta_star_ant/4/param.pickle", + "./logs/meta_star_ant/5/param.pickle", + "./logs/meta_star_ant/6/param.pickle", + "./logs/meta_star_ant/7/param.pickle", + "./logs/meta_star_ant/8/param.pickle", + "./logs/meta_star_ant/9/param.pickle", + "./logs/meta_star_ant/10/param.pickle", + "./logs/meta_star_ant/11/param.pickle", + "./logs/meta_star_ant/12/param.pickle", + "./logs/meta_star_ant/13/param.pickle", + "./logs/meta_star_ant/14/param.pickle", + "./logs/meta_star_ant/15/param.pickle", + "./logs/meta_star_ant/16/param.pickle", + "./logs/meta_star_ant/17/param.pickle", + "./logs/meta_star_ant/18/param.pickle", + "./logs/meta_star_ant/19/param.pickle", + "./logs/meta_star_ant/20/param.pickle", + "./logs/meta_star_ant/21/param.pickle", + "./logs/meta_star_ant/22/param.pickle", + "./logs/meta_star_ant/23/param.pickle", + "./logs/meta_star_ant/24/param.pickle", + "./logs/meta_star_ant/25/param.pickle", + "./logs/meta_star_ant/26/param.pickle", + "./logs/meta_star_ant/27/param.pickle", + "./logs/meta_star_ant/28/param.pickle", + "./logs/meta_star_ant/29/param.pickle", + "./logs/meta_star_ant/30/param.pickle", + "./logs/meta_star_ant/31/param.pickle", + "./logs/meta_star_ant/32/param.pickle", + "./logs/meta_star_ant/33/param.pickle", + "./logs/meta_star_ant/34/param.pickle", + "./logs/meta_star_ant/35/param.pickle", + "./logs/meta_star_ant/36/param.pickle", + "./logs/meta_star_ant/37/param.pickle", + "./logs/meta_star_ant/38/param.pickle", + "./logs/meta_star_ant/39/param.pickle", + "./logs/meta_star_ant/40/param.pickle", + "./logs/meta_star_ant/41/param.pickle", + "./logs/meta_star_ant/42/param.pickle", + "./logs/meta_star_ant/43/param.pickle", + "./logs/meta_star_ant/44/param.pickle", + "./logs/meta_star_ant/45/param.pickle", + "./logs/meta_star_ant/46/param.pickle", + "./logs/meta_star_ant/47/param.pickle", + "./logs/meta_star_ant/48/param.pickle", + "./logs/meta_star_ant/49/param.pickle", + "./logs/meta_star_ant/50/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [2048], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_star_bdl.json b/configs/lopt_star_bdl.json new file mode 100644 index 0000000..ff912f5 --- /dev/null +++ b/configs/lopt_star_bdl.json @@ -0,0 +1,67 @@ +{ + "env": [{ + "name": [["big_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["A2Cstar"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["Star"], + "kwargs": [{ + "train_steps": [3e7], + "param_load_path": [ + "./logs/meta_star_bdl/1/param.pickle", + "./logs/meta_star_bdl/2/param.pickle", + "./logs/meta_star_bdl/3/param.pickle", + "./logs/meta_star_bdl/4/param.pickle", + "./logs/meta_star_bdl/5/param.pickle", + "./logs/meta_star_bdl/6/param.pickle", + "./logs/meta_star_bdl/7/param.pickle", + "./logs/meta_star_bdl/8/param.pickle", + "./logs/meta_star_bdl/9/param.pickle", + "./logs/meta_star_bdl/10/param.pickle", + "./logs/meta_star_bdl/11/param.pickle", + "./logs/meta_star_bdl/12/param.pickle", + "./logs/meta_star_bdl/13/param.pickle", + "./logs/meta_star_bdl/14/param.pickle", + "./logs/meta_star_bdl/15/param.pickle", + "./logs/meta_star_bdl/16/param.pickle", + "./logs/meta_star_bdl/17/param.pickle", + "./logs/meta_star_bdl/18/param.pickle", + "./logs/meta_star_bdl/19/param.pickle", + "./logs/meta_star_bdl/20/param.pickle", + "./logs/meta_star_bdl/21/param.pickle", + "./logs/meta_star_bdl/22/param.pickle", + "./logs/meta_star_bdl/23/param.pickle", + "./logs/meta_star_bdl/24/param.pickle", + "./logs/meta_star_bdl/25/param.pickle", + "./logs/meta_star_bdl/26/param.pickle", + "./logs/meta_star_bdl/27/param.pickle", + "./logs/meta_star_bdl/28/param.pickle", + "./logs/meta_star_bdl/29/param.pickle", + "./logs/meta_star_bdl/30/param.pickle", + "./logs/meta_star_bdl/31/param.pickle", + "./logs/meta_star_bdl/32/param.pickle", + "./logs/meta_star_bdl/33/param.pickle", + "./logs/meta_star_bdl/34/param.pickle", + "./logs/meta_star_bdl/35/param.pickle", + "./logs/meta_star_bdl/36/param.pickle", + "./logs/meta_star_bdl/37/param.pickle", + "./logs/meta_star_bdl/38/param.pickle", + "./logs/meta_star_bdl/39/param.pickle", + "./logs/meta_star_bdl/40/param.pickle" + ], + "learning_rate": [0.003], + "grad_clip": [1] + }] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_star_catch.json b/configs/lopt_star_catch.json new file mode 100644 index 0000000..497c6e5 --- /dev/null +++ b/configs/lopt_star_catch.json @@ -0,0 +1,47 @@ +{ + "env": [{ + "name": [["catch"]], + "num_envs": [64], + "train_steps": [5e5] + }], + "agent": [{ + "name": ["A2Cstar"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01] + }], + "agent_optim": [{ + "name": ["Star"], + "kwargs": [{ + "train_steps": [5e5], + "param_load_path": [ + "./logs/meta_star_catch/1/param.pickle", + "./logs/meta_star_catch/2/param.pickle", + "./logs/meta_star_catch/3/param.pickle", + "./logs/meta_star_catch/4/param.pickle", + "./logs/meta_star_catch/5/param.pickle", + "./logs/meta_star_catch/6/param.pickle", + "./logs/meta_star_catch/7/param.pickle", + "./logs/meta_star_catch/8/param.pickle", + "./logs/meta_star_catch/9/param.pickle", + "./logs/meta_star_catch/10/param.pickle", + "./logs/meta_star_catch/11/param.pickle", + "./logs/meta_star_catch/12/param.pickle", + "./logs/meta_star_catch/13/param.pickle", + "./logs/meta_star_catch/14/param.pickle", + "./logs/meta_star_catch/15/param.pickle", + "./logs/meta_star_catch/16/param.pickle", + "./logs/meta_star_catch/17/param.pickle", + "./logs/meta_star_catch/18/param.pickle", + "./logs/meta_star_catch/19/param.pickle", + "./logs/meta_star_catch/20/param.pickle" + ], + "learning_rate": [1e-3], + "grad_clip": [-1] + }] + }], + "discount": [0.995], + "seed": [42], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/lopt_star_humanoid.json b/configs/lopt_star_humanoid.json new file mode 100644 index 0000000..848fcfe --- /dev/null +++ b/configs/lopt_star_humanoid.json @@ -0,0 +1,86 @@ +{ + "env": [{ + "name": ["humanoid"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [0.1], + "num_envs": [2048], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPOstar"], + "gae_lambda": [0.95], + "rollout_steps": [10], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [8], + "entropy_weight": [1e-3] + }], + "optim": [{ + "name": ["Star"], + "kwargs": [{ + "train_steps": [1e8], + "param_load_path": [ + "./logs/meta_star_humanoid/1/param.pickle", + "./logs/meta_star_humanoid/2/param.pickle", + "./logs/meta_star_humanoid/3/param.pickle", + "./logs/meta_star_humanoid/4/param.pickle", + "./logs/meta_star_humanoid/5/param.pickle", + "./logs/meta_star_humanoid/6/param.pickle", + "./logs/meta_star_humanoid/7/param.pickle", + "./logs/meta_star_humanoid/8/param.pickle", + "./logs/meta_star_humanoid/9/param.pickle", + "./logs/meta_star_humanoid/10/param.pickle", + "./logs/meta_star_humanoid/11/param.pickle", + "./logs/meta_star_humanoid/12/param.pickle", + "./logs/meta_star_humanoid/13/param.pickle", + "./logs/meta_star_humanoid/14/param.pickle", + "./logs/meta_star_humanoid/15/param.pickle", + "./logs/meta_star_humanoid/16/param.pickle", + "./logs/meta_star_humanoid/17/param.pickle", + "./logs/meta_star_humanoid/18/param.pickle", + "./logs/meta_star_humanoid/19/param.pickle", + "./logs/meta_star_humanoid/20/param.pickle", + "./logs/meta_star_humanoid/21/param.pickle", + "./logs/meta_star_humanoid/22/param.pickle", + "./logs/meta_star_humanoid/23/param.pickle", + "./logs/meta_star_humanoid/24/param.pickle", + "./logs/meta_star_humanoid/25/param.pickle", + "./logs/meta_star_humanoid/26/param.pickle", + "./logs/meta_star_humanoid/27/param.pickle", + "./logs/meta_star_humanoid/28/param.pickle", + "./logs/meta_star_humanoid/29/param.pickle", + "./logs/meta_star_humanoid/30/param.pickle", + "./logs/meta_star_humanoid/31/param.pickle", + "./logs/meta_star_humanoid/32/param.pickle", + "./logs/meta_star_humanoid/33/param.pickle", + "./logs/meta_star_humanoid/34/param.pickle", + "./logs/meta_star_humanoid/35/param.pickle", + "./logs/meta_star_humanoid/36/param.pickle", + "./logs/meta_star_humanoid/37/param.pickle", + "./logs/meta_star_humanoid/38/param.pickle", + "./logs/meta_star_humanoid/39/param.pickle", + "./logs/meta_star_humanoid/40/param.pickle", + "./logs/meta_star_humanoid/41/param.pickle", + "./logs/meta_star_humanoid/42/param.pickle", + "./logs/meta_star_humanoid/43/param.pickle", + "./logs/meta_star_humanoid/44/param.pickle", + "./logs/meta_star_humanoid/45/param.pickle", + "./logs/meta_star_humanoid/46/param.pickle", + "./logs/meta_star_humanoid/47/param.pickle", + "./logs/meta_star_humanoid/48/param.pickle", + "./logs/meta_star_humanoid/49/param.pickle", + "./logs/meta_star_humanoid/50/param.pickle" + ], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_l2l_ant.json b/configs/meta_l2l_ant.json new file mode 100644 index 0000000..c568a0f --- /dev/null +++ b/configs/meta_l2l_ant.json @@ -0,0 +1,43 @@ +{ + "env": [{ + "name": ["ant"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [2048], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["MetaPPO"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [8], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2], + "reset_interval": [32, 64, 128, 256, 512] + }], + "agent_optim": [{ + "name": ["L2LGD2"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "display_interval": [10], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_l2l_bdl.json b/configs/meta_l2l_bdl.json new file mode 100644 index 0000000..5c23616 --- /dev/null +++ b/configs/meta_l2l_bdl.json @@ -0,0 +1,35 @@ +{ + "env": [{ + "name": [["big_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["MetaA2C"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01], + "reset_interval": [72, 144, 288, 576] + }], + "agent_optim": [{ + "name": ["L2LGD2"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [0.003], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "discount": [0.995], + "seed": [42], + "display_interval": [50], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_l2l_catch.json b/configs/meta_l2l_catch.json new file mode 100644 index 0000000..f92c798 --- /dev/null +++ b/configs/meta_l2l_catch.json @@ -0,0 +1,35 @@ +{ + "env": [{ + "name": [["catch"]], + "num_envs": [64], + "train_steps": [5e5] + }], + "agent": [{ + "name": ["MetaA2C"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01], + "reset_interval": [32, 64] + }], + "agent_optim": [{ + "name": ["L2LGD2"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [1e-3], + "grad_clip": [-1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-4, 3e-4, 1e-3, 3e-3, 1e-2], "grad_norm": [0.5]}, + {"learning_rate": [1e-4, 3e-4, 1e-3, 3e-3, 1e-2], "grad_clip": [1.0]} + ] + }], + "discount": [0.995], + "seed": [42], + "display_interval": [5], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_l2l_humanoid.json b/configs/meta_l2l_humanoid.json new file mode 100644 index 0000000..31d3c84 --- /dev/null +++ b/configs/meta_l2l_humanoid.json @@ -0,0 +1,43 @@ +{ + "env": [{ + "name": ["humanoid"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [0.1], + "num_envs": [2048], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["MetaPPO"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [8], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-3], + "reset_interval": [32, 64, 128, 256, 512] + }], + "agent_optim": [{ + "name": ["L2LGD2"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "display_interval": [10], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_lin_ant.json b/configs/meta_lin_ant.json new file mode 100644 index 0000000..6a32f5b --- /dev/null +++ b/configs/meta_lin_ant.json @@ -0,0 +1,43 @@ +{ + "env": [{ + "name": ["ant"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [2048], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["MetaPPO"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [8], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2], + "reset_interval": [32, 64, 128, 256, 512] + }], + "agent_optim": [{ + "name": ["LinearOptim"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "display_interval": [10], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_lin_bdl.json b/configs/meta_lin_bdl.json new file mode 100644 index 0000000..2ab62b1 --- /dev/null +++ b/configs/meta_lin_bdl.json @@ -0,0 +1,35 @@ +{ + "env": [{ + "name": [["big_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["MetaA2C"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01], + "reset_interval": [72, 144, 288, 576] + }], + "agent_optim": [{ + "name": ["LinearOptim"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [0.003], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "discount": [0.995], + "seed": [42], + "display_interval": [50], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_lin_catch.json b/configs/meta_lin_catch.json new file mode 100644 index 0000000..fa3feef --- /dev/null +++ b/configs/meta_lin_catch.json @@ -0,0 +1,35 @@ +{ + "env": [{ + "name": [["catch"]], + "num_envs": [64], + "train_steps": [5e5] + }], + "agent": [{ + "name": ["MetaA2C"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01], + "reset_interval": [32, 64] + }], + "agent_optim": [{ + "name": ["LinearOptim"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [1e-3], + "grad_clip": [-1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-4, 3e-4, 1e-3, 3e-3, 1e-2], "grad_norm": [0.5]}, + {"learning_rate": [1e-4, 3e-4, 1e-3, 3e-3, 1e-2], "grad_clip": [1.0]} + ] + }], + "discount": [0.995], + "seed": [42], + "display_interval": [5], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_lin_humanoid.json b/configs/meta_lin_humanoid.json new file mode 100644 index 0000000..624da6d --- /dev/null +++ b/configs/meta_lin_humanoid.json @@ -0,0 +1,43 @@ +{ + "env": [{ + "name": ["humanoid"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [0.1], + "num_envs": [2048], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["MetaPPO"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [8], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-3], + "reset_interval": [32, 64, 128, 256, 512] + }], + "agent_optim": [{ + "name": ["LinearOptim"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "display_interval": [10], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_rl_ant.json b/configs/meta_rl_ant.json new file mode 100644 index 0000000..3bfdb9f --- /dev/null +++ b/configs/meta_rl_ant.json @@ -0,0 +1,43 @@ +{ + "env": [{ + "name": ["ant"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [2048], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["MetaPPO"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [8], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2], + "reset_interval": [32, 64, 128, 256, 512] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "display_interval": [10], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_rl_bdl.json b/configs/meta_rl_bdl.json new file mode 100644 index 0000000..9f5265d --- /dev/null +++ b/configs/meta_rl_bdl.json @@ -0,0 +1,35 @@ +{ + "env": [{ + "name": [["big_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["MetaA2C"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01], + "reset_interval": [72, 144, 288, 576] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [0.003], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "discount": [0.995], + "seed": [42], + "display_interval": [50], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_rl_catch.json b/configs/meta_rl_catch.json new file mode 100644 index 0000000..c8b1080 --- /dev/null +++ b/configs/meta_rl_catch.json @@ -0,0 +1,35 @@ +{ + "env": [{ + "name": [["catch"]], + "num_envs": [64], + "train_steps": [5e5] + }], + "agent": [{ + "name": ["MetaA2C"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01], + "reset_interval": [32, 64] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [1e-3], + "grad_clip": [-1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-4, 3e-4, 1e-3, 3e-3, 1e-2], "grad_norm": [0.5]}, + {"learning_rate": [1e-4, 3e-4, 1e-3, 3e-3, 1e-2], "grad_clip": [1.0]} + ] + }], + "discount": [0.995], + "seed": [42], + "display_interval": [5], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_rl_grid.json b/configs/meta_rl_grid.json new file mode 100644 index 0000000..56b52de --- /dev/null +++ b/configs/meta_rl_grid.json @@ -0,0 +1,37 @@ +{ + "env": [{ + "name": [["small_dense_long", "small_dense_short", "big_sparse_short", "big_dense_short", "big_sparse_long", "big_dense_long"]], + "reward_scaling": [[1e3, 1e2, 1e2, 1e1, 1e1, 1e0]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["MetaA2C"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01], + "reset_interval": [72, 144, 288, 576] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [[1e-3, 3e-3, 3e-3, 3e-3, 1e-3, 3e-3]], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [{ + "learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3], + "grad_clip": [1.0, -1.0], + "max_norm": [0.5, -1.0] + }] + }], + "discount": [0.995], + "seed": [42], + "display_interval": [50], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_rl_humanoid.json b/configs/meta_rl_humanoid.json new file mode 100644 index 0000000..c49a665 --- /dev/null +++ b/configs/meta_rl_humanoid.json @@ -0,0 +1,43 @@ +{ + "env": [{ + "name": ["humanoid"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [0.1], + "num_envs": [2048], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["MetaPPO"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [8], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-3], + "reset_interval": [32, 64, 128, 256, 512] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "display_interval": [10], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_rl_sdl.json b/configs/meta_rl_sdl.json new file mode 100644 index 0000000..8e1ee5a --- /dev/null +++ b/configs/meta_rl_sdl.json @@ -0,0 +1,35 @@ +{ + "env": [{ + "name": [["small_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["MetaA2C"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01], + "reset_interval": [72, 144, 288, 576] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [0.03], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "discount": [0.995], + "seed": [42], + "display_interval": [50], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_rlp_ant.json b/configs/meta_rlp_ant.json new file mode 100644 index 0000000..cd62580 --- /dev/null +++ b/configs/meta_rlp_ant.json @@ -0,0 +1,43 @@ +{ + "env": [{ + "name": ["ant"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [2048], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["MetapPPO"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [8], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2], + "reset_interval": [-1] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "display_interval": [10], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_rlp_bdl.json b/configs/meta_rlp_bdl.json new file mode 100644 index 0000000..b16d2ee --- /dev/null +++ b/configs/meta_rlp_bdl.json @@ -0,0 +1,35 @@ +{ + "env": [{ + "name": [["big_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["MetapA2C"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01], + "reset_interval": [-1] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [0.003], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "discount": [0.995], + "seed": [42], + "display_interval": [50], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_rlp_humanoid.json b/configs/meta_rlp_humanoid.json new file mode 100644 index 0000000..8ba41bc --- /dev/null +++ b/configs/meta_rlp_humanoid.json @@ -0,0 +1,43 @@ +{ + "env": [{ + "name": ["humanoid"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [0.1], + "num_envs": [2048], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["MetapPPO"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [8], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-3], + "reset_interval": [-1] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "display_interval": [10], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_rlp_sdl.json b/configs/meta_rlp_sdl.json new file mode 100644 index 0000000..72054d7 --- /dev/null +++ b/configs/meta_rlp_sdl.json @@ -0,0 +1,35 @@ +{ + "env": [{ + "name": [["small_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["MetapA2C"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01], + "reset_interval": [-1] + }], + "agent_optim": [{ + "name": ["Optim4RL"], + "kwargs": [{ + "param_load_path": [""], + "learning_rate": [0.03], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "discount": [0.995], + "seed": [42], + "display_interval": [50], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_star_ant.json b/configs/meta_star_ant.json new file mode 100644 index 0000000..cef4336 --- /dev/null +++ b/configs/meta_star_ant.json @@ -0,0 +1,44 @@ +{ + "env": [{ + "name": ["ant"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [2048], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["MetaPPOstar"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [8], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2], + "reset_interval": [32, 64, 128, 256, 512] + }], + "agent_optim": [{ + "name": ["Star"], + "kwargs": [{ + "train_steps": [1e8], + "param_load_path": [""], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "display_interval": [10], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_star_bdl.json b/configs/meta_star_bdl.json new file mode 100644 index 0000000..8922f0c --- /dev/null +++ b/configs/meta_star_bdl.json @@ -0,0 +1,36 @@ +{ + "env": [{ + "name": [["big_dense_long"]], + "num_envs": [512], + "train_steps": [3e7] + }], + "agent": [{ + "name": ["MetaA2Cstar"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01], + "reset_interval": [72, 144, 288, 576] + }], + "agent_optim": [{ + "name": ["Star"], + "kwargs": [{ + "train_steps": [3e7], + "param_load_path": [""], + "learning_rate": [0.003], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "discount": [0.995], + "seed": [42], + "display_interval": [50], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_star_catch.json b/configs/meta_star_catch.json new file mode 100644 index 0000000..c1384ca --- /dev/null +++ b/configs/meta_star_catch.json @@ -0,0 +1,36 @@ +{ + "env": [{ + "name": [["catch"]], + "num_envs": [64], + "train_steps": [5e5] + }], + "agent": [{ + "name": ["MetaA2Cstar"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [20], + "critic_loss_weight": [0.5], + "entropy_weight": [0.01], + "reset_interval": [32, 64] + }], + "agent_optim": [{ + "name": ["Star"], + "kwargs": [{ + "train_steps": [5e5], + "param_load_path": [""], + "learning_rate": [1e-3], + "grad_clip": [-1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-4, 3e-4, 1e-3, 3e-3, 1e-2], "grad_norm": [0.5]}, + {"learning_rate": [1e-4, 3e-4, 1e-3, 3e-3, 1e-2], "grad_clip": [1.0]} + ] + }], + "discount": [0.995], + "seed": [42], + "display_interval": [5], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/meta_star_humanoid.json b/configs/meta_star_humanoid.json new file mode 100644 index 0000000..a0236c5 --- /dev/null +++ b/configs/meta_star_humanoid.json @@ -0,0 +1,44 @@ +{ + "env": [{ + "name": ["humanoid"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [0.1], + "num_envs": [2048], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["MetaPPOstar"], + "inner_updates": [4], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [8], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-3], + "reset_interval": [32, 64, 128, 256, 512] + }], + "agent_optim": [{ + "name": ["Star"], + "kwargs": [{ + "train_steps": [1e8], + "param_load_path": [""], + "learning_rate": [3e-4], + "grad_clip": [1] + }] + }], + "meta_optim": [{ + "name": ["Adam"], + "kwargs": [ + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_norm": [0.5]}, + {"learning_rate": [1e-5, 3e-5, 1e-4, 3e-4, 1e-3], "grad_clip": [1.0]} + ] + }], + "display_interval": [10], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/ppo_ant.json b/configs/ppo_ant.json new file mode 100644 index 0000000..01dde8e --- /dev/null +++ b/configs/ppo_ant.json @@ -0,0 +1,30 @@ +{ + "env": [{ + "name": ["ant"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [4096], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2] + }], + "optim": [{ + "name": ["RMSProp", "Adam"], + "kwargs": [{"learning_rate": [3e-4], "grad_clip": [1]}] + }], + "batch_size": [2048], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/ppo_humanoid.json b/configs/ppo_humanoid.json new file mode 100644 index 0000000..db11cee --- /dev/null +++ b/configs/ppo_humanoid.json @@ -0,0 +1,30 @@ +{ + "env": [{ + "name": ["humanoid"], + "train_steps": [1e8], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [0.1], + "num_envs": [2048], + "num_evals": [10], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [10], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [8], + "entropy_weight": [1e-3] + }], + "optim": [{ + "name": ["RMSProp", "Adam"], + "kwargs": [{"learning_rate": [3e-4], "grad_clip": [1]}] + }], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/ppo_pendulum.json b/configs/ppo_pendulum.json new file mode 100644 index 0000000..9d37513 --- /dev/null +++ b/configs/ppo_pendulum.json @@ -0,0 +1,30 @@ +{ + "env": [{ + "name": ["inverted_double_pendulum"], + "train_steps": [2e7], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [10], + "num_envs": [2048], + "num_evals": [20], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [5], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [4], + "entropy_weight": [1e-2] + }], + "optim": [{ + "name": ["RMSProp", "Adam"], + "kwargs": [{"learning_rate": [3e-4], "grad_clip": [1]}] + }], + "batch_size": [1024], + "discount": [0.97], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/ppo_walker2d.json b/configs/ppo_walker2d.json new file mode 100644 index 0000000..03e97f0 --- /dev/null +++ b/configs/ppo_walker2d.json @@ -0,0 +1,30 @@ +{ + "env": [{ + "name": ["walker2d"], + "train_steps": [5e7], + "episode_length": [1000], + "action_repeat": [1], + "reward_scaling": [5], + "num_envs": [2048], + "num_evals": [20], + "normalize_obs": [true] + }], + "agent": [{ + "name": ["PPO"], + "gae_lambda": [0.95], + "rollout_steps": [20], + "num_minibatches": [32], + "clipping_epsilon": [0.3], + "update_epochs": [8], + "entropy_weight": [1e-3] + }], + "optim": [{ + "name": ["RMSProp", "Adam"], + "kwargs": [{"learning_rate": [3e-5], "grad_clip": [1]}] + }], + "batch_size": [512], + "discount": [0.997], + "max_devices_per_host": [-1], + "seed": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/pusher_collect.json b/configs/pusher_collect.json deleted file mode 100644 index f8939bf..0000000 --- a/configs/pusher_collect.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "env": [ - { - "name": ["pusher"], - "train_steps": [5e7], - "episode_length": [100], - "action_repeat": [1], - "reward_scaling": [5], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["CollectPPO"], - "data_reduce": [100], - "gae_lambda": [0.95], - "rollout_steps": [30], - "num_minibatches": [16], - "clip_ratio": [0.3], - "update_epochs": [8], - "entropy_weight": [1e-2] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [512], - "discount": [0.95], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/pusher_lopt.json b/configs/pusher_lopt.json deleted file mode 100644 index 1473b02..0000000 --- a/configs/pusher_lopt.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "env": [ - { - "name": ["pusher"], - "train_steps": [5e7], - "episode_length": [100], - "action_repeat": [1], - "reward_scaling": [5], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [30], - "num_minibatches": [16], - "clip_ratio": [0.3], - "update_epochs": [8], - "entropy_weight": [1e-2] - } - ], - "optimizer": [ - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [ - "./logs/exp/index/meta_param_path1.pickle", - "./logs/exp/index/rnn_parameter_path2.pickle" - ], - "learning_rate": [3e-4], - "gradient_clip": [1] - } - ] - } - ], - "batch_size": [512], - "discount": [0.95], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/pusher_ppo.json b/configs/pusher_ppo.json deleted file mode 100644 index a0c0abf..0000000 --- a/configs/pusher_ppo.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "env": [ - { - "name": ["pusher"], - "train_steps": [5e7], - "episode_length": [100], - "action_repeat": [1], - "reward_scaling": [5], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [30], - "num_minibatches": [16], - "clip_ratio": [0.3], - "update_epochs": [8], - "entropy_weight": [1e-2] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [512], - "discount": [0.95], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/reacher_collect.json b/configs/reacher_collect.json deleted file mode 100644 index b5ad895..0000000 --- a/configs/reacher_collect.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "env": [ - { - "name": ["reacher"], - "train_steps": [1e8], - "episode_length": [1000], - "action_repeat": [4], - "reward_scaling": [5], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["CollectPPO"], - "data_reduce": [100], - "gae_lambda": [0.95], - "rollout_steps": [50], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [8], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [256], - "discount": [0.95], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/reacher_lopt.json b/configs/reacher_lopt.json deleted file mode 100644 index 7c2a7ff..0000000 --- a/configs/reacher_lopt.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "env": [ - { - "name": ["reacher"], - "train_steps": [1e8], - "episode_length": [1000], - "action_repeat": [4], - "reward_scaling": [5], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [50], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [8], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [ - "./logs/exp/index/meta_param_path1.pickle", - "./logs/exp/index/rnn_parameter_path2.pickle" - ], - "learning_rate": [3e-4], - "gradient_clip": [1] - } - ] - } - ], - "batch_size": [256], - "discount": [0.95], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/reacher_ppo.json b/configs/reacher_ppo.json deleted file mode 100644 index 0f71b5b..0000000 --- a/configs/reacher_ppo.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "env": [ - { - "name": ["reacher"], - "train_steps": [1e8], - "episode_length": [1000], - "action_repeat": [4], - "reward_scaling": [5], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [50], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [8], - "entropy_weight": [1e-3] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [3e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [256], - "discount": [0.95], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/sds_a2c.json b/configs/sds_a2c.json deleted file mode 100644 index 311ec2b..0000000 --- a/configs/sds_a2c.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "env": [ - { - "name": [["small_dense_short"]], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["A2C"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01] - } - ], - "agent_optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [ - { - "learning_rate": [3e-2, 1e-2, 3e-3, 1e-3, 3e-4, 1e-4], - "gradient_clip": [1] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "generate_random_seed": [true] -} diff --git a/configs/sds_lopt.json b/configs/sds_lopt.json deleted file mode 100644 index ad50737..0000000 --- a/configs/sds_lopt.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - "env": [ - { - "name": [["small_dense_short"]], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["A2C"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01] - } - ], - "agent_optimizer": [ - { - "name": ["LinearOptim"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [ - "./logs/sds_meta/index/meta_param_path1.pickle", - "./logs/sds_meta/index/rnn_parameter_path2.pickle" - ], - "learning_rate": [1e-2], - "gradient_clip": [1] - } - ] - }, - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [ - "./logs/sds_meta/index/meta_param_path1.pickle", - "./logs/sds_meta/index/rnn_parameter_path2.pickle" - ], - "learning_rate": [1e-2], - "gradient_clip": [1] - } - ] - }, - { - "name": ["L2LGD2"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [ - "./logs/sds_meta/index/meta_param_path1.pickle", - "./logs/sds_meta/index/rnn_parameter_path2.pickle" - ], - "learning_rate": [1e-2], - "gradient_clip": [1] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "generate_random_seed": [true] -} diff --git a/configs/sds_meta.json b/configs/sds_meta.json deleted file mode 100644 index 33fe1ab..0000000 --- a/configs/sds_meta.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "env": [ - { - "name": [["small_dense_sparse"]], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["MetaA2C"], - "inner_updates": [4], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01], - "reset_interval": [256, 512] - } - ], - "agent_optimizer": [ - { - "name": ["Optim4RL", "LinearOptim", "L2LGD2"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [""], - "learning_rate": [1e-2], - "gradient_clip": [1] - } - ] - } - ], - "meta_optimizer": [ - { - "name": ["Adam"], - "kwargs": [ - { - "learning_rate": [3e-5, 1e-4, 3e-4, 1e-3], - "gradient_clip": [1] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "save_param": [256], - "display_interval": [50], - "generate_random_seed": [true] -} diff --git a/configs/sds_star.json b/configs/sds_star.json deleted file mode 100644 index 2fce932..0000000 --- a/configs/sds_star.json +++ /dev/null @@ -1,49 +0,0 @@ -{ - "env": [ - { - "name": [["small_dense_short"]], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["StarA2C"], - "inner_updates": [4], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01], - "reset_interval": [256] - } - ], - "agent_optimizer": [ - { - "name": ["Star"], - "kwargs": [ - { - "train_steps": [3e7], - "step_mult": [3e-3, 1e-3, 3e-4], - "nominal_stepsize": [3e-3, 1e-3, 3e-4, 0.0], - "weight_decay": [0.0, 0.1, 0.5] - } - ] - } - ], - "meta_optimizer": [ - { - "name": ["Adam"], - "kwargs": [ - { - "learning_rate": [3e-5, 1e-4, 3e-4, 1e-3], - "gradient_clip": [1] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "save_param": [256], - "display_interval": [50], - "generate_random_seed": [true] -} diff --git a/configs/sds_star_lopt.json b/configs/sds_star_lopt.json deleted file mode 100644 index c1561c6..0000000 --- a/configs/sds_star_lopt.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "env": [ - { - "name": [["small_dense_short"]], - "num_envs": [512], - "train_steps": [3e7] - } - ], - "agent": [ - { - "name": ["A2C2"], - "gae_lambda": [0.95], - "rollout_steps": [20], - "critic_loss_weight": [0.5], - "entropy_weight": [0.01] - } - ], - "agent_optimizer": [ - { - "name": ["Star"], - "kwargs": [ - { - "train_steps": [3e7], - "param_load_path": [ - "./logs/exp/index/meta_param_path1.pickle", - "./logs/exp/index/rnn_parameter_path2.pickle" - ], - "step_mult": [3e-3], - "nominal_stepsize": [3e-3], - "weight_decay": [0.0] - } - ] - } - ], - "discount": [0.995], - "seed": [42], - "generate_random_seed": [true] -} diff --git a/configs/ur5e_collect.json b/configs/ur5e_collect.json deleted file mode 100644 index 906aa1d..0000000 --- a/configs/ur5e_collect.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "env": [ - { - "name": ["ur5e"], - "train_steps": [2e7], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [10], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["CollectPPO"], - "data_reduce": [100], - "gae_lambda": [0.95], - "rollout_steps": [5], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [4], - "entropy_weight": [1e-2] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [2e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [1024], - "discount": [0.95], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/ur5e_lopt.json b/configs/ur5e_lopt.json deleted file mode 100644 index 39e6eeb..0000000 --- a/configs/ur5e_lopt.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "env": [ - { - "name": ["ur5e"], - "train_steps": [2e7], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [10], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [5], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [4], - "entropy_weight": [1e-2] - } - ], - "optimizer": [ - { - "name": ["Optim4RL"], - "kwargs": [ - { - "mlp_dims": [[16, 16]], - "hidden_size": [8], - "param_load_path": [ - "./logs/exp/index/meta_param_path1.pickle", - "./logs/exp/index/rnn_parameter_path2.pickle" - ], - "learning_rate": [2e-4], - "gradient_clip": [1] - } - ] - } - ], - "batch_size": [1024], - "discount": [0.95], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/configs/ur5e_ppo.json b/configs/ur5e_ppo.json deleted file mode 100644 index fc8a99e..0000000 --- a/configs/ur5e_ppo.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "env": [ - { - "name": ["ur5e"], - "train_steps": [2e7], - "episode_length": [1000], - "action_repeat": [1], - "reward_scaling": [10], - "num_envs": [2048], - "num_evals": [20], - "normalize_obs": [true] - } - ], - "agent": [ - { - "name": ["PPO"], - "gae_lambda": [0.95], - "rollout_steps": [5], - "num_minibatches": [32], - "clip_ratio": [0.3], - "update_epochs": [4], - "entropy_weight": [1e-2] - } - ], - "optimizer": [ - { - "name": ["RMSProp", "Adam"], - "kwargs": [{ "learning_rate": [2e-4], "gradient_clip": [1] }] - } - ], - "batch_size": [1024], - "discount": [0.95], - "max_devices_per_host": [-1], - "seed": [1], - "generate_random_seed": [true] -} diff --git a/download.py b/download.py new file mode 100644 index 0000000..29f9932 --- /dev/null +++ b/download.py @@ -0,0 +1,13 @@ +import os +import numpy as np +import tensorflow_datasets as tfds +os.environ['NO_GCE_CHECK'] = 'true' + + +dataset = 'mnist' +ds_builder = tfds.builder(dataset, data_dir='./data/') +ds_builder.download_and_prepare() +train_data = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) +test_data = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) +np.savez(f'./data/{dataset}/train.npz', x=train_data['image'], y=train_data['label']) +np.savez(f'./data/{dataset}/test.npz', x=test_data['image'], y=test_data['label']) \ No newline at end of file diff --git a/envs/catch.py b/envs/catch.py index 3badb8a..dfc2a47 100644 --- a/envs/catch.py +++ b/envs/catch.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,58 +12,58 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial + import jax import jax.numpy as jnp -from jax import lax, random +from jax import jit, lax, random -from envs.spaces import Box, Discrete +from envs.spaces import Discrete, Box class Catch(object): - """A JAX implementation of Catch.""" + """A JAX implementation of the Catch gridworld.""" def __init__(self, rows=10, columns=5): self._rows = rows self._columns = columns - self.num_actions = 3 + self.num_as = 3 self.action_space = Discrete(3) - self.observation_space = Box(0.0, 1.0, shape=(self._rows, self._columns, 1), dtype=bool) + self.obs_space = Box(0.0, 1.0, shape=(self._rows, self._columns, 1), dtype=bool) + @partial(jit, static_argnames=['self']) def reset(self, seed): ball_y = 0 ball_x = random.randint(seed, (), 0, self._columns) paddle_y = self._rows - 1 paddle_x = self._columns // 2 - state = jnp.array([ball_y, ball_x, paddle_y, paddle_x], dtype=jnp.int32) - return lax.stop_gradient(state) + s = jnp.array([ball_y, ball_x, paddle_y, paddle_x], dtype=jnp.int32) + return s - def step(self, seed, state, action): - # Generate the next env state: next_state - paddle_x = jnp.clip(state[3] + action - 1, 0, self._columns - 1) - next_state = jnp.array([state[0] + 1, state[1], state[2], paddle_x]) - # Check if next_state is a teriminal state - done = self.is_terminal(next_state) + @partial(jit, static_argnames=['self']) + def step(self, seed, s, a): + # Generate the next env s: next_s + paddle_x = jnp.clip(s[3]+a-1, 0, self._columns-1) + next_s = jnp.array([s[0]+1, s[1], s[2], paddle_x]) + # Check if next_s is a teriminal s + done = lax.select(next_s[0] == self._rows-1, 1, 0) # Compute the reward - reward = self.reward(state, action, next_state, done) - # Reset the next_state if done - next_state = lax.select(done, self.reset(seed), next_state) - next_state = lax.stop_gradient(next_state) - return next_state, reward, done + r = lax.select(done, lax.select(next_s[1] == next_s[3], 1., -1.), 0.) + # Reset the next_s if done + next_s = lax.select(done, self.reset(seed), next_s) + return next_s, r, done - def render_obs(self, state): + @partial(jit, static_argnames=['self']) + def render_obs(self, s): def f(y, x): - return lax.select( + fn = lax.select( jnp.bitwise_or( - jnp.bitwise_and(y == state[0], x == state[1]), - jnp.bitwise_and(y == state[2], x == state[3]) - ), 1., 0.) + jnp.bitwise_and(y == s[0], x == s[1]), + jnp.bitwise_and(y == s[2], x == s[3]) + ), + 1., + 0. + ) + return fn y_board = jnp.repeat(jnp.arange(self._rows), self._columns) x_board = jnp.tile(jnp.arange(self._columns), self._rows) - return lax.stop_gradient(jax.vmap(f)(y_board, x_board).reshape((self._rows, self._columns, 1))) - - def reward(self, state, action, next_state, done): - r = lax.select(done, lax.select(next_state[1] == next_state[3], 1., -1.), 0.) - return r - - def is_terminal(self, state): - done = lax.select(state[0] == self._rows-1, 1, 0) - return done \ No newline at end of file + return jax.vmap(f)(y_board, x_board).reshape((self._rows, self._columns, 1)) \ No newline at end of file diff --git a/envs/gridworld.py b/envs/gridworld.py index da3aa83..cb5b08a 100644 --- a/envs/gridworld.py +++ b/envs/gridworld.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,37 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -import chex from typing import List +from functools import partial + +import jax +import flax import jax.numpy as jnp -from jax import lax, random +from jax import jit, lax, random -from envs.spaces import Box, Discrete +from envs.spaces import Discrete, Box -@chex.dataclass(frozen=True) +@flax.struct.dataclass class GridworldConfig: - env_map: chex.Array - empty_pos_list: chex.Array - objects: chex.Array + env_map: jax.Array + empty_pos_list: jax.Array + objects: jax.Array max_steps: int -@chex.dataclass +@flax.struct.dataclass class EnvState: - agent_pos: chex.Array - objects_pos: chex.Array + agent_pos: jax.Array + objects_pos: jax.Array time: int - -def string_to_bool_map(str_map: List[str]) -> chex.Array: - """Convert string map into boolean walking map.""" +def string_to_bool_map(str_map: List[str]) -> jax.Array: + '''Convert string map into boolean walking map.''' bool_map = [] for row in str_map: bool_map.append([r=='#' for r in row]) return jnp.array(bool_map) -def get_all_empty_pos(env_map: chex.Array) -> chex.Array: - """Get all empty positions, i.e. {(x,y): env_map[x,y]==0}""" +def get_all_empty_pos(env_map: jax.Array) -> jax.Array: + '''Get all empty positions, i.e. {(x,y): env_map[x,y]==0}''' pos_list = [] for x in range(env_map.shape[0]): for y in range(env_map.shape[1]): @@ -57,8 +59,6 @@ def get_all_empty_pos(env_map: chex.Array) -> chex.Array: - Reward: dense/sparse - Horizon: long/short - object: [reward, terminate_prob, respawn_prob] -Note: - Due to a feature/bug in Jax, please make sure all envs have different shapes for (str_map, objects). """ GridworldConfigDict = dict( small_sparse_short = dict( @@ -226,7 +226,8 @@ def get_all_empty_pos(env_map: chex.Array) -> chex.Array: class Gridworld(object): """ - A JAX implementation of Gridworlds. + A JAX implementation of the Gridworld in http://arxiv.org/abs/2007.08794. + We include the agent position in the observation. """ def __init__(self, env_name, env_cfg): env_dict = GridworldConfigDict[env_name] @@ -239,41 +240,43 @@ def __init__(self, env_name, env_cfg): env_cfg.setdefault('reward_scaling', 1.0) self.reward_scaling = env_cfg['reward_scaling'] self.num_actions = 9 - self.action_space = Discrete(self.num_actions) self.move_delta = jnp.array([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1], [1, -1], [1, 1], [-1, 1], [-1, -1]]) - self.observation_space = Box( + self.action_space = Discrete(self.num_actions) + self.obs_space = Box( low = 0, high = 1, shape = (self.num_objects+1, *self.env_map.shape), dtype = bool ) + @partial(jit, static_argnames=['self']) def reset(self, seed): pos = random.choice(seed, self.empty_pos_list, shape=(1+self.num_objects,), replace=False) agent_pos = jnp.array(pos[0]) objects_pos = jnp.array(pos[1:]) - state = EnvState( + s = EnvState( agent_pos=agent_pos, objects_pos=objects_pos, time=0 ) - return lax.stop_gradient(state) + return s - def step(self, seed, state, action): + @partial(jit, static_argnames=['self']) + def step(self, seed, s, a): # Agent move one step: if hit a wall, go back - agent_pos = state.agent_pos + self.move_delta[action] + agent_pos = s.agent_pos + self.move_delta[a] agent_pos = jnp.maximum(jnp.minimum(agent_pos, jnp.array(self.env_map.shape)-1), jnp.array([0,0])) - agent_pos = lax.select(self.env_map[agent_pos[0], agent_pos[1]]==0, agent_pos, state.agent_pos) + agent_pos = lax.select(self.env_map[agent_pos[0], agent_pos[1]]==0, agent_pos, s.agent_pos) # Collect objects and compute reward def body_func(i, carry_in): - seed, reward, done_flag, state = carry_in + seed, r, done_flag, s = carry_in seed_terminate, seed_respawn1, seed_respawn2, seed = random.split(seed, 4) # Collect the object - is_collected = jnp.logical_and(agent_pos[0]==state.objects_pos[i][0], agent_pos[1]==state.objects_pos[i][1]) + is_collected = jnp.logical_and(agent_pos[0]==s.objects_pos[i][0], agent_pos[1]==s.objects_pos[i][1]) # Compute the reward - reward = lax.select(is_collected, reward+self.objects[i][0], reward) + r += lax.select(is_collected, self.objects[i][0], 0.0) # Remove the object by changing its position to [-1,-1] - obj_pos = lax.select(is_collected, jnp.array([-1,-1]), state.objects_pos[i]) + obj_pos = lax.select(is_collected, jnp.array([-1,-1]), s.objects_pos[i]) # Terminate with probability done = jnp.logical_and(is_collected, random.uniform(seed_terminate) <= self.objects[i][1]) done_flag = lax.select(done, 1, done_flag) @@ -283,27 +286,28 @@ def body_func(i, carry_in): empty_pos = random.choice(seed_respawn2, self.empty_pos_list, shape=(), replace=False) obj_pos = lax.select(respawn, empty_pos, obj_pos) # Set object position - state = state.replace(objects_pos=state.objects_pos.at[i].set(obj_pos)) - carry_out = (seed, reward, done_flag, state) + s = s.replace(objects_pos=s.objects_pos.at[i].set(obj_pos)) + carry_out = (seed, r, done_flag, s) return carry_out - reward, done_flag = 0., 0 - carry_in = (seed, reward, done_flag, state) - seed, reward, done_flag, state = lax.fori_loop(0, self.num_objects, body_func, carry_in) + r, done_flag = 0., 0 + carry_in = (seed, r, done_flag, s) + seed, r, done_flag, s = lax.fori_loop(0, self.num_objects, body_func, carry_in) # Generate the next env state - next_state = EnvState(agent_pos=agent_pos, objects_pos=state.objects_pos, time=state.time+1) + next_s = EnvState(agent_pos=agent_pos, objects_pos=s.objects_pos, time=s.time+1) # Check if next_state is a teriminal state - done = lax.select(done_flag, 1, self.is_terminal(next_state)) + done = lax.select(done_flag, 1, self.is_terminal(next_s)) # Reset the next_state if done - reset_state = self.reset(seed) - next_state.agent_pos = lax.select(done, reset_state.agent_pos, next_state.agent_pos) - next_state.objects_pos = lax.select(done, reset_state.objects_pos, next_state.objects_pos) - next_state.time = lax.select(done, reset_state.time, next_state.time) - # Generate the next env state - next_state = lax.stop_gradient(next_state) - return next_state, self.reward_scaling*reward, done + reset_s = self.reset(seed) + next_s = next_s.replace( + agent_pos = lax.select(done, reset_s.agent_pos, next_s.agent_pos), + objects_pos = lax.select(done, reset_s.objects_pos, next_s.objects_pos), + time = lax.select(done, reset_s.time, next_s.time) + ) + return next_s, self.reward_scaling*r, done + @partial(jit, static_argnames=['self']) def render_obs(self, state): - obs_map = jnp.zeros(self.observation_space.shape) + obs_map = jnp.zeros(self.obs_space.shape) # Render objects def body_func(i, maps): obj_is_present = jnp.logical_and(state.objects_pos[i][0]>=0, state.objects_pos[i][1]>=0) @@ -314,6 +318,7 @@ def body_func(i, maps): obs_map = obs_map.at[-1, state.agent_pos[0], state.agent_pos[1]].set(1) return lax.stop_gradient(obs_map) + @partial(jit, static_argnames=['self']) def is_terminal(self, state): done = lax.select(state.time >= self.max_steps, 1, 0) return done \ No newline at end of file diff --git a/envs/random_walk.py b/envs/random_walk.py deleted file mode 100644 index c741b45..0000000 --- a/envs/random_walk.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2022 Garena Online Private Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax -import jax.numpy as jnp -from jax import lax - -from envs.spaces import Box, Discrete - - -class RandomWalk(object): - """ - A JAX implementation of Random Walk (Example 6.2 in Rich Sutton's RL introduction book). - T <--- A(0) <---> B(1) <---> C(2) <---> D(3) <---> E(4) ---> T - True values without discount: A(1/6), B(2/6), C(3/6), D(4/6), E(5/6) - """ - def __init__(self): - self.num_actions = 2 - self.num_states = 5 - self.action_space = Discrete(2) - self.observation_space = Box(0, 1, shape=(5,), dtype=bool) - - def reset(self, seed): - state = jnp.array(2) - return lax.stop_gradient(state) - - def step(self, seed, state, action): - # Generate the next env state: next_state - next_state = jnp.array(state+2*action-1) - # Check if next_state is a teriminal state - done = self.is_terminal(next_state) - # Compute the reward - reward = lax.select(next_state==self.num_states, 1.0, 0.0) - # Reset the next_state if done - next_state = lax.select(done, self.reset(seed), next_state) - next_state = lax.stop_gradient(next_state) - return next_state, reward, done - - def render_obs(self, state): - obs = jax.nn.one_hot(state, self.num_states) - return lax.stop_gradient(obs) - - def is_terminal(self, state): - done = jnp.logical_or(state==-1, state==self.num_states) - return done \ No newline at end of file diff --git a/envs/spaces.py b/envs/spaces.py index 56341d4..18e03ab 100644 --- a/envs/spaces.py +++ b/envs/spaces.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,9 +17,9 @@ class Discrete(object): - """ + ''' Minimal jittable class for discrete gymnax spaces. - """ + ''' def __init__(self, n): assert n >= 0 self.n = n diff --git a/envs/utils.py b/envs/utils.py index 942f0f8..afda8a1 100644 --- a/envs/utils.py +++ b/envs/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax -import jax.numpy as jnp -from jax.tree_util import tree_map - -import brax.envs -from envs.spaces import Box from envs.catch import Catch -from envs.random_walk import RandomWalk from envs.gridworld import Gridworld, GridworldConfigDict @@ -28,85 +21,5 @@ def make_env(env_name, env_cfg): return Catch() elif env_name in GridworldConfigDict.keys(): return Gridworld(env_name, env_cfg) - elif env_name == 'random_walk': - return RandomWalk() - elif is_gymnax_env(env_name): - import gymnax - env, env_param = gymnax.make(env_name) - if env_name == 'MountainCar-v0': - object.__setattr__(env_param, 'max_steps_in_episode', 1000) - return GymnaxWrapper(env, env_param) else: - raise NameError('Please choose a valid environment name!') - - -def is_gymnax_env(env_name): - if env_name in ['Pendulum-v1', 'CartPole-v1', 'MountainCar-v0', 'MountainCarContinuous-v0', 'Acrobot-v1']: - return True - if 'bsuite' in env_name: - return True - if 'MinAtar' in env_name: - return True - if 'misc' in env_name: - return True - return False - - -class GymnaxWrapper(object): - """A wrapper for gymnax games""" - def __init__(self, env, env_param): - self.env = env - self.env_param = env_param - self.action_space = env.action_space(env_param) - self.observation_space = env.observation_space(env_param) - - def reset(self, seed): - obs, state = self.env.reset(seed, self.env_param) - return state - - def step(self, seed, state, action): - next_obs, next_state, reward, done, info = self.env.step(seed, state, action, self.env_param) - return next_state, reward, done - - def render_obs(self, state): - try: - return self.env.get_obs(state) - except Exception: - return self.env.get_obs(state, self.env_param) - - def is_terminal(self, state): - return self.env.is_terminal(state, self.env_param) - - -brax_envs = ['acrobot', 'ant', 'fast', 'fetch', 'grasp', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'reacherangle', 'swimmer', 'ur5e', 'walker2d'] - -def make_brax_env(env_name: str, episode_length: int=1000, action_repeat: int=1): - assert env_name in brax_envs, f'{env_name} is not a Brax environment!' - env = brax.envs.create( - env_name = env_name, - episode_length = episode_length, - action_repeat = action_repeat, - auto_reset = False - ) - return BraxWrapper(env) - -class BraxWrapper(object): - """A wrapper for gymnax games""" - def __init__(self, env): - self.env = env - self.action_space = Box(low=-1.0, high=1.0, shape=(self.env.action_size,)) - self.observation_space = Box(low=-jnp.inf, high=jnp.inf, shape=(self.env.observation_size,)) - - def reset(self, seed): - state = self.env.reset(seed) - return lax.stop_gradient(state) - - def step(self, seed, state, action): - next_state = lax.stop_gradient(self.env.step(state, action)) - reward, done = next_state.reward, next_state.done - reset_state = self.reset(seed) - next_state = tree_map(lambda reset_s, next_s: lax.select(done>0, reset_s, next_s), reset_state, next_state) - return next_state, reward, done - - def render_obs(self, state): - return state.obs \ No newline at end of file + raise NameError('Please choose a valid environment name!') \ No newline at end of file diff --git a/experiment.py b/experiment.py index bf1b4d7..b2bc9e8 100644 --- a/experiment.py +++ b/experiment.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import copy -import json import time +import json import numpy as np import agents @@ -22,22 +24,23 @@ class Experiment(object): - """ + ''' Train the agent to play the game. - """ + ''' def __init__(self, cfg): self.cfg = copy.deepcopy(cfg) self.config_idx = cfg['config_idx'] self.agent_name = cfg['agent']['name'] if self.cfg['generate_random_seed']: self.cfg['seed'] = np.random.randint(int(1e6)) + self.model_path = self.cfg['model_path'] self.cfg_path = self.cfg['cfg_path'] self.save_config() def run(self): - """ + ''' Run the game for multiple steps - """ + ''' self.start_time = time.time() set_random_seed(self.cfg['seed']) self.agent = getattr(agents, self.agent_name)(self.cfg) diff --git a/main.py b/main.py index b4b3345..7373b93 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,22 +19,10 @@ from utils.helper import make_dir from utils.sweeper import Sweeper -# from jax.config import config -# config.update('jax_disable_jit', True) -# config.update("jax_debug_nans", True) -# config.update("jax_enable_x64", True) -# config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) -# Set a specific platform -# config.update('jax_platform_name', 'cpu') - -# Fake devices -# import os -# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2' - def main(argv): parser = argparse.ArgumentParser(description="Config file") - parser.add_argument('--config_file', type=str, default='./configs/sds_a2c.json', help='Configuration file') + parser.add_argument('--config_file', type=str, default='./configs/catch.json', help='Configuration file') parser.add_argument('--config_idx', type=int, default=1, help='Configuration index') args = parser.parse_args() @@ -49,6 +37,7 @@ def main(argv): cfg['exp'] = args.config_file.split('/')[-1].split('.')[0] cfg['logs_dir'] = f"./logs/{cfg['exp']}/{cfg['config_idx']}/" make_dir(f"./logs/{cfg['exp']}/{cfg['config_idx']}/") + cfg['model_path'] = cfg['logs_dir'] + 'model.pt' cfg['cfg_path'] = cfg['logs_dir'] + 'config.json' exp = Experiment(cfg) diff --git a/requirements.txt b/requirements.txt index 94123c2..89cab86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,13 @@ -flax>=0.6.3 -chex==0.1.5 -psutil>=5.9.0 -pandas>=1.4.0 -rlax==0.1.5 -gymnax==0.0.5 -seaborn==0.12.2 -optax==0.1.4 -pyarrow>=8.0.0 \ No newline at end of file +jax==0.4.19 +distrax==0.1.5 +chex==0.1.85 +optax==0.1.7 +flax==0.7.5 +rlax==0.1.6 +brax==0.9.4 +tensorflow==2.8.4 +matplotlib +psutil +pandas +pyarrow +seaborn \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..9c0daa0 --- /dev/null +++ b/run.sh @@ -0,0 +1,86 @@ +clear + +# Download MNIST +python download.py + +# Collect +python main.py --config_file ./configs/collect_mnist.json --config_idx 1 +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/collect_bdl.json --config_idx {1} ::: $(seq 1 2) + +# A2C +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/a2c_catch.json --config_idx {1} ::: $(seq 1 20) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/a2c_grid.json --config_idx {1} ::: $(seq 1 120) + +# Meta A2C jobs +## Catch +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_rl_catch.json --config_idx {1} ::: $(seq 1 20) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_l2l_catch.json --config_idx {1} ::: $(seq 1 20) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_lin_catch.json --config_idx {1} ::: $(seq 1 20) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_star_catch.json --config_idx {1} ::: $(seq 1 20) +## sdl +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_rl_sdl.json --config_idx {1} ::: $(seq 1 40) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_rlp_sdl.json --config_idx {1} ::: $(seq 1 10) +## bdl +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_rl_bdl.json --config_idx {1} ::: $(seq 1 40) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_rlp_bdl.json --config_idx {1} ::: $(seq 1 10) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_l2l_bdl.json --config_idx {1} ::: $(seq 1 20) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_lin_bdl.json --config_idx {1} ::: $(seq 1 20) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_star_bdl.json --config_idx {1} ::: $(seq 1 20) +## Gridworld +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_rl_grid.json --config_idx {1} ::: $(seq 1 96) + +# Lopt A2C jobs +### catch +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rl_catch.json --config_idx {1} ::: $(seq 1 200) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_l2l_catch.json --config_idx {1} ::: $(seq 1 200) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_lin_catch.json --config_idx {1} ::: $(seq 1 200) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_star_catch.json --config_idx {1} ::: $(seq 1 200) +## sdl +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rl_sdl.json --config_idx {1} ::: $(seq 1 400) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rlp_sdl.json --config_idx {1} ::: $(seq 1 100) +## bdl +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rl_bdl.json --config_idx {1} ::: $(seq 1 400) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rlp_bdl.json --config_idx {1} ::: $(seq 1 100) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_lin_bdl.json --config_idx {1} ::: $(seq 1 400) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_l2l_bdl.json --config_idx {1} ::: $(seq 1 400) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_star_bdl.json --config_idx {1} ::: $(seq 1 400) + +# PPO +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/ppo_ant.json --config_idx {1} ::: $(seq 1 20) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/ppo_humanoid.json --config_idx {1} ::: $(seq 1 20) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/ppo_pendulum.json --config_idx {1} ::: $(seq 1 20) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/ppo_walker2d.json --config_idx {1} ::: $(seq 1 20) + +# Meta PPO jobs +## Ant +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_rl_ant.json --config_idx {1} ::: $(seq 1 50) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_rlp_ant.json --config_idx {1} ::: $(seq 1 10) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_l2l_ant.json --config_idx {1} ::: $(seq 1 50) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_lin_ant.json --config_idx {1} ::: $(seq 1 50) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_star_ant.json --config_idx {1} ::: $(seq 1 50) +## Humanoid +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_rl_humanoid.json --config_idx {1} ::: $(seq 1 50) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_rlp_humanoid.json --config_idx {1} ::: $(seq 1 10) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_l2l_humanoid.json --config_idx {1} ::: $(seq 1 50) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_lin_humanoid.json --config_idx {1} ::: $(seq 1 50) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/meta_star_humanoid.json --config_idx {1} ::: $(seq 1 50) + +# Lopt PPO jobs +## Ant +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rl_ant.json --config_idx {1} ::: $(seq 1 500) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rlp_ant.json --config_idx {1} ::: $(seq 1 100) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_lin_ant.json --config_idx {1} ::: $(seq 1 500) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_l2l_ant.json --config_idx {1} ::: $(seq 1 500) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_star_ant.json --config_idx {1} ::: $(seq 1 500) +## Humanoid +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rl_humanoid.json --config_idx {1} ::: $(seq 1 500) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rlp_humanoid.json --config_idx {1} ::: $(seq 1 100) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_lin_humanoid.json --config_idx {1} ::: $(seq 1 500) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_l2l_humanoid.json --config_idx {1} ::: $(seq 1 500) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_star_humanoid.json --config_idx {1} ::: $(seq 1 500) + +## Lopt: Gridworld --> Brax +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rl_grid_ant.json --config_idx {1} ::: $(seq 1 960) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rl_grid_humanoid.json --config_idx {1} ::: $(seq 1 960) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rl_grid_pendulum.json --config_idx {1} ::: $(seq 1 960) +parallel --eta --ungroup --j 1 python main.py --config_file ./configs/lopt_rl_grid_walker2d.json --config_idx {1} ::: $(seq 1 960) \ No newline at end of file diff --git a/utils/dataloader.py b/utils/dataloader.py new file mode 100644 index 0000000..2f4601a --- /dev/null +++ b/utils/dataloader.py @@ -0,0 +1,51 @@ +import jax +import numpy as np +from jax import random +import jax.numpy as jnp + + +def load_data(dataset, seed=None, batch_size=1): + dataset = dataset.lower() + assert dataset == 'mnist' + dummy_input = jnp.ones([1, 28, 28, 1]) + data_path = f'./data/{dataset}/' + train_file, test_file = np.load(data_path+'train.npz'), np.load(data_path+'test.npz') + train_data = dict( + x = jnp.float32(train_file['x']) / 255., + y = jnp.array(train_file['y']) + ) + test_data = dict( + x = jnp.float32(test_file['x']) / 255., + y = jnp.array(test_file['y']) + ) + train_data = make_batches(train_data, batch_size, seed) + data_loader = { + 'Train': train_data, + 'Test': test_data, + 'dummy_input': dummy_input + } + return data_loader + + +def make_batches(data, batch_size, seed=None): + # Sort data according y label + index = jnp.argsort(data['y']) + data['x'] = data['x'][index] + data['y'] = data['y'][index] + # Find indexes for each label + label_idxs = np.where(data['y'][1:] != data['y'][:-1])[0] + 1 + label_idxs = np.insert(label_idxs, 0, 0) + min_num_batch = np.min(label_idxs[1:] - label_idxs[:-1]) // batch_size + # Truncate into batches and shuffle the training dataset + assert seed is not None, 'Need a random seed.' + new_idxs = jnp.array([], dtype=int) + for i in range(0, len(label_idxs)): + seed, shuffle_seed = random.split(seed) + unshuffled_idxs = jnp.array([], dtype=int) + idxs = jnp.array(range(label_idxs[i], label_idxs[i]+min_num_batch*batch_size), dtype=int) + unshuffled_idxs = jnp.append(unshuffled_idxs, idxs) + shuffled_idxs = jax.random.permutation(shuffle_seed, unshuffled_idxs, independent=True) + new_idxs = jnp.append(new_idxs, shuffled_idxs) + data['x'] = data['x'][new_idxs] + data['y'] = data['y'][new_idxs] + return data \ No newline at end of file diff --git a/utils/helper.py b/utils/helper.py index 6d2a2ac..927e3e7 100644 --- a/utils/helper.py +++ b/utils/helper.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,104 +12,129 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime import os import random -import sys +import pickle +import psutil +import datetime +import numpy as np +from functools import partial +import jax import jax.numpy as jnp -import numpy as np -import psutil -from jax import tree_util +from jax import jit, lax, tree_util def get_time_str(): return datetime.datetime.now().strftime("%y.%m.%d-%H:%M:%S") + def rss_memory_usage(): - """ + ''' Return the resident memory usage in MB - """ + ''' process = psutil.Process(os.getpid()) mem = process.memory_info().rss / float(2 ** 20) return mem -def str_to_class(module_name, class_name): - """ - Convert string to class - """ - return getattr(sys.modules[module_name], class_name) def set_random_seed(seed): - """ + ''' Set all random seeds - """ + ''' random.seed(seed) np.random.seed(seed) + def make_dir(dir): if not os.path.exists(dir): os.makedirs(dir, exist_ok=True) + +jitted_split = partial(jit, static_argnames=['num'])(jax.random.split) + + +def pytree2array(values): + leaves = tree_util.tree_leaves(lax.stop_gradient(values)) + a = jnp.concatenate(leaves, axis=None) + return a + + def tree_stack(trees, axis=0): - """ + ''' From: https://gist.github.com/willwhitney/dd89cac6a5b771ccff18b06b33372c75 Takes a list of trees and stacks every corresponding leaf. For example, given two trees ((a, b), c) and ((a', b'), c'), returns ((stack(a, a'), stack(b, b')), stack(c, c')). Useful for turning a list of objects into something you can feed to a vmapped function. - """ + ''' leaves_list = [] treedef_list = [] for tree in trees: leaves, treedef = tree_util.tree_flatten(tree) leaves_list.append(leaves) treedef_list.append(treedef) - grouped_leaves = zip(*leaves_list) - result_leaves = [jnp.stack(leaf, axis=axis) for leaf in grouped_leaves] + result_leaves = [jnp.stack(l, axis=axis) for l in grouped_leaves] return treedef_list[0].unflatten(result_leaves) -def tree_transpose(list_of_trees): - """ - Convert a list of trees of identical structure into a single tree of lists. - Act the same as tree_stack - """ - return tree_util.tree_map(lambda *xs: jnp.array(xs), *list_of_trees) def tree_unstack(tree): - """ + ''' From: https://gist.github.com/willwhitney/dd89cac6a5b771ccff18b06b33372c75 Takes a tree and turns it into a list of trees. Inverse of tree_stack. For example, given a tree ((a, b), c), where a, b, and c all have first dimension k, will make k trees [((a[0], b[0]), c[0]), ..., ((a[k], b[k]), c[k])] Useful for turning the output of a vmapped function into normal objects. - """ + ''' leaves, treedef = tree_util.tree_flatten(tree) n_trees = leaves[0].shape[0] new_leaves = [[] for _ in range(n_trees)] for leaf in leaves: for i in range(n_trees): new_leaves[i].append(leaf[i]) - new_trees = [treedef.unflatten(leaf) for leaf in new_leaves] + new_trees = [treedef.unflatten(l) for l in new_leaves] return new_trees + +def tree_transpose(list_of_trees): + ''' + Convert a list of trees of identical structure into a single tree of lists. + We can replace tree_stack with tree_transpose. + JAX also provides jax.tree_transpose, + which is more verbose, but allows you specify the structure of the inner and outer Pytree for more flexibility. + ''' + return tree_util.tree_map(lambda *xs: jnp.array(xs), *list_of_trees) + + def tree_concatenate(trees): - """ + ''' Adapted from tree_stack. Takes a list of trees and stacks every corresponding leaf. For example, given two trees ((a, b), c) and ((a', b'), c'), returns ((concatenate(a, a'), concatenate(b, b')), concatenate(c, c')). - """ + ''' leaves_list = [] treedef_list = [] for tree in trees: leaves, treedef = tree_util.tree_flatten(tree) leaves_list.append(leaves) treedef_list.append(treedef) - grouped_leaves = zip(*leaves_list) - result_leaves = [jnp.concatenate(leaf) for leaf in grouped_leaves] - return treedef_list[0].unflatten(result_leaves) \ No newline at end of file + result_leaves = [jnp.concatenate(l) for l in grouped_leaves] + return treedef_list[0].unflatten(result_leaves) + + +def save_model_param(model_param, filepath): + with open(filepath, 'wb') as f: + pickle.dump(model_param, f) + + +def load_model_param(filepath): + f = open(filepath, 'rb') + model_param = pickle.load(f) + model_param = tree_util.tree_map(jnp.array, model_param) + f.close() + return model_param \ No newline at end of file diff --git a/utils/logger.py b/utils/logger.py index af1a751..f5d4818 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/utils/plotter.py b/utils/plotter.py index cd57853..cff4c6e 100644 --- a/utils/plotter.py +++ b/utils/plotter.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,33 +13,37 @@ # limitations under the License. import os +import copy import json -import math import numpy as np import pandas as pd -import seaborn as sns +import seaborn as sns; sns.set(style="ticks"); sns.set_context("notebook") # paper, talk, notebook import matplotlib import matplotlib.pyplot as plt +from matplotlib.ticker import FuncFormatter +# Set font family, bold, and font size +font = {'size': 16} # font = {'family':'normal', 'weight':'normal', 'size': 12} +matplotlib.rc('font', **font) +# Avoid Type 3 fonts: http://phyletica.org/matplotlib-fonts/ +matplotlib.rcParams['pdf.fonttype'] = 42 +matplotlib.rcParams['ps.fonttype'] = 42 +plt.rcParams['axes.autolimit_mode'] = 'round_numbers' +plt.rcParams['axes.xmargin'] = 0 +plt.rcParams['axes.ymargin'] = 0 + from utils.helper import make_dir from utils.sweeper import Sweeper -# Set plot style -sns.set(style="ticks") -sns.set_context("paper") -#sns.set_context("talk") - -# Avoid Type 3 fonts in matplotlib plots: http://phyletica.org/matplotlib-fonts/ -matplotlib.rcParams['pdf.fonttype'] = 42 -matplotlib.rcParams['ps.fonttype'] = 42 - class Plotter(object): def __init__(self, cfg): # Set default value for symmetric EMA (exponential moving average) # Note that EMA only works when merged is True cfg.setdefault('EMA', False) - cfg.setdefault('ci', None) + cfg.setdefault('ci', ('ci', 95)) + cfg.setdefault('rolling_score_window', -1) + self.cfg = copy.deepcopy(cfg) # Copy parameters self.exp = cfg['exp'] self.merged = cfg['merged'] @@ -49,14 +53,17 @@ def __init__(self, cfg): self.hue_label = cfg['hue_label'] self.show = cfg['show'] self.imgType = cfg['imgType'] - self.ci = cfg['ci'] + if type(cfg['ci']) == int: + self.ci = ('ci', cfg['ci']) + else: + self.ci = cfg['ci'] self.EMA = cfg['EMA'] - """ Set sweep_keys: + ''' Set sweep_keys: For a hierarchical dict, all keys along the path are cancatenated with '/' into one key. - For example, for config_dict = {"env": {"name": ["Catch"]}}, - key = 'env': return the whole dict {"name": ["Catch"]}, - key = 'env/name': return ["Catch"]. - """ + For example, for config_dict = {"env": {"names": ["Catch-bsuite"]}}, + key = 'env': return the whole dict {"names": ["Catch-bsuite"]}, + key = 'env/names': return ["Catch-bsuite"]. + ''' self.sweep_keys = cfg['sweep_keys'] self.sort_by = cfg['sort_by'] self.ascending = cfg['ascending'] @@ -66,9 +73,9 @@ def __init__(self, cfg): self.total_combination = get_total_combination(self.exp) def merge_index(self, config_idx, mode, processed, exp=None): - """ + ''' Given exp and config index, merge the results of multiple runs - """ + ''' if exp is None: exp = self.exp result_list = [] @@ -87,7 +94,7 @@ def merge_index(self, config_idx, mode, processed, exp=None): # Do symmetric EMA (exponential moving average) only # when we want the original data (i.e. no processed) - if (self.EMA) and (not processed): + if (self.EMA) and (processed == False): # Get x's and y's in form of numpy arries xs, ys = [], [] for result in result_list: @@ -106,11 +113,11 @@ def merge_index(self, config_idx, mode, processed, exp=None): result_list[i] = result_list[i][:n] result_list[i].loc[:, self.x_label] = new_x result_list[i].loc[:, self.y_label] = new_y - elif not processed: + elif processed == False: # Moving average if self.rolling_score_window > 0: for i in range(len(result_list)): - y = result_list[i][self.y_label].to_numpy() + x, y = result_list[i][self.x_label].to_numpy(), result_list[i][self.y_label].to_numpy() y = moving_average(y, self.rolling_score_window) result_list[i].loc[:, self.x_label] = new_x result_list[i].loc[:, self.y_label] = new_y @@ -122,19 +129,19 @@ def merge_index(self, config_idx, mode, processed, exp=None): return result_list def get_result(self, exp, config_idx, mode, get_process_result_dict=None): - """ + ''' Return: (merged, processed) result - if (merged == True) or (get_process_result_dict is not None): Return a list of (processed) result for all runs. - if (merged == False): Return unmerged result of one single run in a list. - """ + ''' if get_process_result_dict is not None: processed = True else: processed = False - if self.merged or processed: + if self.merged == True or processed == True: # Merge results print(f'[{exp}]: Merge {mode} results: {config_idx}/{get_total_combination(exp)}') result_list = self.merge_index(config_idx, mode, processed, exp) @@ -157,36 +164,36 @@ def get_result(self, exp, config_idx, mode, get_process_result_dict=None): return [result] def plot_vanilla(self, data, image_path): - """ + ''' Plot results for data: data = [result_1_list, result_2_list, ...] result_i_list = [result_run_1, result_run_2, ...] result_run_i is a Dataframe - """ + ''' fig, ax = plt.subplots() for i in range(len(data)): # Convert to numpy array ys = [] for result in data[i]: ys.append(result[self.y_label].to_numpy()) - # Compute x_mean, y_mean and y_ci + # Put all results in a dataframe ys = np.array(ys) x_mean = data[i][0][self.x_label].to_numpy() - y_mean = np.mean(ys, axis=0) - if self.ci == 'sd': - y_ci = np.std(ys, axis=0, ddof=0) - elif self.ci == 'se': - y_ci = np.std(ys, axis=0, ddof=0)/math.sqrt(len(ys)) - # Plot - plt.plot(x_mean, y_mean, linewidth=1.0, label=data[i][0][self.hue_label][0]) - if self.ci in ['sd', 'se']: - plt.fill_between(x_mean, y_mean - y_ci, y_mean + y_ci, alpha=0.5) - - # ax.set_title(title) - ax.legend(loc=self.loc) - ax.set_xlabel(self.x_label) - ax.set_ylabel(self.y_label) - ax.get_figure().savefig(image_path) + runs = len(data[i]) + x = np.tile(x_mean, runs) + y = ys.reshape((-1)) + result_df = pd.DataFrame(list(zip(x, y)), columns=['x', 'y']) + sns.lineplot( + data=result_df, x='x', y='y', + estimator=self.cfg['estimator'], + errorbar=self.ci, err_kws={'alpha':0.5}, + linewidth=1.0, label=data[i][0][self.hue_label][0] + ) + plt.legend(loc=self.loc) + plt.xlabel(self.x_label) + plt.ylabel(self.y_label) + plt.tight_layout() + plt.savefig(image_path) if self.show: plt.show() plt.clf() # clear figure @@ -194,33 +201,33 @@ def plot_vanilla(self, data, image_path): plt.close() # close window def plot_indexList(self, indexList, mode, image_name): - """ + ''' Func: Given (config index) list and mode - merged == True: plot merged result for all runs. - merged == False: plot unmerged result of one single run. - """ + ''' expIndexModeList = [] for x in indexList: expIndexModeList.append([self.exp, x ,mode]) self.plot_expIndexModeList(expIndexModeList, image_name) def plot_indexModeList(self, indexModeList, image_name): - """ + ''' Func: Given (config index, mode) list - merged == True: plot merged result for all runs. - merged == False: plot unmerged result of one single run. - """ + ''' expIndexModeList = [] for x in indexModeList: expIndexModeList.append([self.exp] + x) self.plot_expIndexModeList(expIndexModeList, image_name) def plot_expIndexModeList(self, expIndexModeList, image_name): - """ + ''' Func: Given (exp, config index, mode) list - merged == True: plot merged result for all runs. - merged == False: plot unmerged result of one single run. - """ + ''' # Get results results = [] for exp, config_idx, mode in expIndexModeList: @@ -242,9 +249,9 @@ def plot_expIndexModeList(self, expIndexModeList, image_name): self.plot_vanilla(results, image_path) def plot_results(self, mode, indexes='all'): - """ + ''' Plot merged result for all config indexes - """ + ''' if indexes == 'all': if self.merged: indexes = range(1, self.total_combination+1) @@ -264,10 +271,10 @@ def plot_results(self, mode, indexes='all'): image_path = f'./logs/{self.exp}/{config_idx}/{self.y_label}_{mode}.{self.imgType}' self.plot_vanilla([result_list], image_path) - def csv_results(self, mode, get_csv_result_dict, get_process_result_dict): - """ - Show results: generate a *.csv file that store all merged results - """ + def csv_merged_results(self, mode, get_csv_result_dict, get_process_result_dict): + ''' + Show results: generate a *.csv file that store all **merged** results + ''' new_result_list = [] for config_idx in range(1, self.total_combination+1): print(f'[{self.exp}]: CSV {mode} results: {config_idx}/{self.total_combination}') @@ -278,7 +285,11 @@ def csv_results(self, mode, get_csv_result_dict, get_process_result_dict): # Get test results dict result_dict = get_csv_result_dict(result, config_idx, mode) # Expand test result dict from config dict - config_file = f'./logs/{self.exp}/{config_idx}/config.json' + for i in range(self.runs): + config_file = f'./logs/{self.exp}/{config_idx+i*self.total_combination}/config.json' + if os.path.exists(config_file): + break + with open(config_file, 'r') as f: config_dict = json.load(f) for key in self.sweep_keys: @@ -290,16 +301,98 @@ def csv_results(self, mode, get_csv_result_dict, get_process_result_dict): return make_dir(f'./logs/{self.exp}/0/') results = pd.DataFrame(new_result_list) - # Sort by mean and ste of test result label value + # Sort by mean and se of result label value sorted_results = results.sort_values(by=self.sort_by, ascending=self.ascending) - # Save sorted test results into a .feather file - sorted_results_file = f'./logs/{self.exp}/0/results_{mode}.csv' + # Save sorted results into a .feather file + sorted_results_file = f'./logs/{self.exp}/0/results_{mode}_merged.csv' sorted_results.to_csv(sorted_results_file, index=False) + def csv_unmerged_results(self, mode, get_process_result_dict): + ''' + Show results: generate a *.csv file that store all **unmerged** results + ''' + new_result_list = [] + for config_idx in range(1, self.runs*self.total_combination+1): + print(f'[{self.exp}]: CSV {mode} results: {config_idx}/{self.runs*self.total_combination}') + result_file = f'./logs/{self.exp}/{config_idx}/result_{mode}.feather' + config_file = f'./logs/{self.exp}/{config_idx}/config.json' + result = read_file(result_file) + if result is None or not os.path.exists(config_file): + continue + result = get_process_result_dict(result, config_idx, mode) + with open(config_file, 'r') as f: + config_dict = json.load(f) + for key in self.sweep_keys: + result[key] = find_key_value(config_dict, key.split('/')) + new_result_list.append(result) + if len(new_result_list) == 0: + print(f'[{self.exp}]: No {mode} results') + return + make_dir(f'./logs/{self.exp}/0/') + results = pd.DataFrame(new_result_list) + # Save results into a .feather file + results_file = f'./logs/{self.exp}/0/results_{mode}_unmerged.csv' + results.to_csv(results_file, index=False) + + def compare_parameter(self, param_name, perf_name=None, image_name=None, constraints=[], mode='Train', stat='count', kde=False): + ''' + Plot histograms for hyper-parameter selection. + perf_name: the performance metric from results_{mode}.csv, such as Return (mean). + param_name: the name of considered hyper-parameter, such lr. + image_name: the name of the plotted image. + constraints: a list of tuple (k, [x,y,...]). We only consider index with config_dict[k] in [x,y,...]. + mode: Train or Test. + stat: for seaborn plot function + kde: if True, plot all kdes (kernel density estimations) in one figure; o.w. plot histograms in different subfigures + ''' + param_name_short = param_name.split('/')[-1] + if image_name is None: + image_name = param_name_short + if perf_name is None: + perf_name = f'{self.y_label} (mean)' + config_file = f'./configs/{self.exp}.json' + results_file = f'./logs/{self.exp}/0/results_{mode}_unmerged.csv' + if kde: + image_path = f'./logs/{self.exp}/0/{image_name}_{mode}_kde.{self.imgType}' + else: + image_path = f'./logs/{self.exp}/0/{image_name}_{mode}.{self.imgType}' + assert os.path.exists(results_file), f'{results_file} does not exist. Please generate it first with csv_unmerged_results.' + assert os.path.exists(config_file), f'{config_file} does not exist.' + # Load all results + results = pd.read_csv(results_file) + # Select results based on the constraints and param_name + for k, vs in constraints: + results = results.loc[lambda df: df[k].isin(vs), :] + results = results.loc[:, [perf_name, param_name]] + results.rename(columns={param_name: param_name_short}, inplace=True) + # Plot + param_values = sorted(list(set(results[param_name_short]))) + if len(param_values) == 1 and param_values[0] == '/': + return + if kde: # Plot all kdes in one figure + fig, ax = plt.subplots() + # sns.histplot(data=results, x=perf_name, hue=param_name_short, kde=True, stat=stat, palette='bright', discrete=True) + sns.kdeplot(data=results, x=perf_name, hue=param_name_short, palette='bright') + ax.grid(axis='y') + else: # Plot histograms in different subfigures + fig, axs = plt.subplots(len(param_values), 1, sharex=True, sharey=True, figsize=(7, 3*len(param_values))) + if len(param_values) == 1: + axs = [axs] + for i, param_v in enumerate(param_values): + sns.histplot(data=results[results[param_name_short]==param_v], x=perf_name, hue=param_name_short, kde=False, stat=stat, palette='bright', ax=axs[i], discrete=True) + axs[i].grid(axis='y') + plt.xlabel(perf_name) + plt.tight_layout() + plt.savefig(image_path) + if self.show: + plt.show() + plt.clf() # clear figure + plt.cla() # clear axis + plt.close() # close window def one_sided_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1.0, low_counts_threshold=0.0): - """ Copy from baselines.common.plot_util + ''' Copy from baselines.common.plot_util Functionality: perform one-sided (causal) EMA (exponential moving average) smoothing and resampling to an even grid with n points. @@ -319,7 +412,7 @@ def one_sided_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1.0, low xs - array with new x grid ys - array of EMA of y at each point of the new x grid count_ys - array of EMA of y counts at each point of the new x grid - """ + ''' low = xolds[0] if low is None else low high = xolds[-1] if high is None else high @@ -342,16 +435,14 @@ def one_sided_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1.0, low sum_y *= interstep_decay count_y *= interstep_decay while True: - if luoi >= len(xolds): - break + if luoi >= len(xolds): break xold = xolds[luoi] if xold <= xnew: decay = np.exp(- (xnew - xold) / decay_period) sum_y += decay * yolds[luoi] count_y += decay luoi += 1 - else: - break + else: break sum_ys[i] = sum_y count_ys[i] = count_y @@ -359,9 +450,8 @@ def one_sided_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1.0, low ys[count_ys < low_counts_threshold] = np.nan return xnews, ys, count_ys - def symmetric_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1.0, low_counts_threshold=0.0): - """ Copied from baselines.common.plot_util + ''' Copy from baselines.common.plot_util Functionality: Perform symmetric EMA (exponential moving average) smoothing and resampling to an even grid with n points. @@ -381,8 +471,8 @@ def symmetric_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1.0, low xs - array with new x grid ys - array of EMA of y at each point of the new x grid count_ys - array of EMA of y counts at each point of the new x grid - """ + ''' xs, ys1, count_ys1 = one_sided_ema(xolds, yolds, low, high, n, decay_steps, low_counts_threshold) _, ys2, count_ys2 = one_sided_ema(-xolds[::-1], yolds[::-1], -high, -low, n, decay_steps, low_counts_threshold) ys2 = ys2[::-1] @@ -393,14 +483,8 @@ def symmetric_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1.0, low xs = [int(x) for x in xs] return xs, ys, count_ys - def moving_average(values, window): - """ Copied from https://github.com/deepmind/dqn_zoo/blob/master/dqn_zoo_plots.ipynb - Smooth values by doing a moving average - :param values: (numpy array) - :param window: (int) - :return: (numpy array) - """ + # Copied from https://github.com/deepmind/dqn_zoo/blob/master/dqn_zoo_plots.ipynb numerator = np.nancumsum(values) numerator[window:] = numerator[window:] - numerator[:-window] denominator = np.ones(len(values)) * window @@ -408,28 +492,26 @@ def moving_average(values, window): smoothed = numerator / denominator assert values.shape == smoothed.shape return smoothed - - + def get_total_combination(exp): - """ + ''' Get total combination of experiment configuration - """ + ''' config_file = f'./configs/{exp}.json' assert os.path.isfile(config_file), f'[{exp}]: No config file <{config_file}>!' sweeper = Sweeper(config_file) return sweeper.config_dicts['num_combinations'] def find_key_value(config_dict, key_list): - """ + ''' Find key value in config dict recursively given a key_list which represents the keys in path. - """ - for k, v in config_dict.items(): - if k == key_list[0]: - if len(key_list)==1: - return v - else: - return find_key_value(v, key_list[1:]) - return '/' + ''' + for k in key_list: + try: + config_dict = config_dict[k] + except: + return '/' + return config_dict def read_file(result_file): if not os.path.isfile(result_file): @@ -440,7 +522,5 @@ def read_file(result_file): print(f'No result in file <{result_file}>') return None else: - if result.isnull().values.any(): - print(f'NaN detected in file <{result_file}>. Replace NaN with 0.') - result = result.replace(np.nan, 0) + result = result.replace(np.nan, 0) return result \ No newline at end of file diff --git a/utils/sweeper.py b/utils/sweeper.py index 144250f..f92102c 100644 --- a/utils/sweeper.py +++ b/utils/sweeper.py @@ -1,4 +1,4 @@ -# Copyright 2022 Garena Online Private Limited. +# Copyright 2024 Garena Online Private Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,236 +12,232 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os - -import matplotlib.pyplot as plt +import sys +import json +import argparse import numpy as np +import matplotlib.pyplot as plt parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) os.sys.path.insert(0, parentdir) class Sweeper(object): - """ - This class generates a Config object and corresponding config dict - given an index and the config file. - """ - - def __init__(self, config_file): - with open(config_file, "r") as f: - self.config_dicts = json.load(f) - self.get_num_combinations_of_dict(self.config_dicts) - - def get_num_combinations_of_dict(self, config_dict): - """ - Get # of combinations for configurations in a config dict - """ - assert type(config_dict) == dict, "Config file must be a dict!" - num_combinations_of_dict = 1 - for key, values in config_dict.items(): - num_combinations_of_list = self.get_num_combinations_of_list(values) - num_combinations_of_dict *= num_combinations_of_list - config_dict["num_combinations"] = num_combinations_of_dict - - def get_num_combinations_of_list(self, config_list): - """ - Get # of combinations for configurations in a config list - """ - assert type(config_list) == list, "Elements in a config dict must be a list!" - num_combinations_of_list = 0 - for value in config_list: - if type(value) == dict: - if not ("num_combinations" in value.keys()): - self.get_num_combinations_of_dict(value) - num_combinations_of_list += value["num_combinations"] - else: - num_combinations_of_list += 1 - return num_combinations_of_list - - def generate_config_for_idx(self, idx): - """ - Generate a config dict for the index. - Index is from 1 to # of conbinations. - """ - # Get config dict given the index - cfg = self.get_dict_value( - self.config_dicts, (idx - 1) % self.config_dicts["num_combinations"] - ) - # Set config index - cfg["config_idx"] = idx - # Set number of combinations - cfg["num_combinations"] = self.config_dicts["num_combinations"] - return cfg - - def get_list_value(self, config_list, idx): - for value in config_list: - if type(value) == dict: - if idx + 1 - value["num_combinations"] <= 0: - return self.get_dict_value(value, idx) - else: - idx -= value["num_combinations"] - else: - if idx == 0: - return value - else: - idx -= 1 - - def get_dict_value(self, config_dict, idx): - cfg = dict() - for key, values in config_dict.items(): - if key == "num_combinations": - continue - num_combinations_of_list = self.get_num_combinations_of_list(values) - value = self.get_list_value(values, idx % num_combinations_of_list) - cfg[key] = value - idx = idx // num_combinations_of_list - return cfg - - def print_config_dict(self, config_dict): - cfg_json = json.dumps(config_dict, indent=2) - print(cfg_json, end="\n") - - -def unfinished_index(exp, file_name="log.txt", runs=1, max_line_length=10000): - """ - Find unfinished config indexes based on the existence of time info in the log file - """ - # Read config files - config_file = f"./configs/{exp}.json" - sweeper = Sweeper(config_file) - # Read a list of logs - print(f"[{exp}]: ", end=" ") - for i in range(runs * sweeper.config_dicts["num_combinations"]): - log_file = f"./logs/{exp}/{i+1}/{file_name}" + ''' + This class generates a Config object and corresponding config dict + given an index and the config file. + ''' + def __init__(self, config_file): + with open(config_file, 'r') as f: + self.config_dicts = json.load(f) + self.get_num_combinations_of_dict(self.config_dicts) + + def get_num_combinations_of_dict(self, config_dict): + ''' + Get # of combinations for configurations in a config dict + ''' + assert type(config_dict) == dict, 'Config file must be a dict!' + num_combinations_of_dict = 1 + for key, values in config_dict.items(): + num_combinations_of_list = self.get_num_combinations_of_list(values) + num_combinations_of_dict *= num_combinations_of_list + config_dict['num_combinations'] = num_combinations_of_dict + + def get_num_combinations_of_list(self, config_list): + ''' + Get # of combinations for configurations in a config list + ''' + assert type(config_list) == list, 'Elements in a config dict must be a list!' + num_combinations_of_list = 0 + for value in config_list: + if type(value) == dict: + if not('num_combinations' in value.keys()): + self.get_num_combinations_of_dict(value) + num_combinations_of_list += value['num_combinations'] + else: + num_combinations_of_list += 1 + return num_combinations_of_list + + def generate_config_for_idx(self, idx): + ''' + Generate a config dict for the index. + If index < 0, set it to index+num_combinations+1. + If index > 0, set it to index % num_combinations. + Note that the first index is 1, not zero. + ''' + if idx < 0: + idx = idx + self.config_dicts['num_combinations'] + 1 + assert idx > 0, 'Index must >= -num_combinations.' + # Get config dict given the index + cfg = self.get_dict_value(self.config_dicts, (idx-1) % self.config_dicts['num_combinations']) + # Set config index + cfg['config_idx'] = idx + # Set number of combinations + cfg['num_combinations'] = self.config_dicts['num_combinations'] + return cfg + + def get_list_value(self, config_list, idx): + for value in config_list: + if type(value) == dict: + if idx + 1 - value['num_combinations'] <= 0: + return self.get_dict_value(value, idx) + else: + idx -= value['num_combinations'] + else: + if idx == 0: + return value + else: + idx -= 1 + + def get_dict_value(self, config_dict, idx): + cfg = dict() + for key, values in config_dict.items(): + if key == 'num_combinations': + continue + num_combinations_of_list = self.get_num_combinations_of_list(values) + value = self.get_list_value(values, idx % num_combinations_of_list) + cfg[key] = value + idx = idx // num_combinations_of_list + return cfg + + def print_config_dict(self, config_dict): + cfg_json = json.dumps(config_dict, indent=2) + print(cfg_json, end='\n') + + +def unfinished_index(exp, file_name='log.txt', runs=1, max_line_length=10000): + ''' + Find unfinished config indexes based on the existence of time info in the log file + ''' + # Read config files + config_file = f'./configs/{exp}.json' + sweeper = Sweeper(config_file) + # Read a list of logs + print(f'[{exp}]: ', end=' ') + for i in range(runs * sweeper.config_dicts['num_combinations']): + log_file = f'./logs/{exp}/{i+1}/{file_name}' + try: + with open(log_file, 'r') as f: + # Get last line try: - with open(log_file, "r") as f: - # Get last line - try: - f.seek(-max_line_length, os.SEEK_END) - except IOError: - # either file is too small, or too many lines requested - f.seek(0) - last_line = f.readlines()[-1] - # Get time info in last line - try: - float(last_line.split(" ")[-2]) - except Exception: - print(i + 1, end=", ") - continue - except Exception: - print(i + 1, end=", ") - continue - print() - - -def time_info(exp, file_name="log.txt", runs=1, nbins=10, max_line_length=10000): - time_list = [] - # Read config file - config_file = f"./configs/{exp}.json" - sweeper = Sweeper(config_file) - # Read a list of logs - for i in range(runs * sweeper.config_dicts["num_combinations"]): - log_file = f"./logs/{exp}/{i+1}/{file_name}" + f.seek(-max_line_length, os.SEEK_END) + except IOError: + # either file is too small, or too many lines requested + f.seek(0) + last_line = f.readlines()[-1] + # Get time info in last line try: - with open(log_file, "r") as f: - # Get last line - try: - f.seek(-max_line_length, os.SEEK_END) - except IOError: - # either file is too small, or too many lines requested - f.seek(0) - last_line = f.readlines()[-1] - # Get time info in last line - try: - t = float(last_line.split(" ")[-2]) - time_list.append(t) - except Exception: - print("No time info in file: " + log_file) - continue - except Exception: - continue - - if len(time_list) > 0: - time_list = np.array(time_list) - print(f"{exp} max time: {np.max(time_list):.4f} minutes") - print(f"{exp} mean time: {np.mean(time_list):.4f} minutes") - print(f"{exp} min time: {np.min(time_list):.4f} minutes") - - # Plot histogram of time distribution - from utils.helper import make_dir - - make_dir(f"./logs/{exp}/0/") - num, bins, patches = plt.hist(time_list, nbins) - plt.xlabel("Time (min)") - plt.ylabel("Counts in the bin") - plt.savefig(f"./logs/{exp}/0/time_info.png") - # plt.show() - plt.clf() # clear figure - plt.cla() # clear axis - plt.close() # close window - else: - print(f"{exp}: no time info!") - - -def memory_info(exp, file_name="log.txt", runs=1, nbins=10, max_line_length=10000): - mem_list = [] - # Read config file - config_file = f"./configs/{exp}.json" - sweeper = Sweeper(config_file) - # Read a list of logs - for i in range(runs * sweeper.config_dicts["num_combinations"]): - log_file = f"./logs/{exp}/{i+1}/{file_name}" + t = float(last_line.split(' ')[-2]) + except: + print(i+1, end=', ') + continue + except: + print(i+1, end=', ') + continue + print() + + +def time_info(exp, file_name='log.txt', runs=1, nbins=10, max_line_length=10000): + time_list = [] + # Read config file + config_file = f'./configs/{exp}.json' + sweeper = Sweeper(config_file) + # Read a list of logs + for i in range(runs * sweeper.config_dicts['num_combinations']): + log_file = f'./logs/{exp}/{i+1}/{file_name}' + try: + with open(log_file, 'r') as f: + # Get last line try: - with open(log_file, "r") as f: - # Get last line - try: - f.seek(-max_line_length, os.SEEK_END) - except IOError: - # either file is too small, or too many lines requested - f.seek(0) - last_second_line = f.readlines()[-2] - # Get memory info in last line - try: - m = float(last_second_line.split(" ")[-2]) - mem_list.append(m) - except Exception: - print("No memory info in file: " + log_file) - continue - except Exception: - continue - - if len(mem_list) > 0: - mem_list = np.array(mem_list) - print(f"{exp} max memory: {np.max(mem_list):.2f} MB") - print(f"{exp} mean memory: {np.mean(mem_list):.2f} MB") - print(f"{exp} min memory: {np.min(mem_list):.2f} MB") - - # Plot histogram of time distribution - from utils.helper import make_dir - - make_dir(f"./logs/{exp}/0/") - num, bins, patches = plt.hist(mem_list, nbins) - plt.xlabel("Memory (MB)") - plt.ylabel("Counts in the bin") - plt.savefig(f"./logs/{exp}/0/memory_info.png") - # plt.show() - plt.clf() # clear figure - plt.cla() # clear axis - plt.close() # close window - else: - print(f"{exp}: no memory info!") - + f.seek(-max_line_length, os.SEEK_END) + except IOError: + # either file is too small, or too many lines requested + f.seek(0) + last_line = f.readlines()[-1] + # Get time info in last line + try: + t = float(last_line.split(' ')[-2]) + time_list.append(t) + except: + print('No time info in file: '+log_file) + continue + except: + continue + + if len(time_list) > 0: + time_list = np.array(time_list) + print(f'{exp} max time: {np.max(time_list):.2f} minutes') + print(f'{exp} mean time: {np.mean(time_list):.2f} minutes') + print(f'{exp} min time: {np.min(time_list):.2f} minutes') + + # Plot histogram of time distribution + from utils.helper import make_dir + make_dir(f'./logs/{exp}/0/') + num, bins, patches = plt.hist(time_list, nbins) + plt.xlabel('Time (min)') + plt.ylabel('Counts in the bin') + plt.savefig(f'./logs/{exp}/0/time_info.png') + # plt.show() + plt.clf() # clear figure + plt.cla() # clear axis + plt.close() # close window + else: + print(f'{exp}: no time info!') + +def memory_info(exp, file_name='log.txt', runs=1, nbins=10, max_line_length=10000): + mem_list = [] + # Read config file + config_file = f'./configs/{exp}.json' + sweeper = Sweeper(config_file) + # Read a list of logs + for i in range(runs * sweeper.config_dicts['num_combinations']): + log_file = f'./logs/{exp}/{i+1}/{file_name}' + try: + with open(log_file, 'r') as f: + # Get last line + try: + f.seek(-max_line_length, os.SEEK_END) + except IOError: + # either file is too small, or too many lines requested + f.seek(0) + last_second_line = f.readlines()[-2] + # Get memory info in last line + try: + m = float(last_second_line.split(' ')[-2]) + mem_list.append(m) + except: + print('No memory info in file: '+log_file) + continue + except: + continue + + if len(mem_list) > 0: + mem_list = np.array(mem_list) + print(f'{exp} max memory: {np.max(mem_list):.2f} MB') + print(f'{exp} mean memory: {np.mean(mem_list):.2f} MB') + print(f'{exp} min memory: {np.min(mem_list):.2f} MB') + + # Plot histogram of time distribution + from utils.helper import make_dir + make_dir(f'./logs/{exp}/0/') + num, bins, patches = plt.hist(mem_list, nbins) + plt.xlabel('Memory (MB)') + plt.ylabel('Counts in the bin') + plt.savefig(f'./logs/{exp}/0/memory_info.png') + # plt.show() + plt.clf() # clear figure + plt.cla() # clear axis + plt.close() # close window + else: + print(f'{exp}: no memory info!') if __name__ == "__main__": - for agent_config in os.listdir("./configs/"): - if ".json" not in agent_config: - continue - config_file = os.path.join("./configs/", agent_config) - sweeper = Sweeper(config_file) - # sweeper.print_config_dict(sweeper.config_dicts) - # sweeper.print_config_dict(sweeper.generate_config_for_idx(213)) - print( - f"The number of total combinations in {agent_config}:", - sweeper.config_dicts["num_combinations"], - ) + for agent_config in os.listdir('./configs/'): + if not '.json' in agent_config: + continue + config_file = os.path.join('./configs/', agent_config) + sweeper = Sweeper(config_file) + # sweeper.print_config_dict(sweeper.config_dicts) + # sweeper.print_config_dict(sweeper.generate_config_for_idx(213)) + print(f'Number of total combinations in {agent_config}:', sweeper.config_dicts['num_combinations']) \ No newline at end of file