From d9145e602dddcb911f20fee4987c98487ca485f0 Mon Sep 17 00:00:00 2001 From: qlan3 Date: Mon, 23 May 2022 21:07:39 -0600 Subject: [PATCH] add MeDQN --- .gitignore | 6 +- README.md | 64 ++++++------ agents/AveragedDQN.py | 6 +- agents/BootstrappedDQN.py | 3 - agents/DQN.py | 6 +- agents/MaxminDQN.py | 5 +- agents/MeDQN_Real.py | 55 +++++++++++ agents/MeDQN_Uniform.py | 66 +++++++++++++ agents/RepOffPG.py | 47 --------- agents/RepOnPG.py | 97 ------------------- agents/SAC.py | 4 +- agents/VanillaDQN.py | 18 ++-- agents/__init__.py | 4 +- analysis.py | 13 +-- components/exploration.py | 8 +- components/network.py | 72 ++++++++++++-- components/normalizer.py | 12 +++ components/replay.py | 49 ++++++++-- configs/MERL_acrobot_dqn.json | 39 ++++++++ configs/MERL_acrobot_medqn.json | 53 ++++++++++ configs/MERL_catcher_dqn.json | 39 ++++++++ configs/MERL_catcher_medqn.json | 53 ++++++++++ configs/MERL_copter_dqn.json | 39 ++++++++ configs/MERL_copter_medqn.json | 53 ++++++++++ configs/MERL_mc_dqn.json | 39 ++++++++ configs/MERL_mc_dqn_small.json | 39 ++++++++ configs/MERL_mc_medqn.json | 53 ++++++++++ configs/MERL_mc_medqn_lambda.json | 47 +++++++++ configs/MERL_minatar_dqn.json | 45 +++++++++ configs/MERL_minatar_medqn_real.json | 53 ++++++++++ configs/MERL_minatar_medqn_uniform.json | 52 ++++++++++ configs/{catcher.json => Maxmin_catcher.json} | 6 +- configs/{copter.json => Maxmin_copter.json} | 6 +- configs/{lunar.json => Maxmin_lunar.json} | 6 +- configs/{minatar.json => Maxmin_minatar.json} | 6 +- configs/{mujoco_rpg.json => RPG.json} | 2 +- configs/mujoco_ppo.json | 35 ------- configs/mujoco_reinforce.json | 28 ------ configs/mujoco_repoffpg.json | 34 ------- copyfile.sh | 8 ++ find_config.py | 37 +++++-- main.py | 2 +- plot.py | 31 ++++-- procfile | 1 + requirements.txt | 18 ++-- run.py | 58 +++++++++-- run.sh | 2 +- sbatch.sh | 11 ++- unfinish_job.py | 48 +++++++++ utils/plotter.py | 4 + 50 files changed, 1109 insertions(+), 373 deletions(-) create mode 100644 agents/MeDQN_Real.py create mode 100644 agents/MeDQN_Uniform.py delete mode 100644 agents/RepOffPG.py delete mode 100644 agents/RepOnPG.py create mode 100644 configs/MERL_acrobot_dqn.json create mode 100644 configs/MERL_acrobot_medqn.json create mode 100644 configs/MERL_catcher_dqn.json create mode 100644 configs/MERL_catcher_medqn.json create mode 100644 configs/MERL_copter_dqn.json create mode 100644 configs/MERL_copter_medqn.json create mode 100644 configs/MERL_mc_dqn.json create mode 100644 configs/MERL_mc_dqn_small.json create mode 100644 configs/MERL_mc_medqn.json create mode 100644 configs/MERL_mc_medqn_lambda.json create mode 100644 configs/MERL_minatar_dqn.json create mode 100644 configs/MERL_minatar_medqn_real.json create mode 100644 configs/MERL_minatar_medqn_uniform.json rename configs/{catcher.json => Maxmin_catcher.json} (90%) rename configs/{copter.json => Maxmin_copter.json} (90%) rename configs/{lunar.json => Maxmin_lunar.json} (90%) rename configs/{minatar.json => Maxmin_minatar.json} (91%) rename configs/{mujoco_rpg.json => RPG.json} (95%) delete mode 100755 configs/mujoco_ppo.json delete mode 100755 configs/mujoco_reinforce.json delete mode 100755 configs/mujoco_repoffpg.json create mode 100644 copyfile.sh create mode 100644 procfile create mode 100644 unfinish_job.py diff --git a/.gitignore b/.gitignore index 85ee62d..07ff7ec 100644 --- a/.gitignore +++ b/.gitignore @@ -2,10 +2,8 @@ # Edit at https://www.gitignore.io/?templates=python,windows,visualstudiocode # My ignores -logs* +*logs* logfile -procfile -*figure* *output* *DS_Store* @@ -170,4 +168,4 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk -# End of https://www.gitignore.io/api/python,windows,visualstudiocode +# End of https://www.gitignore.io/api/python,windows,visualstudiocode \ No newline at end of file diff --git a/README.md b/README.md index a41c481..947ca16 100755 --- a/README.md +++ b/README.md @@ -8,18 +8,19 @@ Explorer is a PyTorch reinforcement learning framework for **exploring** new ide - Vanilla Deep Q-learning (VanillaDQN): No target network. - [Deep Q-Learning (DQN)](https://users.cs.duke.edu/~pdinesh/sources/MnihEtAlHassibis15NatureControlDeepRL.pdf) - [Double Deep Q-learning (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) -- [Maxmin Deep Q-learning (MaxminDQN)](https://openreview.net/pdf?id=Bkg0u3Etwr) +- [Maxmin Deep Q-learning (MaxminDQN)](https://arxiv.org/pdf/2002.06487.pdf) - [Averaged Deep Q-learning (AveragedDQN)](https://arxiv.org/pdf/1611.01929.pdf) - [Ensemble Deep Q-learning (EnsembleDQN)](https://arxiv.org/pdf/1611.01929.pdf) -- [Bootstrapped Deep Q-learning (BootstrappedDQN)](https://arxiv.org/abs/1602.04621) -- [NoisyNet Deep Q-learning (NoisyNetDQN)](https://arxiv.org/abs/1706.10295) +- [Bootstrapped Deep Q-learning (BootstrappedDQN)](https://arxiv.org/pdf/1602.04621.pdf) +- [NoisyNet Deep Q-learning (NoisyNetDQN)](https://arxiv.org/pdf/1706.10295.pdf) - [REINFORCE](http://incompleteideas.net/book/RLbook2020.pdf) - [Actor-Critic](http://incompleteideas.net/book/RLbook2020.pdf) - [Proximal Policy Optimisation (PPO)](https://arxiv.org/pdf/1707.06347.pdf) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) - [Deep Deterministic Policy Gradients (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) - [Twin Delayed Deep Deterministic Policy Gradients (TD3)](https://arxiv.org/pdf/1802.09477.pdf) -- [Reward Policy Gradient (RPG)](https://arxiv.org/abs/2103.05147) +- [Reward Policy Gradient (RPG)](https://arxiv.org/pdf/2103.05147.pdf) +- [Memory-efficient Deep Q-learning (MeDQN)](https://arxiv.org/pdf/2205.10868.pdf) ## To do list @@ -33,16 +34,14 @@ Explorer is a PyTorch reinforcement learning framework for **exploring** new ide | ├── DQN | | ├── DDQN | | ├── NoisyNetDQN - | | └── BootstrappedDQN + | | ├── BootstrappedDQN + | | └── MeDQN_Uniform, MeDQN_Real | ├── Maxmin DQN ── Ensemble DQN | └── Averaged DQN └── REINFORCE ├── Actor-Critic - | ├── PPO ── RPG - | └── RepOnPG (experimental) - └── SAC ── DDPG - ├── TD3 - └── RepOffPG (experimental) + | └── PPO ── RPG + └── SAC ── DDPG ── TD3 ## Requirements @@ -51,8 +50,12 @@ Explorer is a PyTorch reinforcement learning framework for **exploring** new ide - [PyTorch](https://pytorch.org/) - [Gym && Gym Games](https://github.com/qlan3/gym-games): You may only install part of Gym (`classic_control, box2d`) by command `pip install 'gym[classic_control, box2d]'`. - Optional: - - [Gym Atari](https://github.com/openai/gym/blob/master/docs/environments.md#atari) - - [Gym Mujoco](https://github.com/openai/gym/blob/master/docs/environments.md#mujoco) + - [Gym Atari](https://www.gymlibrary.ml/environments/atari/): `pip install gym[atari,accept-rom-license]` + - [Gym Mujoco](https://www.gymlibrary.ml/environments/mujoco/): + - Download MuJoCo version 1.50 from [MuJoCo website](https://www.roboti.us/download.html). + - Unzip the downloaded `mjpro150` directory into `~/.mujoco/mjpro150`, and place the activation key (the `mjkey.txt` file downloaded from [here](https://www.roboti.us/license.html)) at `~/.mujoco/mjkey.txt`. + - Install [mujoco-py](https://github.com/openai/mujoco-py): `pip install 'mujoco-py<1.50.2,>=1.50.1'` + - Install gym[mujoco]: `pip install gym[mujoco]` - [PyBullet](https://pybullet.org/): `pip install pybullet` - [DeepMind Control Suite](https://github.com/denisyarats/dmc2gym): `pip install git+git://github.com/denisyarats/dmc2gym.git` - Others: Please check `requirements.txt`. @@ -62,76 +65,77 @@ Explorer is a PyTorch reinforcement learning framework for **exploring** new ide ### Train && Test -All hyperparameters including parameters for grid search are stored in a configuration file in directory `configs`. To run an experiment, a configuration index is first used to generate a configuration dict corresponding to this specific configuration index. Then we run an experiment defined by this configuration dict. All results including log files and the model file are saved in directory `logs`. Please refer to the code for details. +All hyperparameters including parameters for grid search are stored in a configuration file in directory `configs`. To run an experiment, a configuration index is first used to generate a configuration dict corresponding to this specific configuration index. Then we run an experiment defined by this configuration dict. All results including log files are saved in directory `logs`. Please refer to the code for details. -For example, run the experiment with configuration file `catcher.json` and configuration index `1`: +For example, run the experiment with configuration file `RPG.json` and configuration index `1`: -```python main.py --config_file ./configs/catcher.json --config_idx 1``` +```python main.py --config_file ./configs/RPG.json --config_idx 1``` The models are tested for one episode after every `test_per_episodes` training episodes which can be set in the configuration file. ### Grid Search (Optional) -First, we calculate the number of total combinations in a configuration file (e.g. `catcher.json`): +First, we calculate the number of total combinations in a configuration file (e.g. `RPG.json`): `python utils/sweeper.py` The output will be: -`Number of total combinations in catcher.json: 90` +`Number of total combinations in RPG.json: 12` -Then we run through all configuration indexes from `1` to `90`. The simplest way is a bash script: +Then we run through all configuration indexes from `1` to `12`. The simplest way is using a bash script: ``` bash -for index in {1..90} +for index in {1..12} do - python main.py --config_file ./configs/catcher.json --config_idx $index + python main.py --config_file ./configs/RPG.json --config_idx $index done ``` [Parallel](https://www.gnu.org/software/parallel/) is usually a better choice to schedule a large number of jobs: ``` bash -parallel --eta --ungroup python main.py --config_file ./configs/catcher.json --config_idx {1} ::: $(seq 1 90) +parallel --eta --ungroup python main.py --config_file ./configs/RPG.json --config_idx {1} ::: $(seq 1 12) ``` -Any configuration index that has the same remainder (divided by the number of total combinations) should has the same configuration dict. So for multiple runs, we just need to add the number of total combinations to the configuration index. For example, 5 runs for configuration index `1`: +Any configuration index that has the same remainder (divided by the number of total combinations) should have the same configuration dict. So for multiple runs, we just need to add the number of total combinations to the configuration index. For example, 5 runs for configuration index `1`: ``` -for index in 1 91 181 271 361 +for index in 1 13 25 37 49 do - python main.py --config_file ./configs/catcher.json --config_idx $index + python main.py --config_file ./configs/RPG.json --config_idx $index done ``` Or a simpler way: ``` -parallel --eta --ungroup python main.py --config_file ./configs/catcher.json --config_idx {1} ::: $(seq 1 90 450) +parallel --eta --ungroup python main.py --config_file ./configs/RPG.json --config_idx {1} ::: $(seq 1 12 60) ``` ### Analysis (Optional) -To analysis the experimental results, just run: +To analyze the experimental results, just run: `python analysis.py` -Inside `analysis.py`, `unfinished_index` will print out the configuration indexes of unfinished jobs based on the existence of the result file. `memory_info` will print out the memory usage information and generate a histogram to show the distribution of memory usages in directory `logs/catcher/0`. Similarly, `time_info` will print out the time information and generate a histogram to show the distribution of time in directory `logs/catcher/0`. Finally, `analyze` will generate `csv` files that store training and test results. More functions are available in `utils/plotter.py`. +Inside `analysis.py`, `unfinished_index` will print out the configuration indexes of unfinished jobs based on the existence of the result file. `memory_info` will print out the memory usage information and generate a histogram to show the distribution of memory usages in directory `logs/RPG/0`. Similarly, `time_info` will print out the time information and generate a histogram to show the distribution of time in directory `logs/RPG/0`. Finally, `analyze` will generate `csv` files that store training and test results. Please check `analysis.py` for more details. More functions are available in `utils/plotter.py`. Enjoy! ## Code of My Papers -- **Qingfeng Lan**, Yangchen Pan, Alona Fyshe, Martha White. **Maxmin Q-learning: Controlling the Estimation Bias of Q-learning.** ICLR, 2020. **(Poster)** [[paper]](https://openreview.net/pdf?id=Bkg0u3Etwr) [[code]](https://github.com/qlan3/Explorer/releases/tag/maxmin1.0) [[video]](https://iclr.cc/virtual/poster_Bkg0u3Etwr.html) +- **Qingfeng Lan**, Yangchen Pan, Alona Fyshe, Martha White. **Maxmin Q-learning: Controlling the Estimation Bias of Q-learning.** ICLR, 2020. **(Poster)** [[paper]](https://openreview.net/pdf?id=Bkg0u3Etwr) [[code]](https://github.com/qlan3/Explorer/releases/tag/maxmin1.0) -- **Qingfeng Lan**, Samuele Tosatto, Homayoon Farrahi, A. Rupam Mahmood. **Model-free Policy Learning with Reward Gradients.** AISTATS, 2022. **(Poster)** [[paper]](https://arxiv.org/abs/2103.05147) [[code]](https://github.com/qlan3/Explorer/tree/RPG) +- **Qingfeng Lan**, Samuele Tosatto, Homayoon Farrahi, A. Rupam Mahmood. **Model-free Policy Learning with Reward Gradients.** AISTATS, 2022. **(Poster)** [[paper]](https://arxiv.org/pdf/2103.05147.pdf) [[code]](https://github.com/qlan3/Explorer/tree/RPG) +- **Qingfeng Lan**, Yangchen Pan, Jun Luo, A. Rupam Mahmood. **Memory-efficient Reinforcement Learning with Knowledge Consolidation.** Arxiv [[paper]](https://arxiv.org/pdf/2205.10868.pdf) [[code]](https://github.com/qlan3/Explorer/) ## Cite -If you find this repo useful to your research, please cite my paper if it is related. Otherwisee, please use this bibtex to cite this repo +If you find this repo useful to your research, please cite my paper if related. Otherwise, please cite this repo: ~~~bibtex @misc{Explorer, diff --git a/agents/AveragedDQN.py b/agents/AveragedDQN.py index b1b2ad2..22df508 100644 --- a/agents/AveragedDQN.py +++ b/agents/AveragedDQN.py @@ -17,10 +17,8 @@ def __init__(self, cfg): self.Q_net_target[i].eval() self.update_target_net_index = 0 - def learn(self): - super().learn() - # Update target network - if (self.step_count // self.cfg['network_update_frequency']) % self.cfg['target_network_update_frequency'] == 0: + def update_target_net(self): + if self.step_count % self.cfg['target_network_update_steps'] == 0: self.Q_net_target[self.update_target_net_index].load_state_dict(self.Q_net[self.update_Q_net_index].state_dict()) self.update_target_net_index = (self.update_target_net_index + 1) % self.k diff --git a/agents/BootstrappedDQN.py b/agents/BootstrappedDQN.py index 10ddbc1..797d739 100644 --- a/agents/BootstrappedDQN.py +++ b/agents/BootstrappedDQN.py @@ -62,9 +62,6 @@ def learn(self): if self.gradient_clip > 0: nn.utils.clip_grad_norm_(self.Q_net[0].parameters(), self.gradient_clip) self.optimizer[0].step() - # Update target network - if (self.step_count // self.cfg['network_update_frequency']) % self.cfg['target_network_update_frequency'] == 0: - self.Q_net_target[0].load_state_dict(self.Q_net[0].state_dict()) if self.show_tb: self.logger.add_scalar(f'Loss', loss.item(), self.step_count) diff --git a/agents/DQN.py b/agents/DQN.py index 9acf9f3..d65f738 100644 --- a/agents/DQN.py +++ b/agents/DQN.py @@ -14,10 +14,8 @@ def __init__(self, cfg): self.Q_net_target[0].load_state_dict(self.Q_net[0].state_dict()) self.Q_net_target[0].eval() - def learn(self): - super().learn() - # Update target network - if (self.step_count // self.cfg['network_update_frequency']) % self.cfg['target_network_update_frequency'] == 0: + def update_target_net(self): + if self.step_count % self.cfg['target_network_update_steps'] == 0: self.Q_net_target[self.update_Q_net_index].load_state_dict(self.Q_net[self.update_Q_net_index].state_dict()) def compute_q_target(self, batch): diff --git a/agents/MaxminDQN.py b/agents/MaxminDQN.py index 6162002..8f3722a 100644 --- a/agents/MaxminDQN.py +++ b/agents/MaxminDQN.py @@ -27,8 +27,9 @@ def learn(self): # Choose a Q_net to udpate self.update_Q_net_index = np.random.choice(list(range(self.k))) super().learn() - # Update target network - if (self.step_count // self.cfg['network_update_frequency']) % self.cfg['target_network_update_frequency'] == 0: + + def update_target_net(self): + if self.step_count % self.cfg['target_network_update_steps'] == 0: for i in range(self.k): self.Q_net_target[i].load_state_dict(self.Q_net[i].state_dict()) diff --git a/agents/MeDQN_Real.py b/agents/MeDQN_Real.py new file mode 100644 index 0000000..d5397e0 --- /dev/null +++ b/agents/MeDQN_Real.py @@ -0,0 +1,55 @@ +from agents.DQN import * + + +class MeDQN_Real(DQN): + ''' + Implementation of MeDQN_Real (Memory-efficient DQN with real state sampling) + - Consolidatie knowledge from target Q-network to current Q-network. + - A state replay buffer is applied. + - A tiny (e.g., one mini-batch size) experience replay buffer is used in practice. + ''' + def __init__(self, cfg): + # Set the consolidation batch size + if 'consod_batch_size' not in cfg['agent'].keys(): + cfg['agent']['consod_batch_size'] = cfg['batch_size'] + super().__init__(cfg) + self.replay = getattr(components.replay, cfg['memory_type'])(cfg['memory_size'], keys=['state', 'action', 'next_state', 'reward', 'mask']) + # Set real state sampler for knowledge consolidation + self.state_sampler = getattr(components.replay, cfg['memory_type'])(cfg['agent']['consod_size'], keys=['state']) + # Set consolidation regularization strategy + epsilon = { + 'steps': float(cfg['train_steps']), + 'start': cfg['agent']['consod_start'], + 'end': cfg['agent']['consod_end'] + } + self.consolidate = getattr(components.exploration, 'LinearEpsilonGreedy')(-1, epsilon) + + def save_experience(self): + super().save_experience() + self.state_sampler.add({'state': to_tensor(self.state['Train'], self.device)}) + + def learn(self): + mode = 'Train' + batch = self.replay.get(['state', 'action', 'reward', 'next_state', 'mask'], self.cfg['memory_size']) + q_target = self.compute_q_target(batch) + lamda = self.consolidate.get_epsilon(self.step_count) # Compute consolidation regularization parameter + for _ in range(self.cfg['agent']['consod_epoch']): + q = self.compute_q(batch) + sample_state = self.state_sampler.sample(['state'], self.cfg['agent']['consod_batch_size']).state + # Compute loss + loss = self.loss(q, q_target) + loss += lamda * self.consolidation_loss(sample_state) + # Take an optimization step + self.optimizer[0].zero_grad() + loss.backward() + if self.gradient_clip > 0: + nn.utils.clip_grad_norm_(self.Q_net[0].parameters(), self.gradient_clip) + self.optimizer[0].step() + if self.show_tb: + self.logger.add_scalar(f'Loss', loss.item(), self.step_count) + + def consolidation_loss(self, state): + q_values = self.Q_net[0](state).squeeze() + q_target_values = self.Q_net_target[0](state).squeeze().detach() + loss = nn.MSELoss(reduction='mean')(q_values, q_target_values) + return loss \ No newline at end of file diff --git a/agents/MeDQN_Uniform.py b/agents/MeDQN_Uniform.py new file mode 100644 index 0000000..bde93a5 --- /dev/null +++ b/agents/MeDQN_Uniform.py @@ -0,0 +1,66 @@ +from agents.DQN import * + + +class MeDQN_Uniform(DQN): + ''' + Implementation of MeDQN_Uniform (Memory-efficient DQN with uniform state sampling) + - Consolidatie knowledge from target Q-network to current Q-network. + - The bounds of state space are updated with real states frequently. + - A tiny (e.g., one mini-batch size) experience replay buffer is used in practice. + ''' + def __init__(self, cfg): + # Set the consolidation batch size + if 'consod_batch_size' not in cfg['agent'].keys(): + cfg['agent']['consod_batch_size'] = cfg['batch_size'] + super().__init__(cfg) + self.replay = getattr(components.replay, cfg['memory_type'])(cfg['memory_size'], keys=['state', 'action', 'next_state', 'reward', 'mask']) + # Set uniform state sampler for knowledge consolidation + if 'MinAtar' in self.env_name: + self.state_sampler = DiscreteUniformSampler( + shape=self.env['Train'].observation_space.shape, + normalizer=self.state_normalizer, + device=self.device + ) + else: + self.state_sampler = ContinousUniformSampler( + shape=self.env['Train'].observation_space.shape, + normalizer=self.state_normalizer, + device=self.device + ) + # Set consolidation regularization strategy + epsilon = { + 'steps': float(cfg['train_steps']), + 'start': cfg['agent']['consod_start'], + 'end': cfg['agent']['consod_end'] + } + self.consolidate = getattr(components.exploration, 'LinearEpsilonGreedy')(-1, epsilon) + + def save_experience(self): + super().save_experience() + self.state_sampler.update_bound(self.original_state) + + def learn(self): + mode = 'Train' + batch = self.replay.get(['state', 'action', 'reward', 'next_state', 'mask'], self.cfg['memory_size']) + q_target = self.compute_q_target(batch) + lamda = self.consolidate.get_epsilon(self.step_count) # Compute consolidation regularization parameter + for _ in range(self.cfg['agent']['consod_epoch']): + q = self.compute_q(batch) + sample_state = self.state_sampler.sample(self.cfg['agent']['consod_batch_size']) + # Compute loss + loss = self.loss(q, q_target) + loss += lamda * self.consolidation_loss(sample_state) + # Take an optimization step + self.optimizer[0].zero_grad() + loss.backward() + if self.gradient_clip > 0: + nn.utils.clip_grad_norm_(self.Q_net[0].parameters(), self.gradient_clip) + self.optimizer[0].step() + if self.show_tb: + self.logger.add_scalar(f'Loss', loss.item(), self.step_count) + + def consolidation_loss(self, state): + q_values = self.Q_net[0](state).squeeze() + q_target_values = self.Q_net_target[0](state).squeeze().detach() + loss = nn.MSELoss(reduction='mean')(q_values, q_target_values) + return loss \ No newline at end of file diff --git a/agents/RepOffPG.py b/agents/RepOffPG.py deleted file mode 100644 index a9f5cad..0000000 --- a/agents/RepOffPG.py +++ /dev/null @@ -1,47 +0,0 @@ -from agents.DDPG import * - - -class RepOffPG(DDPG): - ''' - Implementation of RepOffPG (Reparameterization Off-Policy Gradient), almost the same as SVG(0). - ''' - def __init__(self, cfg): - super().__init__(cfg) - - def createNN(self, input_type): - # Set feature network - if input_type == 'pixel': - input_size = self.cfg['feature_dim'] - if 'MinAtar' in self.env_name: - feature_net = Conv2d_MinAtar(in_channels=self.env[mode].game.state_shape()[2], feature_dim=input_size) - else: - feature_net = Conv2d_Atari(in_channels=4, feature_dim=input_size) - elif input_type == 'feature': - input_size = self.state_size - feature_net = nn.Identity() - # Set actor network - assert self.action_type == 'CONTINUOUS', f"{self.cfg['agent']['name']} only supports continous action spaces." - actor_net = MLPStdGaussianActor(action_lim=self.action_lim, layer_dims=[input_size]+self.cfg['hidden_layers']+[2*self.action_size], hidden_act=self.cfg['hidden_act'], rsample=True) - # Set critic network - critic_net = MLPQCritic(layer_dims=[input_size+self.action_size]+self.cfg['hidden_layers']+[1], hidden_act=self.cfg['hidden_act'], output_act=self.cfg['output_act']) - # Set the model - NN = ActorQCriticNet(feature_net, actor_net, critic_net) - return NN - - def get_action(self, mode='Train'): - ''' - Pick an action from policy network - ''' - if self.step_count <= self.cfg['exploration_steps']: - prediction = {'action': torch.as_tensor(self.env[mode].action_space.sample())} - else: - deterministic = True if mode == 'Test' else False - state = to_tensor(self.state[mode], self.device) - prediction = self.network(state, deterministic=deterministic) - return prediction - - def compute_q_target(self, batch): - with torch.no_grad(): - q_next = self.network_target(batch.next_state, deterministic=True)['q'] - q_target = batch.reward + self.discount * batch.mask * q_next - return q_target \ No newline at end of file diff --git a/agents/RepOnPG.py b/agents/RepOnPG.py deleted file mode 100644 index b250be6..0000000 --- a/agents/RepOnPG.py +++ /dev/null @@ -1,97 +0,0 @@ -from agents.ActorCritic import * - - -class RepOnPG(ActorCritic): - ''' - Implementation of RepOnPG (Reparameterization On-Policy Gradient) - ''' - def __init__(self, cfg): - super().__init__(cfg) - # Set replay buffer - self.replay = FiniteReplay(self.steps_per_epoch+1, keys=['state', 'action', 'reward', 'mask', 'q', 'q_target']) - - def save_experience(self, prediction): - # Save state, action, reward, mask, q - mode = 'Train' - if self.reward[mode] is not None: - prediction = { - 'state': to_tensor(self.state[mode], self.device), - 'action': to_tensor(self.action[mode], self.device), - 'reward': to_tensor(self.reward[mode], self.device), - 'mask': to_tensor(1-self.done[mode], self.device), - 'q': prediction['prediction'] - } - self.replay.add(prediction) - else: - self.replay.add({'q': prediction['q']}) - - def createNN(self, input_type): - # Set feature network - if input_type == 'pixel': - input_size = self.cfg['feature_dim'] - if 'MinAtar' in self.env_name: - feature_net = Conv2d_MinAtar(in_channels=self.env[mode].game.state_shape()[2], feature_dim=input_size) - else: - feature_net = Conv2d_Atari(in_channels=4, feature_dim=input_size) - elif input_type == 'feature': - input_size = self.state_size - feature_net = nn.Identity() - # Set actor network - assert self.action_type == 'CONTINUOUS', f"{self.cfg['agent']['name']} only supports continous action spaces." - actor_net = MLPStdGaussianActor(action_lim=self.action_lim, layer_dims=[input_size]+self.cfg['hidden_layers']+[2*self.action_size], hidden_act=self.cfg['hidden_act'], rsample=True) - # Set critic network - critic_net = MLPQCritic(layer_dims=[input_size+self.action_size]+self.cfg['hidden_layers']+[1], hidden_act=self.cfg['hidden_act'], output_act=self.cfg['output_act']) - # Set the model - NN = ActorQCriticNet(feature_net, actor_net, critic_net) - return NN - - def get_action(self, mode='Train'): - ''' - Pick an action from policy network - ''' - state = to_tensor(self.state[mode], self.device) - deterministic = True if mode == 'Test' else False - prediction = self.network(state, deterministic=deterministic) - return prediction - - def learn(self): - mode = 'Train' - # Compute q target - for i in range(self.steps_per_epoch): - q_target = self.replay.reward[i] + self.discount * self.replay.mask[i] * self.replay.q[i+1] - self.replay.q_target[i] = q_target.detach() - # Get training data - batch = self.replay.get(['state', 'action', 'q', 'q_target'], self.steps_per_epoch) - # Take an optimization step for critic - critic_loss = self.compute_critic_loss(batch) - self.optimizer['critic'].zero_grad() - critic_loss.backward() - if self.gradient_clip > 0: - nn.utils.clip_grad_norm_(self.network.critic_params, self.gradient_clip) - self.optimizer['critic'].step() - # Freeze Q-networks to avoid computing gradients for them - for p in self.network.critic_net.parameters(): - p.requires_grad = False - # Take an optimization step for actor - actor_loss = self.compute_actor_loss(batch) - self.optimizer['actor'].zero_grad() - actor_loss.backward() - if self.gradient_clip > 0: - nn.utils.clip_grad_norm_(self.network.actor_params, self.gradient_clip) - self.optimizer['actor'].step() - # Unfreeze Q-networks - for p in self.network.critic_net.parameters(): - p.requires_grad = True - # Log - if self.show_tb: - self.logger.add_scalar(f'actor_loss', actor_loss.item(), self.step_count) - self.logger.add_scalar(f'critic_loss', critic_loss.item(), self.step_count) - - def compute_actor_loss(self, batch): - actor_loss = - batch.q.mean() - return actor_loss - - def compute_critic_loss(self, batch): - q = self.network.get_q(batch.state, batch.action) - critic_loss = (q - batch.q_target).pow(2).mean() - return critic_loss \ No newline at end of file diff --git a/agents/SAC.py b/agents/SAC.py index 9d614bf..df14c06 100644 --- a/agents/SAC.py +++ b/agents/SAC.py @@ -106,7 +106,7 @@ def time_to_learn(self): - The agent is not on exploration stage - It is time to update network """ - if self.step_count > self.cfg['exploration_steps'] and self.step_count % self.cfg['network_update_frequency'] == 0: + if self.step_count > self.cfg['exploration_steps'] and self.step_count % self.cfg['network_update_steps'] == 0: return True else: return False @@ -122,7 +122,7 @@ def learn(self): nn.utils.clip_grad_norm_(self.network.critic_params, self.gradient_clip) self.optimizer['critic'].step() # Take an optimization step for actor - if (self.step_count // self.cfg['network_update_frequency']) % self.cfg['actor_update_frequency'] == 0: + if (self.step_count // self.cfg['network_update_steps']) % self.cfg['actor_update_frequency'] == 0: # Freeze Q-networks to avoid computing gradients for them for p in self.network.critic_net.parameters(): p.requires_grad = False diff --git a/agents/VanillaDQN.py b/agents/VanillaDQN.py index 7c8bdd8..b69cb4e 100644 --- a/agents/VanillaDQN.py +++ b/agents/VanillaDQN.py @@ -70,8 +70,7 @@ def __init__(self, cfg): # Set loss function self.loss = getattr(torch.nn, cfg['loss'])(reduction='mean') # Set replay buffer - # self.replay = getattr(components.replay, cfg['memory_type'])(cfg['memory_size'], self.cfg['batch_size'], self.device) - self.replay = FiniteReplay(cfg['memory_size'], keys=['state', 'action', 'next_state', 'reward', 'mask']) + self.replay = getattr(components.replay, cfg['memory_type'])(cfg['memory_size'], keys=['state', 'action', 'next_state', 'reward', 'mask']) # Set log dict for key in ['state', 'next_state', 'action', 'reward', 'done', 'episode_return', 'episode_step_count']: setattr(self, key, {'Train': None, 'Test': None}) @@ -95,7 +94,8 @@ def createNN(self, input_type): def reset_game(self, mode): # Reset the game before a new episode - self.state[mode] = self.state_normalizer(self.env[mode].reset()) + self.original_state = self.env[mode].reset() # state before processed + self.state[mode] = self.state_normalizer(self.original_state) self.next_state[mode] = None self.action[mode] = None self.reward[mode] = None @@ -129,8 +129,8 @@ def run_episode(self, mode, render): if render: self.env[mode].render() # Take a step - self.next_state[mode], self.reward[mode], self.done[mode], _ = self.env[mode].step(self.action[mode]) - self.next_state[mode] = self.state_normalizer(self.next_state[mode]) + next_state, self.reward[mode], self.done[mode], _ = self.env[mode].step(self.action[mode]) + self.next_state[mode] = self.state_normalizer(next_state) self.reward[mode] = self.reward_normalizer(self.reward[mode]) self.episode_return[mode] += self.reward[mode] self.episode_step_count[mode] += 1 @@ -140,9 +140,12 @@ def run_episode(self, mode, render): # Update policy if self.time_to_learn(): self.learn() + # Update target Q network: used only in DQN variants + self.update_target_net() self.step_count += 1 # Update state self.state[mode] = self.next_state[mode] + self.original_state = next_state # End of one episode self.save_episode_result(mode) # Reset environment @@ -195,11 +198,14 @@ def time_to_learn(self): - The agent is not on exploration stage - There are enough experiences in replay buffer """ - if self.step_count > self.cfg['exploration_steps'] and self.step_count % self.cfg['network_update_frequency'] == 0: + if self.step_count > self.cfg['exploration_steps'] and self.step_count % self.cfg['network_update_steps'] == 0: return True else: return False + def update_target_net(self): + pass + def learn(self): mode = 'Train' batch = self.replay.sample(['state', 'action', 'reward', 'next_state', 'mask'], self.cfg['batch_size']) diff --git a/agents/__init__.py b/agents/__init__.py index d5bd96f..72d432c 100644 --- a/agents/__init__.py +++ b/agents/__init__.py @@ -7,6 +7,8 @@ from .AveragedDQN import AveragedDQN from .BootstrappedDQN import BootstrappedDQN from .NoisyNetDQN import NoisyNetDQN +from .MeDQN_Real import MeDQN_Real +from .MeDQN_Uniform import MeDQN_Uniform from .REINFORCE import REINFORCE from .ActorCritic import ActorCritic @@ -14,6 +16,4 @@ from .SAC import SAC from .DDPG import DDPG from .TD3 import TD3 -from .RepOffPG import RepOffPG -from .RepOnPG import RepOnPG from .RPG import RPG \ No newline at end of file diff --git a/analysis.py b/analysis.py index 9b69945..5a47a80 100755 --- a/analysis.py +++ b/analysis.py @@ -59,15 +59,16 @@ def analyze(exp, runs=1): 'ppo': [1, 2, 3, 4, 5, 6], 'rpg': [7, 8, 9, 10, 11, 12] } - if exp == 'rpg': + if exp == 'RPG': for i in range(6): for mode in ['Test']: - expIndexModeList = [['rpg', indexes['ppo'][i], mode], ['rpg', indexes['rpg'][i], mode]] + expIndexModeList = [['RPG', indexes['ppo'][i], mode], ['RPG', indexes['rpg'][i], mode]] plotter.plot_expIndexModeList(expIndexModeList, f'{mode}_{envs[i]}') if __name__ == "__main__": - unfinished_index('rpg', runs=30) - memory_info('rpg', runs=30) - time_info('rpg', runs=30) - analyze('rpg', runs=30) \ No newline at end of file + exp, runs = 'RPG', 30 + unfinished_index(exp, runs=runs) + memory_info(exp, runs=runs) + time_info(exp, runs=runs) + analyze(exp, runs=runs) \ No newline at end of file diff --git a/components/exploration.py b/components/exploration.py index e8833d2..9437a6b 100644 --- a/components/exploration.py +++ b/components/exploration.py @@ -47,6 +47,9 @@ def select_action(self, q_values, step_count): else: action = np.argmax(q_values) return action + + def get_epsilon(self, step_count): + return self.bound(self.start + step_count * self.inc, self.end) class ExponentialEpsilonGreedy(BaseExploration): @@ -70,4 +73,7 @@ def select_action(self, q_values, step_count): action = np.random.randint(0, len(q_values)) else: action = np.argmax(q_values) - return action \ No newline at end of file + return action + + def get_epsilon(self, step_count): + return self.bound(self.start * math.pow(self.decay, step_count), self.end) \ No newline at end of file diff --git a/components/network.py b/components/network.py index 00866df..ebb2b55 100644 --- a/components/network.py +++ b/components/network.py @@ -3,14 +3,18 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributions import Categorical, Normal +# torch.autograd.set_detect_anomaly(True) activations = { 'Linear': nn.Identity(), 'ReLU': nn.ReLU(), + 'ELU': nn.ELU(), + 'Softplus': nn.Softplus(), 'LeakyReLU': nn.LeakyReLU(), 'Tanh': nn.Tanh(), 'Sigmoid': nn.Sigmoid(), + 'Hardsigmoid': nn.Hardsigmoid(), 'Softmax-1': nn.Softmax(dim=-1), 'Softmax0': nn.Softmax(dim=0), 'Softmax1': nn.Softmax(dim=1), @@ -94,7 +98,7 @@ def __init__(self, layer_dims, hidden_act='ReLU', output_act='Linear', init_type layer_init( nn.Linear(layer_dims[i], layer_dims[i+1], bias=True), init_type=init_type, - nonlinearity=act, + nonlinearity=act, w_scale=w_s ) ) @@ -102,9 +106,7 @@ def __init__(self, layer_dims, hidden_act='ReLU', output_act='Linear', init_type self.mlp = nn.Sequential(*layers) def forward(self, x): - for layer in self.mlp: - x = layer(x) - return x + return self.mlp(x) class NoisyMLP(nn.Module): @@ -122,9 +124,7 @@ def __init__(self, layer_dims, hidden_act='ReLU', output_act='Linear'): self.mlp = nn.Sequential(*layers) def forward(self, x): - for layer in self.mlp: - x = layer(x) - return x + return self.mlp(x) def reset_noise(self): for layer in self.mlp: @@ -132,6 +132,64 @@ def reset_noise(self): layer.reset_noise() +class Conv2dLayers(nn.Module): + ''' + Multiple Conv2d layers + ''' + def __init__(self, layer_dims, hidden_act='ReLU', output_act='Linear'): + super().__init__() + # Create layers + layers = [] + for i in range(len(layer_dims)-1): + layers.append( + layer_init( + nn.Conv2d(layer_dims[i], layer_dims[i+1], kernel_size=len(layer_dims)-i, stride=1), + nonlinearity=hidden_act + ) + ) + layers.append(activations[hidden_act]) + layers.append( + layer_init( + nn.Conv2d(layer_dims[-1], layer_dims[-1], kernel_size=1, stride=1), + nonlinearity=output_act + ) + ) + layers.append(activations[output_act]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + return self.conv(x) + + +class ConvTranspose2dLayers(nn.Module): + ''' + Multiple ConvTranspose2d layers + ''' + def __init__(self, layer_dims, hidden_act='ReLU', output_act='Sigmoid'): + super().__init__() + # Create layers + layers = [] + layers.append( + layer_init( + nn.ConvTranspose2d(layer_dims[0], layer_dims[0], kernel_size=1, stride=1), + nonlinearity=hidden_act + ) + ) + for i in range(len(layer_dims)-1): + layers.append(activations[hidden_act]) + layers.append( + layer_init( + nn.ConvTranspose2d(layer_dims[i], layer_dims[i+1], kernel_size=i+2, stride=1), + nonlinearity=hidden_act + ) + ) + layers.append(activations[output_act]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + return self.conv(x) + + class Conv2d_Atari(nn.Module): ''' 2D convolution neural network for Atari games diff --git a/components/normalizer.py b/components/normalizer.py index 90cc292..235f3e0 100644 --- a/components/normalizer.py +++ b/components/normalizer.py @@ -40,6 +40,18 @@ def __init__(self): RescaleNormalizer.__init__(self, 1.0 / 255) +class RescaleShiftNormalizer(BaseNormalizer): + def __init__(self, coef=2.0, bias=-1.0): + BaseNormalizer.__init__(self) + self.coef = coef + self.bias = bias + + def __call__(self, x): + if not isinstance(x, torch.Tensor): + x = np.asarray(x) + return self.coef * x + self.bias + + class SignNormalizer(BaseNormalizer): def __call__(self, x): return np.sign(x) diff --git a/components/replay.py b/components/replay.py index e3de808..04a33d5 100644 --- a/components/replay.py +++ b/components/replay.py @@ -39,7 +39,7 @@ def get(self, keys, data_size): class FiniteReplay(object): ''' - Finite replay buffer to store experiences + Finite replay buffer to store experiences: FIFO (first in, firt out) ''' def __init__(self, memory_size, keys=None): if keys is None: @@ -64,6 +64,8 @@ def add(self, data): self.full = True def get(self, keys, data_size, detach=False): + # Get first several samples (without replacement) + data_size = min(self.size(), data_size) data = [getattr(self, k)[:data_size] for k in keys] data = map(lambda x: torch.stack(x), data) if detach: @@ -72,12 +74,8 @@ def get(self, keys, data_size, detach=False): return Entry(*list(data)) def sample(self, keys, batch_size, detach=False): - ''' - if self.size() < batch_size: - return None - ''' + # Sampling with replacement idxs = np.random.randint(0, self.size(), size=batch_size) - # data = [getattr(self, k)[idxs] for k in keys] data = [[getattr(self, k)[idx] for idx in idxs] for k in keys] data = map(lambda x: torch.stack(x), data) if detach: @@ -98,4 +96,41 @@ def size(self): if self.full: return self.memory_size else: - return self.pos \ No newline at end of file + return self.pos + + +class ContinousUniformSampler(object): + ''' + A uniform sampler for continous space + ''' + def __init__(self, shape, normalizer, device): + self.shape = shape + self.normalizer = normalizer + self.device = device + self.reset() + + def reset(self): + self.low = np.inf * np.ones(self.shape) + self.high = -np.inf * np.ones(self.shape) + + def update_bound(self, data): + self.low = np.minimum(self.low, data) + self.high = np.maximum(self.high, data) + + def sample(self, batch_size): + data = np.random.uniform(low=self.low, high=self.high, size=tuple([batch_size]+list(self.shape))) + data = to_tensor(self.normalizer(data), self.device) + return data + + +class DiscreteUniformSampler(ContinousUniformSampler): + ''' + A uniform sampler for discrete space + ''' + def __init__(self, shape, normalizer, device): + super().__init__(shape, normalizer, device) + + def sample(self, batch_size): + data = np.random.randint(low=np.rint(self.low), high=np.rint(self.high)+1, size=tuple([batch_size]+list(self.shape))) + data = to_tensor(self.normalizer(data), self.device) + return data \ No newline at end of file diff --git a/configs/MERL_acrobot_dqn.json b/configs/MERL_acrobot_dqn.json new file mode 100644 index 0000000..9abbc3f --- /dev/null +++ b/configs/MERL_acrobot_dqn.json @@ -0,0 +1,39 @@ +{ + "env": [ + { + "name": ["Acrobot-v1"], + "max_episode_steps": [-1], + "input_type": ["feature"] + } + ], + "agent": [{"name": ["DQN"]}], + "train_steps": [1e5], + "test_per_episodes": [-1], + "device": ["cpu"], + "hidden_layers": [[32,32]], + "memory_type": ["FiniteReplay"], + "memory_size": [32, 1e4], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [1e3], + "epsilon_steps": [1e3], + "epsilon_start": [1.0], + "epsilon_end": [0.01], + "epsilon_decay": [0.999], + "loss": ["MSELoss"], + "optimizer": [ + { + "name": ["Adam"], + "kwargs": [{"lr": [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]}] + } + ], + "batch_size": [32], + "display_interval": [50], + "rolling_score_window": [{"Train": [20], "Test": [5]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [5], + "target_network_update_steps": [100], + "network_update_steps": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/MERL_acrobot_medqn.json b/configs/MERL_acrobot_medqn.json new file mode 100644 index 0000000..c5c5b0f --- /dev/null +++ b/configs/MERL_acrobot_medqn.json @@ -0,0 +1,53 @@ +{ + "env": [ + { + "name": ["Acrobot-v1"], + "max_episode_steps": [-1], + "input_type": ["feature"] + } + ], + "agent": [ + { + "name": ["MeDQN_Real"], + "consod_start": [0.01], + "consod_end": [4, 2, 1], + "consod_epoch": [4, 2, 1], + "consod_size": [1e4] + }, + { + "name": ["MeDQN_Uniform"], + "consod_start": [0.01], + "consod_end": [4, 2, 1], + "consod_epoch": [4, 2, 1] + } + ], + "train_steps": [1e5], + "test_per_episodes": [-1], + "device": ["cpu"], + "hidden_layers": [[32,32]], + "memory_type": ["FiniteReplay"], + "memory_size": [1e4], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [1e3], + "epsilon_steps": [1e3], + "epsilon_start": [1.0], + "epsilon_end": [0.01], + "epsilon_decay": [0.999], + "loss": ["MSELoss"], + "optimizer": [ + { + "name": ["Adam"], + "kwargs": [{"lr": [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]}] + } + ], + "batch_size": [32], + "display_interval": [50], + "rolling_score_window": [{"Train": [20], "Test": [5]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [5], + "target_network_update_steps": [100], + "network_update_steps": [1, 2, 4, 8], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/MERL_catcher_dqn.json b/configs/MERL_catcher_dqn.json new file mode 100644 index 0000000..c5edd42 --- /dev/null +++ b/configs/MERL_catcher_dqn.json @@ -0,0 +1,39 @@ +{ + "env": [ + { + "name": ["Catcher-PLE-v0"], + "max_episode_steps": [2000], + "input_type": ["feature"] + } + ], + "agent": [{"name": ["DQN"]}], + "train_steps": [1.5e6], + "test_per_episodes": [-1], + "device": ["cpu"], + "hidden_layers": [[64,64]], + "memory_type": ["FiniteReplay"], + "memory_size": [32, 1e4], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [1e3], + "epsilon_steps": [1e3], + "epsilon_start": [1.0], + "epsilon_end": [0.01], + "epsilon_decay": [0.999], + "loss": ["MSELoss"], + "optimizer": [ + { + "name": ["RMSprop"], + "kwargs": [{"lr": [1e-3, 3e-4, 1e-4, 3e-5, 1e-5]}] + } + ], + "batch_size": [32], + "display_interval": [100], + "rolling_score_window": [{"Train": [100], "Test": [10]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [5], + "target_network_update_steps": [200], + "network_update_steps": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/MERL_catcher_medqn.json b/configs/MERL_catcher_medqn.json new file mode 100644 index 0000000..3dbff74 --- /dev/null +++ b/configs/MERL_catcher_medqn.json @@ -0,0 +1,53 @@ +{ + "env": [ + { + "name": ["Catcher-PLE-v0"], + "max_episode_steps": [2000], + "input_type": ["feature"] + } + ], + "agent": [ + { + "name": ["MeDQN_Real"], + "consod_start": [0.01], + "consod_end": [4, 2, 1], + "consod_epoch": [4, 2, 1], + "consod_size": [1e4] + }, + { + "name": ["MeDQN_Uniform"], + "consod_start": [0.01], + "consod_end": [4, 2, 1], + "consod_epoch": [4, 2, 1] + } + ], + "train_steps": [1.5e6], + "test_per_episodes": [-1], + "device": ["cpu"], + "hidden_layers": [[64,64]], + "memory_type": ["FiniteReplay"], + "memory_size": [32], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [1e3], + "epsilon_steps": [1e3], + "epsilon_start": [1.0], + "epsilon_end": [0.01], + "epsilon_decay": [0.999], + "loss": ["MSELoss"], + "optimizer": [ + { + "name": ["RMSprop"], + "kwargs": [{"lr": [1e-3, 3e-4, 1e-4, 3e-5, 1e-5]}] + } + ], + "batch_size": [32], + "display_interval": [100], + "rolling_score_window": [{"Train": [100], "Test": [10]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [5], + "target_network_update_steps": [200], + "network_update_steps": [1, 2, 4, 8], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/MERL_copter_dqn.json b/configs/MERL_copter_dqn.json new file mode 100644 index 0000000..9a6d541 --- /dev/null +++ b/configs/MERL_copter_dqn.json @@ -0,0 +1,39 @@ +{ + "env": [ + { + "name": ["Pixelcopter-PLE-v0"], + "max_episode_steps": [500], + "input_type": ["feature"] + } + ], + "agent": [{"name": ["DQN"]}], + "train_steps": [2e6], + "test_per_episodes": [-1], + "device": ["cpu"], + "hidden_layers": [[64,64]], + "memory_type": ["FiniteReplay"], + "memory_size": [32, 1e4], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [1e3], + "epsilon_steps": [1e3], + "epsilon_start": [1.0], + "epsilon_end": [0.01], + "epsilon_decay": [0.999], + "loss": ["MSELoss"], + "optimizer": [ + { + "name": ["RMSprop"], + "kwargs": [{"lr": [1e-3, 3e-4, 1e-4, 3e-5, 1e-5]}] + } + ], + "batch_size": [32], + "display_interval": [100], + "rolling_score_window": [{"Train": [100], "Test": [10]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [5], + "target_network_update_steps": [200], + "network_update_steps": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/MERL_copter_medqn.json b/configs/MERL_copter_medqn.json new file mode 100644 index 0000000..cf1594f --- /dev/null +++ b/configs/MERL_copter_medqn.json @@ -0,0 +1,53 @@ +{ + "env": [ + { + "name": ["Pixelcopter-PLE-v0"], + "max_episode_steps": [500], + "input_type": ["feature"] + } + ], + "agent": [ + { + "name": ["MeDQN_Real"], + "consod_start": [0.01], + "consod_end": [4, 2, 1], + "consod_epoch": [4, 2, 1], + "consod_size": [1e4] + }, + { + "name": ["MeDQN_Uniform"], + "consod_start": [0.01], + "consod_end": [4, 2, 1], + "consod_epoch": [4, 2, 1] + } + ], + "train_steps": [2e6], + "test_per_episodes": [-1], + "device": ["cpu"], + "hidden_layers": [[64,64]], + "memory_type": ["FiniteReplay"], + "memory_size": [32], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [1e3], + "epsilon_steps": [1e3], + "epsilon_start": [1.0], + "epsilon_end": [0.01], + "epsilon_decay": [0.999], + "loss": ["MSELoss"], + "optimizer": [ + { + "name": ["RMSprop"], + "kwargs": [{"lr": [1e-3, 3e-4, 1e-4, 3e-5, 1e-5]}] + } + ], + "batch_size": [32], + "display_interval": [100], + "rolling_score_window": [{"Train": [100], "Test": [10]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [5], + "target_network_update_steps": [200], + "network_update_steps": [1, 2, 4, 8], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/MERL_mc_dqn.json b/configs/MERL_mc_dqn.json new file mode 100644 index 0000000..6fd6a62 --- /dev/null +++ b/configs/MERL_mc_dqn.json @@ -0,0 +1,39 @@ +{ + "env": [ + { + "name": ["MountainCar-v0"], + "max_episode_steps": [1000], + "input_type": ["feature"] + } + ], + "agent": [{"name": ["DQN"]}], + "train_steps": [1e5], + "test_per_episodes": [-1], + "device": ["cpu"], + "hidden_layers": [[32,32]], + "memory_type": ["FiniteReplay"], + "memory_size": [100, 300, 1000, 3000, 10000], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [1e3], + "epsilon_steps": [1e3], + "epsilon_start": [1.0], + "epsilon_end": [0.01], + "epsilon_decay": [0.999], + "loss": ["MSELoss"], + "optimizer": [ + { + "name": ["Adam"], + "kwargs": [{"lr": [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]}] + } + ], + "batch_size": [32], + "display_interval": [20], + "rolling_score_window": [{"Train": [20], "Test": [5]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [5], + "target_network_update_steps": [100], + "network_update_steps": [1, 2, 4, 8], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/MERL_mc_dqn_small.json b/configs/MERL_mc_dqn_small.json new file mode 100644 index 0000000..ddb914d --- /dev/null +++ b/configs/MERL_mc_dqn_small.json @@ -0,0 +1,39 @@ +{ + "env": [ + { + "name": ["MountainCar-v0"], + "max_episode_steps": [1000], + "input_type": ["feature"] + } + ], + "agent": [{"name": ["DQN"]}], + "train_steps": [1e5], + "test_per_episodes": [-1], + "device": ["cpu"], + "hidden_layers": [[32,32]], + "memory_type": ["FiniteReplay"], + "memory_size": [32], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [1e3], + "epsilon_steps": [1e3], + "epsilon_start": [1.0], + "epsilon_end": [0.01], + "epsilon_decay": [0.999], + "loss": ["MSELoss"], + "optimizer": [ + { + "name": ["Adam"], + "kwargs": [{"lr": [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]}] + } + ], + "batch_size": [32], + "display_interval": [20], + "rolling_score_window": [{"Train": [20], "Test": [5]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [5], + "target_network_update_steps": [100], + "network_update_steps": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/MERL_mc_medqn.json b/configs/MERL_mc_medqn.json new file mode 100644 index 0000000..e80c378 --- /dev/null +++ b/configs/MERL_mc_medqn.json @@ -0,0 +1,53 @@ +{ + "env": [ + { + "name": ["MountainCar-v0"], + "max_episode_steps": [1000], + "input_type": ["feature"] + } + ], + "agent": [ + { + "name": ["MeDQN_Real"], + "consod_start": [0.01], + "consod_end": [8, 4, 2, 1], + "consod_epoch": [16, 12, 8, 4, 2, 1], + "consod_size": [100, 300, 1000, 3000, 10000] + }, + { + "name": ["MeDQN_Uniform"], + "consod_start": [0.01], + "consod_end": [4, 2, 1], + "consod_epoch": [4, 2, 1] + } + ], + "train_steps": [1e5], + "test_per_episodes": [-1], + "device": ["cpu"], + "hidden_layers": [[32,32]], + "memory_type": ["FiniteReplay"], + "memory_size": [32], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [1e3], + "epsilon_steps": [1e3], + "epsilon_start": [1.0], + "epsilon_end": [0.01], + "epsilon_decay": [0.999], + "loss": ["MSELoss"], + "optimizer": [ + { + "name": ["Adam"], + "kwargs": [{"lr": [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]}] + } + ], + "batch_size": [32], + "display_interval": [20], + "rolling_score_window": [{"Train": [20], "Test": [5]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [5], + "target_network_update_steps": [100], + "network_update_steps": [1, 2, 4, 8], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/MERL_mc_medqn_lambda.json b/configs/MERL_mc_medqn_lambda.json new file mode 100644 index 0000000..790d80e --- /dev/null +++ b/configs/MERL_mc_medqn_lambda.json @@ -0,0 +1,47 @@ +{ + "env": [ + { + "name": ["MountainCar-v0"], + "max_episode_steps": [1000], + "input_type": ["feature"] + } + ], + "agent": [ + { + "name": ["MeDQN_Real", "MeDQN_Uniform"], + "consod_start": [0.01, 0.1, 2, 4, 8], + "consod_end": [0.01, 0.1, 2, 4, 8], + "consod_epoch": [4], + "consod_size": [1e4] + } + ], + "train_steps": [1e5], + "test_per_episodes": [-1], + "device": ["cpu"], + "hidden_layers": [[32,32]], + "memory_type": ["FiniteReplay"], + "memory_size": [32], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [1e3], + "epsilon_steps": [1e3], + "epsilon_start": [1.0], + "epsilon_end": [0.01], + "epsilon_decay": [0.999], + "loss": ["MSELoss"], + "optimizer": [ + { + "name": ["Adam"], + "kwargs": [{"lr": [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]}] + } + ], + "batch_size": [32], + "display_interval": [20], + "rolling_score_window": [{"Train": [20], "Test": [5]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [5], + "target_network_update_steps": [100], + "network_update_steps": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/MERL_minatar_dqn.json b/configs/MERL_minatar_dqn.json new file mode 100644 index 0000000..1061484 --- /dev/null +++ b/configs/MERL_minatar_dqn.json @@ -0,0 +1,45 @@ +{ + "env": [ + { + "name": ["Asterix-MinAtar-v0", "Breakout-MinAtar-v0", "SpaceInvaders-MinAtar-v0"], + "max_episode_steps": [-1], + "input_type": ["pixel"] + }, + { + "name": ["Seaquest-MinAtar-v0"], + "max_episode_steps": [1e4], + "input_type": ["pixel"] + } + ], + "agent": [{"name": ["DQN"]}], + "train_steps": [5e6], + "test_per_episodes": [-1], + "device": ["cpu"], + "feature_dim": [128], + "hidden_layers": [[]], + "memory_type": ["FiniteReplay"], + "memory_size": [32, 1e5], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [5e3], + "epsilon_steps": [1e5], + "epsilon_start": [1.0], + "epsilon_end": [0.1], + "epsilon_decay": [0.999], + "loss": ["SmoothL1Loss"], + "optimizer": [ + { + "name": ["RMSprop"], + "kwargs": [{"lr": [3e-3, 1e-3, 3e-4, 1e-4, 3e-5], "alpha": [0.95], "centered": [true], "eps": [0.01]}] + } + ], + "batch_size": [32], + "display_interval": [500], + "rolling_score_window": [{"Train": [100], "Test": [10]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [-1], + "target_network_update_steps": [1000], + "network_update_steps": [1], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/MERL_minatar_medqn_real.json b/configs/MERL_minatar_medqn_real.json new file mode 100644 index 0000000..7a4a4f9 --- /dev/null +++ b/configs/MERL_minatar_medqn_real.json @@ -0,0 +1,53 @@ +{ + "env": [ + { + "name": ["Asterix-MinAtar-v0", "Breakout-MinAtar-v0", "SpaceInvaders-MinAtar-v0"], + "max_episode_steps": [-1], + "input_type": ["pixel"] + }, + { + "name": ["Seaquest-MinAtar-v0"], + "max_episode_steps": [1e4], + "input_type": ["pixel"] + } + ], + "agent": [ + { + "name": ["MeDQN_Real"], + "consod_start": [0.01], + "consod_end": [4, 2], + "consod_epoch": [2, 1], + "consod_size": [1e5] + } + ], + "train_steps": [5e6], + "test_per_episodes": [-1], + "device": ["cpu"], + "feature_dim": [128], + "hidden_layers": [[]], + "memory_type": ["FiniteReplay"], + "memory_size": [32], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [5e3], + "epsilon_steps": [1e5], + "epsilon_start": [1.0], + "epsilon_end": [0.1], + "epsilon_decay": [0.999], + "loss": ["SmoothL1Loss"], + "optimizer": [ + { + "name": ["RMSprop"], + "kwargs": [{"lr": [3e-3, 1e-3, 3e-4, 1e-4, 3e-5], "alpha": [0.95], "centered": [true], "eps": [0.01]}] + } + ], + "batch_size": [32], + "display_interval": [500], + "rolling_score_window": [{"Train": [100], "Test": [10]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [-1], + "target_network_update_steps": [1000], + "network_update_steps": [4, 8, 16, 32], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/MERL_minatar_medqn_uniform.json b/configs/MERL_minatar_medqn_uniform.json new file mode 100644 index 0000000..6de42d2 --- /dev/null +++ b/configs/MERL_minatar_medqn_uniform.json @@ -0,0 +1,52 @@ +{ + "env": [ + { + "name": ["Asterix-MinAtar-v0", "Breakout-MinAtar-v0", "SpaceInvaders-MinAtar-v0"], + "max_episode_steps": [-1], + "input_type": ["pixel"] + }, + { + "name": ["Seaquest-MinAtar-v0"], + "max_episode_steps": [1e4], + "input_type": ["pixel"] + } + ], + "agent": [ + { + "name": ["MeDQN_Uniform"], + "consod_start": [0.01], + "consod_end": [4, 2], + "consod_epoch": [2, 1] + } + ], + "train_steps": [5e6], + "test_per_episodes": [-1], + "device": ["cpu"], + "feature_dim": [128], + "hidden_layers": [[]], + "memory_type": ["FiniteReplay"], + "memory_size": [32], + "exploration_type": ["LinearEpsilonGreedy"], + "exploration_steps": [5e3], + "epsilon_steps": [1e5], + "epsilon_start": [1.0], + "epsilon_end": [0.1], + "epsilon_decay": [0.999], + "loss": ["SmoothL1Loss"], + "optimizer": [ + { + "name": ["RMSprop"], + "kwargs": [{"lr": [3e-3, 1e-3, 3e-4, 1e-4, 3e-5], "alpha": [0.95], "centered": [true], "eps": [0.01]}] + } + ], + "batch_size": [32], + "display_interval": [500], + "rolling_score_window": [{"Train": [100], "Test": [10]}], + "discount": [0.99], + "seed": [1], + "show_tb": [false], + "gradient_clip": [-1], + "target_network_update_steps": [300], + "network_update_steps": [4, 8, 16, 32], + "generate_random_seed": [true] +} \ No newline at end of file diff --git a/configs/catcher.json b/configs/Maxmin_catcher.json similarity index 90% rename from configs/catcher.json rename to configs/Maxmin_catcher.json index 1e78611..4f96115 100755 --- a/configs/catcher.json +++ b/configs/Maxmin_catcher.json @@ -19,7 +19,7 @@ "test_per_episodes": [-1], "device": ["cpu"], "hidden_layers": [[64,64]], - "memory_type": ["Replay"], + "memory_type": ["FiniteReplay"], "memory_size": [1e4], "exploration_type": ["LinearEpsilonGreedy"], "exploration_steps": [1e3], @@ -41,7 +41,7 @@ "seed": [1], "show_tb": [false], "gradient_clip": [5], - "target_network_update_frequency": [200], - "network_update_frequency": [1], + "target_network_update_steps": [200], + "network_update_steps": [1], "generate_random_seed": [true] } \ No newline at end of file diff --git a/configs/copter.json b/configs/Maxmin_copter.json similarity index 90% rename from configs/copter.json rename to configs/Maxmin_copter.json index e712bb5..483fe45 100755 --- a/configs/copter.json +++ b/configs/Maxmin_copter.json @@ -19,7 +19,7 @@ "test_per_episodes": [-1], "device": ["cpu"], "hidden_layers": [[64,64]], - "memory_type": ["Replay"], + "memory_type": ["FiniteReplay"], "memory_size": [1e4], "exploration_type": ["LinearEpsilonGreedy"], "exploration_steps": [1e3], @@ -41,7 +41,7 @@ "seed": [1], "show_tb": [false], "gradient_clip": [5], - "target_network_update_frequency": [200], - "network_update_frequency": [1], + "target_network_update_steps": [200], + "network_update_steps": [1], "generate_random_seed": [true] } \ No newline at end of file diff --git a/configs/lunar.json b/configs/Maxmin_lunar.json similarity index 90% rename from configs/lunar.json rename to configs/Maxmin_lunar.json index fb41761..0c3c79e 100755 --- a/configs/lunar.json +++ b/configs/Maxmin_lunar.json @@ -19,7 +19,7 @@ "test_per_episodes": [-1], "device": ["cpu"], "hidden_layers": [[64,64]], - "memory_type": ["Replay"], + "memory_type": ["FiniteReplay"], "memory_size": [1e4], "exploration_type": ["LinearEpsilonGreedy"], "exploration_steps": [1e3], @@ -41,7 +41,7 @@ "seed": [1], "show_tb": [false], "gradient_clip": [5], - "target_network_update_frequency": [200], - "network_update_frequency": [1], + "target_network_update_steps": [200], + "network_update_steps": [1], "generate_random_seed": [true] } diff --git a/configs/minatar.json b/configs/Maxmin_minatar.json similarity index 91% rename from configs/minatar.json rename to configs/Maxmin_minatar.json index 1def90c..69fcea6 100755 --- a/configs/minatar.json +++ b/configs/Maxmin_minatar.json @@ -25,7 +25,7 @@ "device": ["cpu"], "feature_dim": [128], "hidden_layers": [[]], - "memory_type": ["Replay"], + "memory_type": ["FiniteReplay"], "memory_size": [1e5], "exploration_type": ["LinearEpsilonGreedy"], "exploration_steps": [5e3], @@ -47,7 +47,7 @@ "seed": [1], "show_tb": [false], "gradient_clip": [-1], - "target_network_update_frequency": [1000], - "network_update_frequency": [1], + "target_network_update_steps": [1000], + "network_update_steps": [1], "generate_random_seed": [true] } \ No newline at end of file diff --git a/configs/mujoco_rpg.json b/configs/RPG.json similarity index 95% rename from configs/mujoco_rpg.json rename to configs/RPG.json index 7a85c8b..ff8a2e9 100755 --- a/configs/mujoco_rpg.json +++ b/configs/RPG.json @@ -9,7 +9,7 @@ "train_steps": [3e6], "steps_per_epoch": [2048], "test_per_epochs": [4], - "agent": [{"name": ["RPG"]}], + "agent": [{"name": ["PPO", "RPG"]}], "optimizer": [ { "name": ["Adam"], diff --git a/configs/mujoco_ppo.json b/configs/mujoco_ppo.json deleted file mode 100755 index 1700f8e..0000000 --- a/configs/mujoco_ppo.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "env": [ - { - "name": ["HalfCheetah-v2", "Hopper-v2", "Walker2d-v2", "Swimmer-v2", "Ant-v2", "Reacher-v2"], - "max_episode_steps": [-1], - "input_type": ["feature"] - } - ], - "train_steps": [3e6], - "steps_per_epoch": [2048], - "test_per_epochs": [4], - "agent": [{"name": ["PPO"]}], - "optimizer": [ - { - "name": ["Adam"], - "actor_kwargs": [{"lr": [3e-4]}], - "critic_kwargs": [{"lr": [1e-3]}] - } - ], - "batch_size": [64], - "clip_ratio": [0.2], - "target_kl": [0.01], - "optimize_epochs": [10], - "gradient_clip": [2], - "hidden_layers": [[64,64]], - "hidden_act": ["Tanh"], - "display_interval": [20], - "rolling_score_window": [{"Train": [20], "Test": [5]}], - "discount": [0.99], - "gae": [0.95], - "seed": [1], - "device": ["cpu"], - "show_tb": [false], - "generate_random_seed": [true] -} \ No newline at end of file diff --git a/configs/mujoco_reinforce.json b/configs/mujoco_reinforce.json deleted file mode 100755 index 6827711..0000000 --- a/configs/mujoco_reinforce.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "env": [ - { - "name": ["HalfCheetah-v2", "Hopper-v2", "Walker2d-v2", "Swimmer-v2", "Ant-v2", "Reacher-v2"], - "max_episode_steps": [-1], - "input_type": ["feature"] - } - ], - "train_steps": [3e6], - "test_per_episodes": [10], - "agent": [{"name": ["REINFORCE"]}], - "optimizer": [ - { - "name": ["Adam"], - "actor_kwargs": [{"lr": [3e-4]}] - } - ], - "gradient_clip": [-1], - "hidden_layers": [[64,64]], - "hidden_act": ["Tanh"], - "display_interval": [20], - "rolling_score_window": [{"Train": [20], "Test": [5]}], - "discount": [0.99], - "seed": [1], - "device": ["cpu"], - "show_tb": [false], - "generate_random_seed": [true] -} \ No newline at end of file diff --git a/configs/mujoco_repoffpg.json b/configs/mujoco_repoffpg.json deleted file mode 100755 index c5de98b..0000000 --- a/configs/mujoco_repoffpg.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "env": [ - { - "name": ["HalfCheetah-v2", "Hopper-v2", "Walker2d-v2", "Swimmer-v2", "Ant-v2", "Reacher-v2"], - "max_episode_steps": [-1], - "input_type": ["feature"] - } - ], - "train_steps": [3e6], - "test_per_episodes": [10], - "exploration_steps": [1e4], - "memory_size": [1e6], - "agent": [{"name": ["RepOffPG"]}], - "optimizer": [ - { - "name": ["Adam"], - "actor_kwargs": [{"lr": [3e-4]}], - "critic_kwargs": [{"lr": [1e-3]}] - } - ], - "batch_size": [64], - "network_update_frequency": [1], - "discount": [0.99], - "polyak": [0.995], - "gradient_clip": [-1], - "hidden_layers": [[256,256]], - "hidden_act": ["ReLU"], - "display_interval": [20], - "rolling_score_window": [{"Train": [20], "Test": [5]}], - "seed": [1], - "device": ["cpu"], - "show_tb": [true], - "generate_random_seed": [true] -} \ No newline at end of file diff --git a/copyfile.sh b/copyfile.sh new file mode 100644 index 0000000..1c57996 --- /dev/null +++ b/copyfile.sh @@ -0,0 +1,8 @@ +for index in 328 159 134 174 +do + for filename in $(seq $index 720 14400) + do + printf "copy $filename\n" + cp -r ./logs/mc_medqn2/$filename/ ./logs-mc/mc_medqn2/ + done +done \ No newline at end of file diff --git a/find_config.py b/find_config.py index 7d7bd40..60c4df3 100644 --- a/find_config.py +++ b/find_config.py @@ -1,13 +1,30 @@ import os from utils.sweeper import Sweeper -agent_config = 'catcher.json' -config_file = os.path.join('./configs/', agent_config) -sweeper = Sweeper(config_file) - -# Find cfg index with certain constraint -for i in range(1, 1+sweeper.config_dicts['num_combinations']): - cfg = sweeper.generate_config_for_idx(i) - if cfg['agent']['name'] == 'MaxminDQN': - print(i, end=',') -print() \ No newline at end of file + +def find_one_run(): + agent_config = 'mc_medqn.json' + config_file = os.path.join('./configs/', agent_config) + sweeper = Sweeper(config_file) + for i in range(1, 1+sweeper.config_dicts['num_combinations']): + cfg = sweeper.generate_config_for_idx(i) + if cfg['agent']['consod_start'] == cfg['agent']['consod_end']: + print(i, end=',') + print() + + +def find_many_runs(): + l = [23,146,150,147,255,207,133,130,114,55,235,210,138,82,140,209,228,69,71,353,317] + l.sort() + print('len(l)=', len(l)) + ll = [] + for r in range(1,20): + for x in l: + ll.append(x+360*r) + print('len(ll)=', len(ll)) + print(*ll) + + +if __name__ == "__main__": + find_one_run() + # find_many_runs() \ No newline at end of file diff --git a/main.py b/main.py index 4613729..568b3d5 100644 --- a/main.py +++ b/main.py @@ -17,7 +17,7 @@ def main(argv): cfg = sweeper.generate_config_for_idx(args.config_idx) # Set config dict default value - cfg.setdefault('network_update_frequency', 1) + cfg.setdefault('network_update_steps', 1) cfg['env'].setdefault('max_episode_steps', -1) cfg.setdefault('show_tb', False) cfg.setdefault('render', False) diff --git a/plot.py b/plot.py index 41058ec..db9380a 100644 --- a/plot.py +++ b/plot.py @@ -6,15 +6,18 @@ import matplotlib import matplotlib.pyplot as plt; plt.style.use('seaborn-ticks') from matplotlib.ticker import FuncFormatter - -from utils.helper import make_dir -from utils.plotter import read_file, get_total_combination - +# Avoid Type 3 fonts: http://phyletica.org/matplotlib-fonts/ +matplotlib.rcParams['pdf.fonttype'] = 42 +matplotlib.rcParams['ps.fonttype'] = 42 # Set font family, bold, and font size #font = {'family':'normal', 'weight':'normal', 'size': 12} font = {'size': 15} matplotlib.rc('font', **font) +from utils.helper import make_dir +from utils.plotter import read_file, get_total_combination, symmetric_ema + + class Plotter(object): def __init__(self, cfg): cfg.setdefault('ci', None) @@ -42,6 +45,22 @@ def get_result(self, exp, config_idx, mode): result_list.append(result) config_idx += total_combination + # Do symmetric EMA (exponential moving average) + # Get x's and y's in form of numpy arries + xs, ys = [], [] + for result in result_list: + xs.append(result[self.x_label].to_numpy()) + ys.append(result[self.y_label].to_numpy()) + # Do symetric EMA to get new x's and y's + low = max(x[0] for x in xs) + high = min(x[-1] for x in xs) + n = min(len(x) for x in xs) + for i in range(len(xs)): + new_x, new_y, _ = symmetric_ema(xs[i], ys[i], low, high, n) + result_list[i] = result_list[i][:n] + result_list[i].loc[:, self.x_label] = new_x + result_list[i].loc[:, self.y_label] = new_y + ys = [] for result in result_list: ys.append(result[self.y_label].to_numpy()) @@ -64,7 +83,7 @@ def x_format(x, pos): 'x_label': 'Step', 'y_label': 'Average Return', 'show': False, - 'imgType': 'pdf', + 'imgType': 'png', 'ci': 'se', 'x_format': None, 'y_format': None, @@ -132,4 +151,4 @@ def learning_curve(exp, runs=1): plt.close() # close window if __name__ == "__main__": - learning_curve('rpg', 30) \ No newline at end of file + learning_curve('RPG', 30) \ No newline at end of file diff --git a/procfile b/procfile new file mode 100644 index 0000000..2b82dfe --- /dev/null +++ b/procfile @@ -0,0 +1 @@ +60 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 930b7a2..d4d0536 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ -pandas==1.1.3 -pybullet==3.0.7 -torch>=1.8.0 -numpy>=1.21.0 -gym==0.17.3 -psutil==5.7.2 -matplotlib==3.3.2 -seaborn==0.11.0 -opencv_python==4.4.0.46 \ No newline at end of file +gym==0.23.1 +gym_games==1.0.4 +matplotlib==3.5.2 +numpy==1.21.5 +opencv_python==4.5.5.64 +pandas==1.4.2 +psutil==5.9.0 +seaborn==0.11.2 +torch==1.11.0 \ No newline at end of file diff --git a/run.py b/run.py index b031d4b..84ca36a 100644 --- a/run.py +++ b/run.py @@ -9,18 +9,22 @@ def make_dir(dir): def main(argv): - sbatch_cfg = { # Account name - 'account': 'rrg-whitem', + # 'account': 'def-ashique', + 'account': 'rrg-ashique', # Job name - 'job-name': 'catcher', + # 'job-name': 'minatar_dqn', + 'job-name': 'minatar_dqn_sm', + # 'job-name': 'minatar_medqn_real', + # 'job-name': 'minatar_medqn_uniform', # Job time - 'time': '0-10:00:00', + 'time': '0-05:00:00', # GPU/CPU type 'cpus-per-task': 1, # Memory - 'mem-per-cpu': '2000M', + # 'mem-per-cpu': '2500M', + 'mem-per-cpu': '1500M', # Email address 'mail-user': 'qlan3@ualberta.ca' } @@ -29,9 +33,37 @@ def main(argv): # sbatch_cfg['job-name'], sbatch_cfg['time'], sbatch_cfg['mem-per-cpu'] = 'catcher', '0-10:00:00', '2000M' # sbatch_cfg['job-name'], sbatch_cfg['time'], sbatch_cfg['mem-per-cpu'] = 'copter', '0-05:00:00', '2000M' # sbatch_cfg['job-name'], sbatch_cfg['time'], sbatch_cfg['mem-per-cpu'] = 'lunar', '0-07:00:00', '2000M' - # sbatch_cfg['job-name'], sbatch_cfg['time'], sbatch_cfg['mem-per-cpu'] = 'minatar', '1-08:00:00', '4000M' - + # sbatch_cfg['job-name'], sbatch_cfg['time'], sbatch_cfg['mem-per-cpu'] = 'minatar', '0-05:00:00', '2500M' + + + l_dqn = [11,15,19,7,13,17,9,12,16,6,10,2,18] + l_dqn.sort() + ll_dqn = [] + for r in range(1,10): + for x in l_dqn: + ll_dqn.append(x+20*r) + + l_dqn_sm = [19,11,15,13,17,9,18,14,2,20,12] + l_dqn_sm.sort() + ll_dqn_sm = [] + for r in range(1,10): + for x in l_dqn_sm: + ll_dqn_sm.append(x+20*r) + + l_uniform = [827,267,927,155,583,691,751,351,147,747,587,277,269,273,497,357,669,433,501,509,821,205,517,577,254,270,826,746,490,510,730,430,830,734,732,736,652,888,656,892,512,496,572,592,508,352] + l_uniform.sort() + ll_uniform = [] + for r in range(1,10): + for x in l_uniform: + ll_uniform.append(x+960*r) + l_real = [195,643,27,179,423,115,403,187,191,267,419,43,351,199,31,203,357,121,41,201,125,285,129,133,213,749,429,433,517,493,501,505,489,57,432,888,884,564,648,664,644,188,416,652,276,352,340,108,256,426,106,402,566,110,510,406,410,254,414,206,574] + l_real.sort() + ll_real = [] + for r in range(1,10): + for x in l_real: + ll_real.append(x+960*r) + general_cfg = { # User name 'user': 'qlan3', @@ -40,10 +72,16 @@ def main(argv): # Check time interval in minutes 'check-time-interval': 5, # Clusters info: {name: capacity} - 'clusters': {'Cedar': 3000}, + 'clusters': {'Narval': 1000}, # Job indexes list - 'job-list': list(range(1, 30+1)) - } + # 'job-list': list(range(1, 20+1)) + # 'job-list': list(range(1, 960+1)) + # 'job-list': ll_uniform + # 'job-list': ll_real + # 'job-list': ll_dqn + 'job-list': ll_dqn_sm + # 'job-list': [] + } make_dir(f"output/{sbatch_cfg['job-name']}") submitter = Submitter(general_cfg, sbatch_cfg) diff --git a/run.sh b/run.sh index 06d724a..ea17bab 100644 --- a/run.sh +++ b/run.sh @@ -1,3 +1,3 @@ export OMP_NUM_THREADS=1 # git rev-parse --short HEAD -parallel --eta --ungroup --jobs 120 python main.py --config_file ./configs/mujoco_rpg.json --config_idx {1} ::: $(seq 1 360) \ No newline at end of file +parallel --eta --ungroup --jobs 120 python main.py --config_file ./configs/RPG.json --config_idx {1} ::: $(seq 1 360) \ No newline at end of file diff --git a/sbatch.sh b/sbatch.sh index 70dbd2d..2aab1a8 100644 --- a/sbatch.sh +++ b/sbatch.sh @@ -1,10 +1,10 @@ #!/bin/bash -# Ask SLURM to send the USR1 signal 120 seconds before end of the time limit -#SBATCH --signal=B:USR1@120 +# Ask SLURM to send the USR1 signal 300 seconds before end of the time limit +#SBATCH --signal=B:USR1@300 #SBATCH --output=output/%x/%a.txt #SBATCH --mail-type=ALL #SBATCH --mail-type=TIME_LIMIT - +#SBATCH --exclude=nc20552,nc11103,nc11126,nc10303,nc20305,nc10249,nc20325,nc11124,nc20529,nc20526,nc20342,nc20354,nc30616,nc30305,nc20133,nc10220 # --------------------------------------------------------------------- echo "Current working directory: `pwd`" echo "Starting run at: `date`" @@ -21,13 +21,16 @@ cleanup() dest=./logs/$SLURM_JOB_NAME/ echo "Source directory: $sour" echo "Destination directory: $dest" - cp -r $sour $dest + cp -rf $sour $dest } # Call `cleanup` once we receive USR1 or EXIT signal trap 'cleanup' USR1 EXIT # --------------------------------------------------------------------- export OMP_NUM_THREADS=1 +module load gcc/9.3.0 arrow/2.0.0 python/3.7 scipy-stack +source ~/envs/gym/bin/activate python main.py --config_file ./configs/${SLURM_JOB_NAME}.json --config_idx $SLURM_ARRAY_TASK_ID --slurm_dir $SLURM_TMPDIR +# python main.py --config_file ./configs/${SLURM_JOB_NAME}.json --config_idx $SLURM_ARRAY_TASK_ID # --------------------------------------------------------------------- echo "Job finished with exit code $? at: `date`" # --------------------------------------------------------------------- \ No newline at end of file diff --git a/unfinish_job.py b/unfinish_job.py new file mode 100644 index 0000000..d015814 --- /dev/null +++ b/unfinish_job.py @@ -0,0 +1,48 @@ +import os +import sys +import json +import argparse +import numpy as np +import matplotlib.pyplot as plt + +parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +os.sys.path.insert(0, parentdir) + +from utils.sweeper import Sweeper + +exp = 'minatar_me_vae' +l = [1124,996,1216,1360,928,732,720,608,1018,562,1178,846,1014,842,1290,970,511,963,507,523,1335,971,979,1339,1323,1019,489,509,989,1289,1145,1309,1153,993,1149,841] +ll = [] +for r in range(10): + for x in l: + ll.append(x+1440*r) +ll.sort() + +file_name='log.txt' +max_line_length=10000 + +config_file = f'./configs/{exp}.json' +sweeper = Sweeper(config_file) +# Read a list of logs +print(f'[{exp}]: ', end=' ') +for i in ll: + log_file = f'./logs/{exp}/{i}/{file_name}' + try: + with open(log_file, 'r') as f: + # Get last line + try: + f.seek(-max_line_length, os.SEEK_END) + except IOError: + # either file is too small, or too many lines requested + f.seek(0) + last_line = f.readlines()[-1] + # Get time info in last line + try: + t = float(last_line.split(' ')[-2]) + except: + print(i, end=', ') + continue + except: + print(i, end=', ') + continue +print() \ No newline at end of file diff --git a/utils/plotter.py b/utils/plotter.py index c8e6a8d..36ca9b0 100644 --- a/utils/plotter.py +++ b/utils/plotter.py @@ -4,8 +4,12 @@ import numpy as np import pandas as pd import seaborn as sns; sns.set(style="ticks"); sns.set_context("paper") #sns.set_context("talk") +import matplotlib import matplotlib.pyplot as plt from matplotlib.ticker import FuncFormatter +# Avoid Type 3 fonts in matplotlib plots: http://phyletica.org/matplotlib-fonts/ +matplotlib.rcParams['pdf.fonttype'] = 42 +matplotlib.rcParams['ps.fonttype'] = 42 from utils.helper import make_dir from utils.sweeper import Sweeper