Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Noisy DQN #2152

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions examples/torch/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

Here it creates a gym environment CartPole, and trains a DQN with 50k steps.
"""
import math

import click
import gym
import numpy as np
Expand All @@ -29,7 +31,7 @@
from garage.torch.q_functions import DiscreteCNNQFunction
from garage.trainer import Trainer

hyperparams = dict(n_epochs=500,
hyperparams = dict(n_epochs=1000,
steps_per_epoch=20,
sampler_batch_size=500,
lr=1e-4,
Expand All @@ -39,6 +41,10 @@
target_update_freq=2,
buffer_batch_size=32,
max_epsilon=1.0,
double=True,
dueling=False,
noisy=True,
noisy_sigma=0.5,
min_epsilon=0.01,
decay_ratio=0.1,
buffer_size=int(1e4),
Expand All @@ -54,11 +60,13 @@
@click.option('--seed', default=24)
@click.option('--n', type=int, default=psutil.cpu_count(logical=False))
@click.option('--buffer_size', type=int, default=None)
@click.option('--n_steps', type=float, default=None)
@click.option('--max_episode_length', type=int, default=None)
def main(env=None,
seed=24,
n=psutil.cpu_count(logical=False),
buffer_size=None,
n_steps=None,
max_episode_length=None):
"""Wrapper to setup the logging directory.

