1
1
"""Discrete CNN Q Function."""
2
+ from dowel import tabular
2
3
import torch
3
4
from torch import nn
4
5
5
- from garage .torch .modules import CNNModule , MLPModule
6
+ from garage .torch .modules import CNNModule , MLPModule , NoisyMLPModule
6
7
7
8
8
9
# pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305
@@ -33,6 +34,13 @@ class DiscreteCNNModule(nn.Module):
33
34
of two hidden layers, each with 32 hidden units.
34
35
dueling (bool): Whether to use a dueling architecture for the
35
36
fully-connected layer.
37
+ noisy (bool): Whether to use parameter noise for the fully-connected
38
+ layers. If True, hidden_w_init, hidden_b_init, output_w_init, and
39
+ output_b_init are ignored.
40
+ noisy_sigma (float): Level of scaling to apply to the parameter noise.
41
+ This is ignored if noisy is set to False.
42
+ std_noise (float): Standard deviation of the gaussian parameters noise.
43
+ This is ignored if noisy is set to False.
36
44
mlp_hidden_nonlinearity (callable): Activation function for
37
45
intermediate dense layer(s) in the MLP. It should return
38
46
a torch.Tensor. Set it to None to maintain a linear activation.
@@ -81,6 +89,9 @@ def __init__(self,
81
89
hidden_w_init = nn .init .xavier_uniform_ ,
82
90
hidden_b_init = nn .init .zeros_ ,
83
91
paddings = 0 ,
92
+ noisy = True ,
93
+ noisy_sigma = 0.5 ,
94
+ std_noise = 1. ,
84
95
padding_mode = 'zeros' ,
85
96
max_pool = False ,
86
97
pool_shape = None ,
@@ -94,6 +105,8 @@ def __init__(self,
94
105
super ().__init__ ()
95
106
96
107
self ._dueling = dueling
108
+ self ._noisy = noisy
109
+ self ._noisy_layers = None
97
110
98
111
input_var = torch .zeros (input_shape )
99
112
cnn_module = CNNModule (input_var = input_var ,
@@ -116,26 +129,49 @@ def __init__(self,
116
129
flat_dim = torch .flatten (cnn_out , start_dim = 1 ).shape [1 ]
117
130
118
131
if dueling :
119
- self ._val = MLPModule (flat_dim ,
120
- 1 ,
121
- hidden_sizes ,
122
- hidden_nonlinearity = mlp_hidden_nonlinearity ,
123
- hidden_w_init = hidden_w_init ,
124
- hidden_b_init = hidden_b_init ,
125
- output_nonlinearity = output_nonlinearity ,
126
- output_w_init = output_w_init ,
127
- output_b_init = output_b_init ,
128
- layer_normalization = layer_normalization )
129
- self ._act = MLPModule (flat_dim ,
130
- output_dim ,
131
- hidden_sizes ,
132
- hidden_nonlinearity = mlp_hidden_nonlinearity ,
133
- hidden_w_init = hidden_w_init ,
134
- hidden_b_init = hidden_b_init ,
135
- output_nonlinearity = output_nonlinearity ,
136
- output_w_init = output_w_init ,
137
- output_b_init = output_b_init ,
138
- layer_normalization = layer_normalization )
132
+ if noisy :
133
+ self ._val = NoisyMLPModule (
134
+ flat_dim ,
135
+ 1 ,
136
+ hidden_sizes ,
137
+ sigma_naught = noisy_sigma ,
138
+ std_noise = std_noise ,
139
+ hidden_nonlinearity = mlp_hidden_nonlinearity ,
140
+ output_nonlinearity = output_nonlinearity )
141
+ self ._act = NoisyMLPModule (
142
+ flat_dim ,
143
+ output_dim ,
144
+ hidden_sizes ,
145
+ sigma_naught = noisy_sigma ,
146
+ std_noise = std_noise ,
147
+ hidden_nonlinearity = mlp_hidden_nonlinearity ,
148
+ output_nonlinearity = output_nonlinearity )
149
+ self ._noisy_layers = [self ._val , self ._act ]
150
+ else :
151
+ self ._val = MLPModule (
152
+ flat_dim ,
153
+ 1 ,
154
+ hidden_sizes ,
155
+ hidden_nonlinearity = mlp_hidden_nonlinearity ,
156
+ hidden_w_init = hidden_w_init ,
157
+ hidden_b_init = hidden_b_init ,
158
+ output_nonlinearity = output_nonlinearity ,
159
+ output_w_init = output_w_init ,
160
+ output_b_init = output_b_init ,
161
+ layer_normalization = layer_normalization )
162
+
163
+ self ._act = MLPModule (
164
+ flat_dim ,
165
+ output_dim ,
166
+ hidden_sizes ,
167
+ hidden_nonlinearity = mlp_hidden_nonlinearity ,
168
+ hidden_w_init = hidden_w_init ,
169
+ hidden_b_init = hidden_b_init ,
170
+ output_nonlinearity = output_nonlinearity ,
171
+ output_w_init = output_w_init ,
172
+ output_b_init = output_b_init ,
173
+ layer_normalization = layer_normalization )
174
+
139
175
if mlp_hidden_nonlinearity is None :
140
176
self ._module = nn .Sequential (cnn_module , nn .Flatten ())
141
177
else :
@@ -144,16 +180,29 @@ def __init__(self,
144
180
nn .Flatten ())
145
181
146
182
else :
147
- mlp_module = MLPModule (flat_dim ,
148
- output_dim ,
149
- hidden_sizes ,
150
- hidden_nonlinearity = mlp_hidden_nonlinearity ,
151
- hidden_w_init = hidden_w_init ,
152
- hidden_b_init = hidden_b_init ,
153
- output_nonlinearity = output_nonlinearity ,
154
- output_w_init = output_w_init ,
155
- output_b_init = output_b_init ,
156
- layer_normalization = layer_normalization )
183
+ mlp_module = None
184
+ if noisy :
185
+ mlp_module = NoisyMLPModule (
186
+ flat_dim ,
187
+ output_dim ,
188
+ hidden_sizes ,
189
+ sigma_naught = noisy_sigma ,
190
+ std_noise = std_noise ,
191
+ hidden_nonlinearity = mlp_hidden_nonlinearity ,
192
+ output_nonlinearity = output_nonlinearity )
193
+ self ._noisy_layers = [mlp_module ]
194
+ else :
195
+ mlp_module = MLPModule (
196
+ flat_dim ,
197
+ output_dim ,
198
+ hidden_sizes ,
199
+ hidden_nonlinearity = mlp_hidden_nonlinearity ,
200
+ hidden_w_init = hidden_w_init ,
201
+ hidden_b_init = hidden_b_init ,
202
+ output_nonlinearity = output_nonlinearity ,
203
+ output_w_init = output_w_init ,
204
+ output_b_init = output_b_init ,
205
+ layer_normalization = layer_normalization )
157
206
158
207
if mlp_hidden_nonlinearity is None :
159
208
self ._module = nn .Sequential (cnn_module , nn .Flatten (),
@@ -182,3 +231,21 @@ def forward(self, inputs):
182
231
return val + act
183
232
184
233
return self ._module (inputs )
234
+
235
+ def log_noise (self , key ):
236
+ """Log sigma levels for noisy layers.
237
+
238
+ Args:
239
+ key (str): Prefix to use for logging.
240
+
241
+ """
242
+ if self ._noisy :
243
+ layer_num = 0
244
+ for layer in self ._noisy_layers :
245
+ for name , param in layer .named_parameters ():
246
+ if name .endswith ('weight_sigma' ):
247
+ layer_num += 1
248
+ sigma_mean = float (
249
+ (param ** 2 ).mean ().sqrt ().data .cpu ().numpy ())
250
+ tabular .record (key + '_layer_' + str (layer_num ),
251
+ sigma_mean )
0 commit comments