Skip to content

Commit

Permalink
Add Double DQN (#2148)
Browse files Browse the repository at this point in the history
  • Loading branch information
maliesa96 authored Nov 3, 2020
1 parent 08987a1 commit cdda5fc
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ The table below summarizes the algorithms available in garage.
| REINFORCE (a.k.a. VPG) | PyTorch, TensorFlow |
| DDPG | PyTorch, TensorFlow |
| DQN | PyTorch, TensorFlow |
| DDQN | TensorFlow |
| DDQN | PyTorch, TensorFlow |
| ERWR | TensorFlow |
| NPO | TensorFlow |
| PPO | PyTorch, TensorFlow |
Expand Down
11 changes: 11 additions & 0 deletions examples/torch/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Here it creates a gym environment CartPole, and trains a DQN with 50k steps.
"""
import math

import click
import gym
import numpy as np
Expand Down Expand Up @@ -54,11 +56,13 @@
@click.option('--seed', default=24)
@click.option('--n', type=int, default=psutil.cpu_count(logical=False))
@click.option('--buffer_size', type=int, default=None)
@click.option('--n_steps', type=float, default=None)
@click.option('--max_episode_length', type=int, default=None)
def main(env=None,
seed=24,
n=psutil.cpu_count(logical=False),
buffer_size=None,
n_steps=None,
max_episode_length=None):
"""Wrapper to setup the logging directory.
Expand All @@ -73,6 +77,9 @@ def main(env=None,
buffer_size (int): size of the replay buffer in transitions. If None,
defaults to hyperparams['buffer_size']. This is used by the
integration tests.
n_steps (float): Total number of environment steps to run for, not
not including evaluation. If this is not None, n_epochs will
be recalculated based on this value.
max_episode_length (int): Max length of an episode. If None, defaults
to the timelimit specific to the environment. Used by integration
tests.
Expand All @@ -81,6 +88,10 @@ def main(env=None,
env += 'NoFrameskip-v4'
logdir = 'data/local/experiment/' + env

if n_steps is not None:
hyperparams['n_epochs'] = math.ceil(
int(n_steps) / (hyperparams['steps_per_epoch'] *
hyperparams['sampler_batch_size']))
if buffer_size is not None:
hyperparams['buffer_size'] = buffer_size

Expand Down
20 changes: 16 additions & 4 deletions src/garage/torch/algos/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class DQN(RLAlgorithm):
n_train_steps (int): Training steps.
eval_env (Environment): Evaluation environment. If None, a copy of the
main environment is used for evaluation.
double_q (bool): Whether to use Double DQN.
See https://arxiv.org/abs/1509.06461.
max_episode_length_eval (int or None): Maximum length of episodes used
for off-policy evaluation. If `None`, defaults to
`env_spec.max_episode_length`.
Expand Down Expand Up @@ -67,6 +69,7 @@ def __init__(
replay_buffer,
exploration_policy=None,
eval_env=None,
double_q=True,
qf_optimizer=torch.optim.Adam,
*, # Everything after this is numbers.
steps_per_epoch=20,
Expand Down Expand Up @@ -100,6 +103,7 @@ def __init__(
self._steps_per_epoch = steps_per_epoch
self._n_train_steps = n_train_steps
self._buffer_batch_size = buffer_batch_size
self._double_q = double_q
self._discount = discount
self._reward_scale = reward_scale
self.max_episode_length = env_spec.max_episode_length
Expand Down Expand Up @@ -246,10 +250,18 @@ def _optimize_qf(self, timesteps):
next_inputs = next_observations
inputs = observations
with torch.no_grad():
# discrete, outputs Qs for all possible actions
target_qvals = self._target_qf(next_inputs)
best_qvals, _ = torch.max(target_qvals, 1)
best_qvals = best_qvals.unsqueeze(1)
if self._double_q:
# Use online qf to get optimal actions
selected_actions = torch.argmax(self._qf(next_inputs), axis=1)
# use target qf to get Q values for those actions
selected_actions = selected_actions.long().unsqueeze(1)
best_qvals = torch.gather(self._target_qf(next_inputs),
dim=1,
index=selected_actions)
else:
target_qvals = self._target_qf(next_inputs)
best_qvals, _ = torch.max(target_qvals, 1)
best_qvals = best_qvals.unsqueeze(1)

rewards_clipped = rewards
if self._clip_reward is not None:
Expand Down
50 changes: 49 additions & 1 deletion tests/garage/torch/algos/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def setup():
steps_per_epoch = 10
sampler_batch_size = 512
num_timesteps = 100 * steps_per_epoch * sampler_batch_size

env = GymEnv('CartPole-v0')

replay_buffer = PathBuffer(capacity_in_transitions=int(1e6))
Expand All @@ -50,6 +49,7 @@ def setup():
replay_buffer=replay_buffer,
steps_per_epoch=steps_per_epoch,
qf_lr=5e-5,
double_q=False,
discount=0.9,
min_buffer_size=int(1e4),
n_train_steps=500,
Expand Down Expand Up @@ -121,6 +121,54 @@ def test_dqn_loss(setup):
assert (selected_qs == algo_selected_qs).all()


def test_double_dqn_loss(setup):
algo, env, buff, _, batch_size = setup

algo._double_q = True
trainer = Trainer(snapshot_config)
trainer.setup(algo, env, sampler_cls=LocalSampler)

paths = trainer.obtain_episodes(0, batch_size=batch_size)
buff.add_episode_batch(paths)
timesteps = buff.sample_timesteps(algo._buffer_batch_size)
timesteps_copy = copy.deepcopy(timesteps)

observations = np_to_torch(timesteps.observations)
rewards = np_to_torch(timesteps.rewards).reshape(-1, 1)
actions = np_to_torch(timesteps.actions)
next_observations = np_to_torch(timesteps.next_observations)
terminals = np_to_torch(timesteps.terminals).reshape(-1, 1)

next_inputs = next_observations
inputs = observations
with torch.no_grad():
# double Q loss
selected_actions = torch.argmax(algo._qf(next_inputs), axis=1)
# use target qf to get Q values for those actions
selected_actions = selected_actions.long().unsqueeze(1)
best_qvals = torch.gather(algo._target_qf(next_inputs),
dim=1,
index=selected_actions)

rewards_clipped = rewards
y_target = (rewards_clipped +
(1.0 - terminals) * algo._discount * best_qvals)
y_target = y_target.squeeze(1)

# optimize qf
qvals = algo._qf(inputs)
selected_qs = torch.sum(qvals * actions, axis=1)
qval_loss = F.smooth_l1_loss(selected_qs, y_target)

algo_loss, algo_targets, algo_selected_qs = algo._optimize_qf(
timesteps_copy)
env.close()

assert (qval_loss.detach() == algo_loss).all()
assert (y_target == algo_targets).all()
assert (selected_qs == algo_selected_qs).all()


def test_to_device(setup):
algo, _, _, _, _ = setup
algo._qf.to = MagicMock(name='to')
Expand Down

0 comments on commit cdda5fc

Please sign in to comment.