Expand All @@ -73,6 +81,9 @@ def main(env=None,
buffer_size (int): size of the replay buffer in transitions. If None,
defaults to hyperparams['buffer_size']. This is used by the
integration tests.
n_steps (float): Total number of environment steps to run for, not
not including evaluation. If this is not None, n_epochs will
be recalculated based on this value.
max_episode_length (int): Max length of an episode. If None, defaults
to the timelimit specific to the environment. Used by integration
tests.
Expand All @@ -81,6 +92,10 @@ def main(env=None,
env += 'NoFrameskip-v4'
logdir = 'data/local/experiment/' + env

if n_steps is not None:
hyperparams['n_epochs'] = math.ceil(
int(n_steps) / (hyperparams['steps_per_epoch'] *
hyperparams['sampler_batch_size']))
if buffer_size is not None:
hyperparams['buffer_size'] = buffer_size

Expand All @@ -93,7 +108,7 @@ def main(env=None,


# pylint: disable=unused-argument
@wrap_experiment(snapshot_mode='gap_overwrite', snapshot_gap=30)
@wrap_experiment(snapshot_mode='gap_overwrite', snapshot_gap=50)
def dqn_atari(ctxt=None,
env=None,
seed=24,
Expand Down Expand Up @@ -147,19 +162,25 @@ def dqn_atari(ctxt=None,
hidden_channels=hyperparams['hidden_channels'],
kernel_sizes=hyperparams['kernel_sizes'],
strides=hyperparams['strides'],
dueling=hyperparams['dueling'],
noisy=hyperparams['noisy'],
noisy_sigma=hyperparams['noisy_sigma'],
hidden_w_init=(
lambda x: torch.nn.init.orthogonal_(x, gain=np.sqrt(2))),
hidden_sizes=hyperparams['hidden_sizes'],
is_image=True)

policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
exploration_policy = EpsilonGreedyPolicy(
env_spec=env.spec,
policy=policy,
total_timesteps=num_timesteps,
max_epsilon=hyperparams['max_epsilon'],
min_epsilon=hyperparams['min_epsilon'],
decay_ratio=hyperparams['decay_ratio'])

exploration_policy = policy
if not hyperparams['noisy']:
exploration_policy = EpsilonGreedyPolicy(
env_spec=env.spec,
policy=policy,
total_timesteps=num_timesteps,
max_epsilon=hyperparams['max_epsilon'],
min_epsilon=hyperparams['min_epsilon'],
decay_ratio=hyperparams['decay_ratio'])

algo = DQN(env_spec=env.spec,
policy=policy,
Expand All @@ -168,6 +189,7 @@ def dqn_atari(ctxt=None,
replay_buffer=replay_buffer,
steps_per_epoch=steps_per_epoch,
qf_lr=hyperparams['lr'],
double_q=hyperparams['double'],
clip_gradient=hyperparams['clip_gradient'],
discount=hyperparams['discount'],
min_buffer_size=hyperparams['min_buffer_size'],
Expand Down
23 changes: 19 additions & 4 deletions src/garage/torch/algos/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class DQN(RLAlgorithm):
n_train_steps (int): Training steps.
eval_env (Environment): Evaluation environment. If None, a copy of the
main environment is used for evaluation.
double_q (bool): Whether to use Double DQN.
See https://arxiv.org/abs/1509.06461.
max_episode_length_eval (int or None): Maximum length of episodes used
for off-policy evaluation. If `None`, defaults to
`env_spec.max_episode_length`.
Expand Down Expand Up @@ -67,6 +69,7 @@ def __init__(
replay_buffer,
exploration_policy=None,
eval_env=None,
double_q=True,
qf_optimizer=torch.optim.Adam,
*, # Everything after this is numbers.
steps_per_epoch=20,
Expand Down Expand Up @@ -100,6 +103,7 @@ def __init__(
self._steps_per_epoch = steps_per_epoch
self._n_train_steps = n_train_steps
self._buffer_batch_size = buffer_batch_size
self._double_q = double_q
self._discount = discount
self._reward_scale = reward_scale
self.max_episode_length = env_spec.max_episode_length
Expand Down Expand Up @@ -223,6 +227,9 @@ def _log_eval_results(self, epoch):
tabular.record('QFunction/MaxY', np.max(self._epoch_ys))
tabular.record('QFunction/AverageAbsY',
np.mean(np.abs(self._epoch_ys)))
# log noise levels if using a NoisyNet.
if hasattr(self._qf, 'log_noise'):
self._qf.log_noise('QFunction/Noisy-Sigma')

def _optimize_qf(self, timesteps):
"""Perform algorithm optimizing.
Expand All @@ -246,10 +253,18 @@ def _optimize_qf(self, timesteps):
next_inputs = next_observations
inputs = observations
with torch.no_grad():
# discrete, outputs Qs for all possible actions
target_qvals = self._target_qf(next_inputs)
best_qvals, _ = torch.max(target_qvals, 1)
best_qvals = best_qvals.unsqueeze(1)
if self._double_q:
# Use online qf to get optimal actions
selected_actions = torch.argmax(self._qf(next_inputs), axis=1)
# use target qf to get Q values for those actions
selected_actions = selected_actions.long().unsqueeze(1)
best_qvals = torch.gather(self._target_qf(next_inputs),
dim=1,
index=selected_actions)
else:
target_qvals = self._target_qf(next_inputs)
best_qvals, _ = torch.max(target_qvals, 1)
best_qvals = best_qvals.unsqueeze(1)

rewards_clipped = rewards
if self._clip_reward is not None:
Expand Down
2 changes: 2 additions & 0 deletions src/garage/torch/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from garage.torch.modules.gaussian_mlp_module import GaussianMLPModule
from garage.torch.modules.mlp_module import MLPModule
from garage.torch.modules.multi_headed_mlp_module import MultiHeadedMLPModule
from garage.torch.modules.noisy_mlp_module import NoisyMLPModule
# DiscreteCNNModule must go after MLPModule
from garage.torch.modules.discrete_cnn_module import DiscreteCNNModule
# yapf: enable
Expand All @@ -20,6 +21,7 @@
'DiscreteCNNModule',
'MLPModule',
'MultiHeadedMLPModule',
'NoisyMLPModule',
'GaussianMLPModule',
'GaussianMLPIndependentStdModule',
'GaussianMLPTwoHeadedModule',
Expand Down
141 changes: 126 additions & 15 deletions src/garage/torch/modules/discrete_cnn_module.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Discrete CNN Q Function."""
from dowel import tabular
import torch
from torch import nn

from garage.torch.modules import CNNModule, MLPModule
from garage.torch.modules import CNNModule, MLPModule, NoisyMLPModule


# pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305
Expand Down Expand Up @@ -31,6 +32,15 @@ class DiscreteCNNModule(nn.Module):
hidden_sizes (list[int]): Output dimension of dense layer(s) for
the MLP for mean. For example, (32, 32) means the MLP consists
of two hidden layers, each with 32 hidden units.
dueling (bool): Whether to use a dueling architecture for the
fully-connected layer.
noisy (bool): Whether to use parameter noise for the fully-connected
layers. If True, hidden_w_init, hidden_b_init, output_w_init, and
output_b_init are ignored.
noisy_sigma (float): Level of scaling to apply to the parameter noise.
This is ignored if noisy is set to False.
std_noise (float): Standard deviation of the gaussian parameters noise.
This is ignored if noisy is set to False.
mlp_hidden_nonlinearity (callable): Activation function for
intermediate dense layer(s) in the MLP. It should return
a torch.Tensor. Set it to None to maintain a linear activation.
Expand Down Expand Up @@ -73,11 +83,15 @@ def __init__(self,
hidden_channels,
strides,
hidden_sizes=(32, 32),
dueling=False,
cnn_hidden_nonlinearity=nn.ReLU,
mlp_hidden_nonlinearity=nn.ReLU,
hidden_w_init=nn.init.xavier_uniform_,
hidden_b_init=nn.init.zeros_,
paddings=0,
noisy=True,
noisy_sigma=0.5,
std_noise=1.,
padding_mode='zeros',
max_pool=False,
pool_shape=None,
Expand All @@ -90,6 +104,10 @@ def __init__(self,

super().__init__()

self._dueling = dueling
self._noisy = noisy
self._noisy_layers = None

input_var = torch.zeros(input_shape)
cnn_module = CNNModule(input_var=input_var,
kernel_sizes=kernel_sizes,
Expand All @@ -109,22 +127,90 @@ def __init__(self,
with torch.no_grad():
cnn_out = cnn_module(input_var)
flat_dim = torch.flatten(cnn_out, start_dim=1).shape[1]
mlp_module = MLPModule(flat_dim,
output_dim,
hidden_sizes,
hidden_nonlinearity=mlp_hidden_nonlinearity,
hidden_w_init=hidden_w_init,
hidden_b_init=hidden_b_init,
output_nonlinearity=output_nonlinearity,
output_w_init=output_w_init,
output_b_init=output_b_init,
layer_normalization=layer_normalization)

if mlp_hidden_nonlinearity is None:
self._module = nn.Sequential(cnn_module, nn.Flatten(), mlp_module)
if dueling:
if noisy:
self._val = NoisyMLPModule(
flat_dim,
1,
hidden_sizes,
sigma_naught=noisy_sigma,
std_noise=std_noise,
hidden_nonlinearity=mlp_hidden_nonlinearity,
output_nonlinearity=output_nonlinearity)
self._act = NoisyMLPModule(
flat_dim,
output_dim,
hidden_sizes,
sigma_naught=noisy_sigma,
std_noise=std_noise,
hidden_nonlinearity=mlp_hidden_nonlinearity,
output_nonlinearity=output_nonlinearity)
self._noisy_layers = [self._val, self._act]
else:
self._val = MLPModule(
flat_dim,
1,
hidden_sizes,
hidden_nonlinearity=mlp_hidden_nonlinearity,
hidden_w_init=hidden_w_init,
hidden_b_init=hidden_b_init,
output_nonlinearity=output_nonlinearity,
output_w_init=output_w_init,
output_b_init=output_b_init,
layer_normalization=layer_normalization)

self._act = MLPModule(
flat_dim,
output_dim,
hidden_sizes,
hidden_nonlinearity=mlp_hidden_nonlinearity,
hidden_w_init=hidden_w_init,
hidden_b_init=hidden_b_init,
output_nonlinearity=output_nonlinearity,
output_w_init=output_w_init,
output_b_init=output_b_init,
layer_normalization=layer_normalization)

if mlp_hidden_nonlinearity is None:
self._module = nn.Sequential(cnn_module, nn.Flatten())
else:
self._module = nn.Sequential(cnn_module,
mlp_hidden_nonlinearity(),
nn.Flatten())

else:
self._module = nn.Sequential(cnn_module, mlp_hidden_nonlinearity(),
nn.Flatten(), mlp_module)
mlp_module = None
if noisy:
mlp_module = NoisyMLPModule(
flat_dim,
output_dim,
hidden_sizes,
sigma_naught=noisy_sigma,
std_noise=std_noise,
hidden_nonlinearity=mlp_hidden_nonlinearity,
output_nonlinearity=output_nonlinearity)
self._noisy_layers = [mlp_module]
else:
mlp_module = MLPModule(
flat_dim,
output_dim,
hidden_sizes,
hidden_nonlinearity=mlp_hidden_nonlinearity,
hidden_w_init=hidden_w_init,
hidden_b_init=hidden_b_init,
output_nonlinearity=output_nonlinearity,
output_w_init=output_w_init,
output_b_init=output_b_init,
layer_normalization=layer_normalization)

if mlp_hidden_nonlinearity is None:
self._module = nn.Sequential(cnn_module, nn.Flatten(),
mlp_module)
else:
self._module = nn.Sequential(cnn_module,
mlp_hidden_nonlinearity(),
nn.Flatten(), mlp_module)

def forward(self, inputs):
"""Forward method.
Expand All @@ -137,4 +223,29 @@ def forward(self, inputs):
torch.Tensor: Output tensor of shape :math:`(N, output_dim)`.

"""
if self._dueling:
out = self._module(inputs)
val = self._val(out)
act = self._act(out)
act = act - act.mean(1).unsqueeze(1)
return val + act

return self._module(inputs)

def log_noise(self, key):
"""Log sigma levels for noisy layers.

Args:
key (str): Prefix to use for logging.

"""
if self._noisy:
layer_num = 0
for layer in self._noisy_layers:
for name, param in layer.named_parameters():
if name.endswith('weight_sigma'):
layer_num += 1
sigma_mean = float(
(param**2).mean().sqrt().data.cpu().numpy())
tabular.record(key + '_layer_' + str(layer_num),
sigma_mean)
4 changes: 2 additions & 2 deletions src/garage/torch/modules/multi_headed_mlp_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from garage.torch import NonLinearity


# pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305
# pylint: disable=abstract-method
class MultiHeadedMLPModule(nn.Module):
"""MultiHeadedMLPModule Model.

Expand Down Expand Up @@ -71,8 +73,6 @@ def __init__(self,
output_nonlinearities = self._check_parameter_for_output_layer(
'output_nonlinearities', output_nonlinearities, n_heads)

self._layers = nn.ModuleList()

prev_size = input_dim
for size in hidden_sizes:
hidden_layers = nn.Sequential()
Expand Down
Loading