Skip to content

Commit

Permalink
add MeDQN
Browse files Browse the repository at this point in the history
  • Loading branch information
qlan3 committed May 24, 2022
1 parent fc75506 commit d9145e6
Show file tree
Hide file tree
Showing 50 changed files with 1,109 additions and 373 deletions.
6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
# Edit at https://www.gitignore.io/?templates=python,windows,visualstudiocode

# My ignores
logs*
*logs*
logfile
procfile
*figure*
*output*
*DS_Store*

Expand Down Expand Up @@ -170,4 +168,4 @@ $RECYCLE.BIN/
# Windows shortcuts
*.lnk

# End of https://www.gitignore.io/api/python,windows,visualstudiocode
# End of https://www.gitignore.io/api/python,windows,visualstudiocode
64 changes: 34 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@ Explorer is a PyTorch reinforcement learning framework for **exploring** new ide
- Vanilla Deep Q-learning (VanillaDQN): No target network.
- [Deep Q-Learning (DQN)](https://users.cs.duke.edu/~pdinesh/sources/MnihEtAlHassibis15NatureControlDeepRL.pdf)
- [Double Deep Q-learning (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
- [Maxmin Deep Q-learning (MaxminDQN)](https://openreview.net/pdf?id=Bkg0u3Etwr)
- [Maxmin Deep Q-learning (MaxminDQN)](https://arxiv.org/pdf/2002.06487.pdf)
- [Averaged Deep Q-learning (AveragedDQN)](https://arxiv.org/pdf/1611.01929.pdf)
- [Ensemble Deep Q-learning (EnsembleDQN)](https://arxiv.org/pdf/1611.01929.pdf)
- [Bootstrapped Deep Q-learning (BootstrappedDQN)](https://arxiv.org/abs/1602.04621)
- [NoisyNet Deep Q-learning (NoisyNetDQN)](https://arxiv.org/abs/1706.10295)
- [Bootstrapped Deep Q-learning (BootstrappedDQN)](https://arxiv.org/pdf/1602.04621.pdf)
- [NoisyNet Deep Q-learning (NoisyNetDQN)](https://arxiv.org/pdf/1706.10295.pdf)
- [REINFORCE](http://incompleteideas.net/book/RLbook2020.pdf)
- [Actor-Critic](http://incompleteideas.net/book/RLbook2020.pdf)
- [Proximal Policy Optimisation (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- [Deep Deterministic Policy Gradients (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Twin Delayed Deep Deterministic Policy Gradients (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
- [Reward Policy Gradient (RPG)](https://arxiv.org/abs/2103.05147)
- [Reward Policy Gradient (RPG)](https://arxiv.org/pdf/2103.05147.pdf)
- [Memory-efficient Deep Q-learning (MeDQN)](https://arxiv.org/pdf/2205.10868.pdf)

## To do list

Expand All @@ -33,16 +34,14 @@ Explorer is a PyTorch reinforcement learning framework for **exploring** new ide
| ├── DQN
| | ├── DDQN
| | ├── NoisyNetDQN
| | └── BootstrappedDQN
| | ├── BootstrappedDQN
| | └── MeDQN_Uniform, MeDQN_Real
| ├── Maxmin DQN ── Ensemble DQN
| └── Averaged DQN
└── REINFORCE
├── Actor-Critic
| ├── PPO ── RPG
| └── RepOnPG (experimental)
└── SAC ── DDPG
├── TD3
└── RepOffPG (experimental)
| └── PPO ── RPG
└── SAC ── DDPG ── TD3


## Requirements
Expand All @@ -51,8 +50,12 @@ Explorer is a PyTorch reinforcement learning framework for **exploring** new ide
- [PyTorch](https://pytorch.org/)
- [Gym && Gym Games](https://github.com/qlan3/gym-games): You may only install part of Gym (`classic_control, box2d`) by command `pip install 'gym[classic_control, box2d]'`.
- Optional:
- [Gym Atari](https://github.com/openai/gym/blob/master/docs/environments.md#atari)
- [Gym Mujoco](https://github.com/openai/gym/blob/master/docs/environments.md#mujoco)
- [Gym Atari](https://www.gymlibrary.ml/environments/atari/): `pip install gym[atari,accept-rom-license]`
- [Gym Mujoco](https://www.gymlibrary.ml/environments/mujoco/):
- Download MuJoCo version 1.50 from [MuJoCo website](https://www.roboti.us/download.html).
- Unzip the downloaded `mjpro150` directory into `~/.mujoco/mjpro150`, and place the activation key (the `mjkey.txt` file downloaded from [here](https://www.roboti.us/license.html)) at `~/.mujoco/mjkey.txt`.
- Install [mujoco-py](https://github.com/openai/mujoco-py): `pip install 'mujoco-py<1.50.2,>=1.50.1'`
- Install gym[mujoco]: `pip install gym[mujoco]`
- [PyBullet](https://pybullet.org/): `pip install pybullet`
- [DeepMind Control Suite](https://github.com/denisyarats/dmc2gym): `pip install git+git://github.com/denisyarats/dmc2gym.git`
- Others: Please check `requirements.txt`.
Expand All @@ -62,76 +65,77 @@ Explorer is a PyTorch reinforcement learning framework for **exploring** new ide

### Train && Test

All hyperparameters including parameters for grid search are stored in a configuration file in directory `configs`. To run an experiment, a configuration index is first used to generate a configuration dict corresponding to this specific configuration index. Then we run an experiment defined by this configuration dict. All results including log files and the model file are saved in directory `logs`. Please refer to the code for details.
All hyperparameters including parameters for grid search are stored in a configuration file in directory `configs`. To run an experiment, a configuration index is first used to generate a configuration dict corresponding to this specific configuration index. Then we run an experiment defined by this configuration dict. All results including log files are saved in directory `logs`. Please refer to the code for details.

For example, run the experiment with configuration file `catcher.json` and configuration index `1`:
For example, run the experiment with configuration file `RPG.json` and configuration index `1`:

```python main.py --config_file ./configs/catcher.json --config_idx 1```
```python main.py --config_file ./configs/RPG.json --config_idx 1```

The models are tested for one episode after every `test_per_episodes` training episodes which can be set in the configuration file.


### Grid Search (Optional)

First, we calculate the number of total combinations in a configuration file (e.g. `catcher.json`):
First, we calculate the number of total combinations in a configuration file (e.g. `RPG.json`):

`python utils/sweeper.py`

The output will be:

`Number of total combinations in catcher.json: 90`
`Number of total combinations in RPG.json: 12`

Then we run through all configuration indexes from `1` to `90`. The simplest way is a bash script:
Then we run through all configuration indexes from `1` to `12`. The simplest way is using a bash script:

``` bash
for index in {1..90}
for index in {1..12}
do
python main.py --config_file ./configs/catcher.json --config_idx $index
python main.py --config_file ./configs/RPG.json --config_idx $index
done
```

[Parallel](https://www.gnu.org/software/parallel/) is usually a better choice to schedule a large number of jobs:

``` bash
parallel --eta --ungroup python main.py --config_file ./configs/catcher.json --config_idx {1} ::: $(seq 1 90)
parallel --eta --ungroup python main.py --config_file ./configs/RPG.json --config_idx {1} ::: $(seq 1 12)
```

Any configuration index that has the same remainder (divided by the number of total combinations) should has the same configuration dict. So for multiple runs, we just need to add the number of total combinations to the configuration index. For example, 5 runs for configuration index `1`:
Any configuration index that has the same remainder (divided by the number of total combinations) should have the same configuration dict. So for multiple runs, we just need to add the number of total combinations to the configuration index. For example, 5 runs for configuration index `1`:

```
for index in 1 91 181 271 361
for index in 1 13 25 37 49
do
python main.py --config_file ./configs/catcher.json --config_idx $index
python main.py --config_file ./configs/RPG.json --config_idx $index
done
```

Or a simpler way:
```
parallel --eta --ungroup python main.py --config_file ./configs/catcher.json --config_idx {1} ::: $(seq 1 90 450)
parallel --eta --ungroup python main.py --config_file ./configs/RPG.json --config_idx {1} ::: $(seq 1 12 60)
```


### Analysis (Optional)

To analysis the experimental results, just run:
To analyze the experimental results, just run:

`python analysis.py`

Inside `analysis.py`, `unfinished_index` will print out the configuration indexes of unfinished jobs based on the existence of the result file. `memory_info` will print out the memory usage information and generate a histogram to show the distribution of memory usages in directory `logs/catcher/0`. Similarly, `time_info` will print out the time information and generate a histogram to show the distribution of time in directory `logs/catcher/0`. Finally, `analyze` will generate `csv` files that store training and test results. More functions are available in `utils/plotter.py`.
Inside `analysis.py`, `unfinished_index` will print out the configuration indexes of unfinished jobs based on the existence of the result file. `memory_info` will print out the memory usage information and generate a histogram to show the distribution of memory usages in directory `logs/RPG/0`. Similarly, `time_info` will print out the time information and generate a histogram to show the distribution of time in directory `logs/RPG/0`. Finally, `analyze` will generate `csv` files that store training and test results. Please check `analysis.py` for more details. More functions are available in `utils/plotter.py`.

Enjoy!


## Code of My Papers

- **Qingfeng Lan**, Yangchen Pan, Alona Fyshe, Martha White. **Maxmin Q-learning: Controlling the Estimation Bias of Q-learning.** ICLR, 2020. **(Poster)** [[paper]](https://openreview.net/pdf?id=Bkg0u3Etwr) [[code]](https://github.com/qlan3/Explorer/releases/tag/maxmin1.0) [[video]](https://iclr.cc/virtual/poster_Bkg0u3Etwr.html)
- **Qingfeng Lan**, Yangchen Pan, Alona Fyshe, Martha White. **Maxmin Q-learning: Controlling the Estimation Bias of Q-learning.** ICLR, 2020. **(Poster)** [[paper]](https://openreview.net/pdf?id=Bkg0u3Etwr) [[code]](https://github.com/qlan3/Explorer/releases/tag/maxmin1.0)

- **Qingfeng Lan**, Samuele Tosatto, Homayoon Farrahi, A. Rupam Mahmood. **Model-free Policy Learning with Reward Gradients.** AISTATS, 2022. **(Poster)** [[paper]](https://arxiv.org/abs/2103.05147) [[code]](https://github.com/qlan3/Explorer/tree/RPG)
- **Qingfeng Lan**, Samuele Tosatto, Homayoon Farrahi, A. Rupam Mahmood. **Model-free Policy Learning with Reward Gradients.** AISTATS, 2022. **(Poster)** [[paper]](https://arxiv.org/pdf/2103.05147.pdf) [[code]](https://github.com/qlan3/Explorer/tree/RPG)

- **Qingfeng Lan**, Yangchen Pan, Jun Luo, A. Rupam Mahmood. **Memory-efficient Reinforcement Learning with Knowledge Consolidation.** Arxiv [[paper]](https://arxiv.org/pdf/2205.10868.pdf) [[code]](https://github.com/qlan3/Explorer/)

## Cite

If you find this repo useful to your research, please cite my paper if it is related. Otherwisee, please use this bibtex to cite this repo
If you find this repo useful to your research, please cite my paper if related. Otherwise, please cite this repo:

~~~bibtex
@misc{Explorer,
Expand Down
6 changes: 2 additions & 4 deletions agents/AveragedDQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ def __init__(self, cfg):
self.Q_net_target[i].eval()
self.update_target_net_index = 0

def learn(self):
super().learn()
# Update target network
if (self.step_count // self.cfg['network_update_frequency']) % self.cfg['target_network_update_frequency'] == 0:
def update_target_net(self):
if self.step_count % self.cfg['target_network_update_steps'] == 0:
self.Q_net_target[self.update_target_net_index].load_state_dict(self.Q_net[self.update_Q_net_index].state_dict())
self.update_target_net_index = (self.update_target_net_index + 1) % self.k

Expand Down
3 changes: 0 additions & 3 deletions agents/BootstrappedDQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ def learn(self):
if self.gradient_clip > 0:
nn.utils.clip_grad_norm_(self.Q_net[0].parameters(), self.gradient_clip)
self.optimizer[0].step()
# Update target network
if (self.step_count // self.cfg['network_update_frequency']) % self.cfg['target_network_update_frequency'] == 0:
self.Q_net_target[0].load_state_dict(self.Q_net[0].state_dict())
if self.show_tb:
self.logger.add_scalar(f'Loss', loss.item(), self.step_count)

Expand Down
6 changes: 2 additions & 4 deletions agents/DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ def __init__(self, cfg):
self.Q_net_target[0].load_state_dict(self.Q_net[0].state_dict())
self.Q_net_target[0].eval()

def learn(self):
super().learn()
# Update target network
if (self.step_count // self.cfg['network_update_frequency']) % self.cfg['target_network_update_frequency'] == 0:
def update_target_net(self):
if self.step_count % self.cfg['target_network_update_steps'] == 0:
self.Q_net_target[self.update_Q_net_index].load_state_dict(self.Q_net[self.update_Q_net_index].state_dict())

def compute_q_target(self, batch):
Expand Down
5 changes: 3 additions & 2 deletions agents/MaxminDQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def learn(self):
# Choose a Q_net to udpate
self.update_Q_net_index = np.random.choice(list(range(self.k)))
super().learn()
# Update target network
if (self.step_count // self.cfg['network_update_frequency']) % self.cfg['target_network_update_frequency'] == 0:

def update_target_net(self):
if self.step_count % self.cfg['target_network_update_steps'] == 0:
for i in range(self.k):
self.Q_net_target[i].load_state_dict(self.Q_net[i].state_dict())

Expand Down
55 changes: 55 additions & 0 deletions agents/MeDQN_Real.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from agents.DQN import *


class MeDQN_Real(DQN):
'''
Implementation of MeDQN_Real (Memory-efficient DQN with real state sampling)
- Consolidatie knowledge from target Q-network to current Q-network.
- A state replay buffer is applied.
- A tiny (e.g., one mini-batch size) experience replay buffer is used in practice.
'''
def __init__(self, cfg):
# Set the consolidation batch size
if 'consod_batch_size' not in cfg['agent'].keys():
cfg['agent']['consod_batch_size'] = cfg['batch_size']
super().__init__(cfg)
self.replay = getattr(components.replay, cfg['memory_type'])(cfg['memory_size'], keys=['state', 'action', 'next_state', 'reward', 'mask'])
# Set real state sampler for knowledge consolidation
self.state_sampler = getattr(components.replay, cfg['memory_type'])(cfg['agent']['consod_size'], keys=['state'])
# Set consolidation regularization strategy
epsilon = {
'steps': float(cfg['train_steps']),
'start': cfg['agent']['consod_start'],
'end': cfg['agent']['consod_end']
}
self.consolidate = getattr(components.exploration, 'LinearEpsilonGreedy')(-1, epsilon)

def save_experience(self):
super().save_experience()
self.state_sampler.add({'state': to_tensor(self.state['Train'], self.device)})

def learn(self):
mode = 'Train'
batch = self.replay.get(['state', 'action', 'reward', 'next_state', 'mask'], self.cfg['memory_size'])
q_target = self.compute_q_target(batch)
lamda = self.consolidate.get_epsilon(self.step_count) # Compute consolidation regularization parameter
for _ in range(self.cfg['agent']['consod_epoch']):
q = self.compute_q(batch)
sample_state = self.state_sampler.sample(['state'], self.cfg['agent']['consod_batch_size']).state
# Compute loss
loss = self.loss(q, q_target)
loss += lamda * self.consolidation_loss(sample_state)
# Take an optimization step
self.optimizer[0].zero_grad()
loss.backward()
if self.gradient_clip > 0:
nn.utils.clip_grad_norm_(self.Q_net[0].parameters(), self.gradient_clip)
self.optimizer[0].step()
if self.show_tb:
self.logger.add_scalar(f'Loss', loss.item(), self.step_count)

def consolidation_loss(self, state):
q_values = self.Q_net[0](state).squeeze()
q_target_values = self.Q_net_target[0](state).squeeze().detach()
loss = nn.MSELoss(reduction='mean')(q_values, q_target_values)
return loss
66 changes: 66 additions & 0 deletions agents/MeDQN_Uniform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from agents.DQN import *


class MeDQN_Uniform(DQN):
'''
Implementation of MeDQN_Uniform (Memory-efficient DQN with uniform state sampling)
- Consolidatie knowledge from target Q-network to current Q-network.
- The bounds of state space are updated with real states frequently.
- A tiny (e.g., one mini-batch size) experience replay buffer is used in practice.
'''
def __init__(self, cfg):
# Set the consolidation batch size
if 'consod_batch_size' not in cfg['agent'].keys():
cfg['agent']['consod_batch_size'] = cfg['batch_size']
super().__init__(cfg)
self.replay = getattr(components.replay, cfg['memory_type'])(cfg['memory_size'], keys=['state', 'action', 'next_state', 'reward', 'mask'])
# Set uniform state sampler for knowledge consolidation
if 'MinAtar' in self.env_name:
self.state_sampler = DiscreteUniformSampler(
shape=self.env['Train'].observation_space.shape,
normalizer=self.state_normalizer,
device=self.device
)
else:
self.state_sampler = ContinousUniformSampler(
shape=self.env['Train'].observation_space.shape,
normalizer=self.state_normalizer,
device=self.device
)
# Set consolidation regularization strategy
epsilon = {
'steps': float(cfg['train_steps']),
'start': cfg['agent']['consod_start'],
'end': cfg['agent']['consod_end']
}
self.consolidate = getattr(components.exploration, 'LinearEpsilonGreedy')(-1, epsilon)

def save_experience(self):
super().save_experience()
self.state_sampler.update_bound(self.original_state)

def learn(self):
mode = 'Train'
batch = self.replay.get(['state', 'action', 'reward', 'next_state', 'mask'], self.cfg['memory_size'])
q_target = self.compute_q_target(batch)
lamda = self.consolidate.get_epsilon(self.step_count) # Compute consolidation regularization parameter
for _ in range(self.cfg['agent']['consod_epoch']):
q = self.compute_q(batch)
sample_state = self.state_sampler.sample(self.cfg['agent']['consod_batch_size'])
# Compute loss
loss = self.loss(q, q_target)
loss += lamda * self.consolidation_loss(sample_state)
# Take an optimization step
self.optimizer[0].zero_grad()
loss.backward()
if self.gradient_clip > 0:
nn.utils.clip_grad_norm_(self.Q_net[0].parameters(), self.gradient_clip)
self.optimizer[0].step()
if self.show_tb:
self.logger.add_scalar(f'Loss', loss.item(), self.step_count)

def consolidation_loss(self, state):
q_values = self.Q_net[0](state).squeeze()
q_target_values = self.Q_net_target[0](state).squeeze().detach()
loss = nn.MSELoss(reduction='mean')(q_values, q_target_values)
return loss
Loading

0 comments on commit d9145e6

Please sign in to comment.