Skip to content

Commit a6c3438

Browse files
committed
Add Dueling DQN
1 parent 235a677 commit a6c3438

File tree

4 files changed

+134
-15
lines changed

4 files changed

+134
-15
lines changed

examples/torch/dqn_atari.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
target_update_freq=2,
4242
buffer_batch_size=32,
4343
max_epsilon=1.0,
44+
double=True,
45+
dueling=True,
4446
min_epsilon=0.01,
4547
decay_ratio=0.1,
4648
buffer_size=int(1e4),
@@ -104,7 +106,7 @@ def main(env=None,
104106

105107

106108
# pylint: disable=unused-argument
107-
@wrap_experiment(snapshot_mode='gap_overwrite', snapshot_gap=30)
109+
@wrap_experiment(snapshot_mode='gap_overwrite', snapshot_gap=50)
108110
def dqn_atari(ctxt=None,
109111
env=None,
110112
seed=24,
@@ -162,6 +164,7 @@ def dqn_atari(ctxt=None,
162164
hidden_channels=hyperparams['hidden_channels'],
163165
kernel_sizes=hyperparams['kernel_sizes'],
164166
strides=hyperparams['strides'],
167+
dueling=hyperparams['dueling'],
165168
hidden_w_init=(
166169
lambda x: torch.nn.init.orthogonal_(x, gain=np.sqrt(2))),
167170
hidden_sizes=hyperparams['hidden_sizes'],
@@ -183,6 +186,7 @@ def dqn_atari(ctxt=None,
183186
replay_buffer=replay_buffer,
184187
steps_per_epoch=steps_per_epoch,
185188
qf_lr=hyperparams['lr'],
189+
double_q=hyperparams['double'],
186190
clip_gradient=hyperparams['clip_gradient'],
187191
discount=hyperparams['discount'],
188192
min_buffer_size=hyperparams['min_buffer_size'],

src/garage/torch/modules/discrete_cnn_module.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class DiscreteCNNModule(nn.Module):
3131
hidden_sizes (list[int]): Output dimension of dense layer(s) for
3232
the MLP for mean. For example, (32, 32) means the MLP consists
3333
of two hidden layers, each with 32 hidden units.
34+
dueling (bool): Whether to use a dueling architecture for the
35+
fully-connected layer.
3436
mlp_hidden_nonlinearity (callable): Activation function for
3537
intermediate dense layer(s) in the MLP. It should return
3638
a torch.Tensor. Set it to None to maintain a linear activation.
@@ -73,6 +75,7 @@ def __init__(self,
7375
hidden_channels,
7476
strides,
7577
hidden_sizes=(32, 32),
78+
dueling=False,
7679
cnn_hidden_nonlinearity=nn.ReLU,
7780
mlp_hidden_nonlinearity=nn.ReLU,
7881
hidden_w_init=nn.init.xavier_uniform_,
@@ -90,6 +93,8 @@ def __init__(self,
9093

9194
super().__init__()
9295

96+
self._dueling = dueling
97+
9398
input_var = torch.zeros(input_shape)
9499
cnn_module = CNNModule(input_var=input_var,
95100
kernel_sizes=kernel_sizes,
@@ -109,22 +114,54 @@ def __init__(self,
109114
with torch.no_grad():
110115
cnn_out = cnn_module(input_var)
111116
flat_dim = torch.flatten(cnn_out, start_dim=1).shape[1]
112-
mlp_module = MLPModule(flat_dim,
113-
output_dim,
114-
hidden_sizes,
115-
hidden_nonlinearity=mlp_hidden_nonlinearity,
116-
hidden_w_init=hidden_w_init,
117-
hidden_b_init=hidden_b_init,
118-
output_nonlinearity=output_nonlinearity,
119-
output_w_init=output_w_init,
120-
output_b_init=output_b_init,
121-
layer_normalization=layer_normalization)
122117

123-
if mlp_hidden_nonlinearity is None:
124-
self._module = nn.Sequential(cnn_module, nn.Flatten(), mlp_module)
118+
if dueling:
119+
self._val = MLPModule(flat_dim,
120+
1,
121+
hidden_sizes,
122+
hidden_nonlinearity=mlp_hidden_nonlinearity,
123+
hidden_w_init=hidden_w_init,
124+
hidden_b_init=hidden_b_init,
125+
output_nonlinearity=output_nonlinearity,
126+
output_w_init=output_w_init,
127+
output_b_init=output_b_init,
128+
layer_normalization=layer_normalization)
129+
self._act = MLPModule(flat_dim,
130+
output_dim,
131+
hidden_sizes,
132+
hidden_nonlinearity=mlp_hidden_nonlinearity,
133+
hidden_w_init=hidden_w_init,
134+
hidden_b_init=hidden_b_init,
135+
output_nonlinearity=output_nonlinearity,
136+
output_w_init=output_w_init,
137+
output_b_init=output_b_init,
138+
layer_normalization=layer_normalization)
139+
if mlp_hidden_nonlinearity is None:
140+
self._module = nn.Sequential(cnn_module, nn.Flatten())
141+
else:
142+
self._module = nn.Sequential(cnn_module,
143+
mlp_hidden_nonlinearity(),
144+
nn.Flatten())
145+
125146
else:
126-
self._module = nn.Sequential(cnn_module, mlp_hidden_nonlinearity(),
127-
nn.Flatten(), mlp_module)
147+
mlp_module = MLPModule(flat_dim,
148+
output_dim,
149+
hidden_sizes,
150+
hidden_nonlinearity=mlp_hidden_nonlinearity,
151+
hidden_w_init=hidden_w_init,
152+
hidden_b_init=hidden_b_init,
153+
output_nonlinearity=output_nonlinearity,
154+
output_w_init=output_w_init,
155+
output_b_init=output_b_init,
156+
layer_normalization=layer_normalization)
157+
158+
if mlp_hidden_nonlinearity is None:
159+
self._module = nn.Sequential(cnn_module, nn.Flatten(),
160+
mlp_module)
161+
else:
162+
self._module = nn.Sequential(cnn_module,
163+
mlp_hidden_nonlinearity(),
164+
nn.Flatten(), mlp_module)
128165

129166
def forward(self, inputs):
130167
"""Forward method.
@@ -137,4 +174,11 @@ def forward(self, inputs):
137174
torch.Tensor: Output tensor of shape :math:`(N, output_dim)`.
138175
139176
"""
177+
if self._dueling:
178+
out = self._module(inputs)
179+
val = self._val(out)
180+
act = self._act(out)
181+
act = act - act.mean(1).unsqueeze(1)
182+
return val + act
183+
140184
return self._module(inputs)

src/garage/torch/q_functions/discrete_cnn_q_function.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class DiscreteCNNQFunction(DiscreteCNNModule):
2727
For example, (3, 32) means there are two convolutional layers.
2828
The filter for the first conv layer outputs 3 channels
2929
and the second one outputs 32 channels.
30+
dueling (bool): Whether to use a dueling architecture for the
31+
fully-connected layer.
3032
hidden_sizes (list[int]): Output dimension of dense layer(s) for
3133
the MLP for mean. For example, (32, 32) means the MLP consists
3234
of two hidden layers, each with 32 hidden units.
@@ -70,6 +72,7 @@ def __init__(self,
7072
kernel_sizes,
7173
hidden_channels,
7274
strides,
75+
dueling=False,
7376
hidden_sizes=(32, 32),
7477
cnn_hidden_nonlinearity=torch.nn.ReLU,
7578
mlp_hidden_nonlinearity=torch.nn.ReLU,
@@ -94,6 +97,7 @@ def __init__(self,
9497
kernel_sizes=kernel_sizes,
9598
strides=strides,
9699
hidden_sizes=hidden_sizes,
100+
dueling=dueling,
97101
hidden_channels=hidden_channels,
98102
cnn_hidden_nonlinearity=cnn_hidden_nonlinearity,
99103
mlp_hidden_nonlinearity=mlp_hidden_nonlinearity,

tests/garage/torch/modules/test_discrete_cnn_module.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,73 @@ def test_output_values(output_dim, kernel_sizes, hidden_channels, strides,
6565
assert torch.all(torch.eq(output.detach(), module(obs).detach()))
6666

6767

68+
@pytest.mark.parametrize(
69+
'output_dim, kernel_sizes, hidden_channels, strides, paddings', [
70+
(1, (1, ), (32, ), (1, ), (0, )),
71+
(2, (3, ), (32, ), (1, ), (0, )),
72+
(5, (3, ), (32, ), (2, ), (0, )),
73+
(5, (5, ), (12, ), (1, ), (2, )),
74+
(5, (1, 1), (32, 64), (1, 1), (0, 0)),
75+
(10, (3, 3), (32, 64), (1, 1), (0, 0)),
76+
(10, (3, 3), (32, 64), (2, 2), (0, 0)),
77+
])
78+
def test_dueling_output_values(output_dim, kernel_sizes, hidden_channels,
79+
strides, paddings):
80+
81+
batch_size = 64
82+
input_width = 32
83+
input_height = 32
84+
in_channel = 3
85+
input_shape = (batch_size, in_channel, input_height, input_width)
86+
obs = torch.rand(input_shape)
87+
88+
module = DiscreteCNNModule(input_shape=input_shape,
89+
output_dim=output_dim,
90+
hidden_channels=hidden_channels,
91+
hidden_sizes=hidden_channels,
92+
kernel_sizes=kernel_sizes,
93+
strides=strides,
94+
paddings=paddings,
95+
padding_mode='zeros',
96+
dueling=True,
97+
hidden_w_init=nn.init.ones_,
98+
output_w_init=nn.init.ones_,
99+
is_image=False)
100+
101+
cnn = CNNModule(input_var=obs,
102+
hidden_channels=hidden_channels,
103+
kernel_sizes=kernel_sizes,
104+
strides=strides,
105+
paddings=paddings,
106+
padding_mode='zeros',
107+
hidden_w_init=nn.init.ones_,
108+
is_image=False)
109+
flat_dim = torch.flatten(cnn(obs).detach(), start_dim=1).shape[1]
110+
111+
mlp_adv = MLPModule(
112+
flat_dim,
113+
output_dim,
114+
hidden_channels,
115+
hidden_w_init=nn.init.ones_,
116+
output_w_init=nn.init.ones_,
117+
)
118+
119+
mlp_val = MLPModule(
120+
flat_dim,
121+
1,
122+
hidden_channels,
123+
hidden_w_init=nn.init.ones_,
124+
output_w_init=nn.init.ones_,
125+
)
126+
127+
cnn_out = cnn(obs)
128+
val = mlp_val(torch.flatten(cnn_out, start_dim=1))
129+
adv = mlp_adv(torch.flatten(cnn_out, start_dim=1))
130+
output = val + (adv - adv.mean(1).unsqueeze(1))
131+
132+
assert torch.all(torch.eq(output.detach(), module(obs).detach()))
133+
134+
68135
@pytest.mark.parametrize('output_dim, hidden_channels, kernel_sizes, strides',
69136
[(1, (32, ), (1, ), (1, ))])
70137
def test_without_nonlinearity(output_dim, hidden_channels, kernel_sizes,

0 commit comments

Comments
 (0)