Skip to content

Commit 44212a0

Browse files
authored
Remove periods from parameter names
See pytorch/pytorch#6941 for this change to pytorch
1 parent c979995 commit 44212a0

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pyvarinf/vi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ def _variationalize_module(self, dico, module, prefix, zero_mean,
255255
None)
256256

257257
if learn_mean:
258-
self.register_parameter(prefix + '.' + name + '_mean',
258+
self.register_parameter(prefix + '_' + name + '_mean',
259259
dico[name].mean)
260260
if learn_rho:
261-
self.register_parameter(prefix + '.' + name + '_rho',
261+
self.register_parameter(prefix + '_' + name + '_rho',
262262
dico[name].rho)
263263

264264
to_erase.append(name)
@@ -269,7 +269,7 @@ def _variationalize_module(self, dico, module, prefix, zero_mean,
269269
for mname, sub_module in module.named_children():
270270
sub_dico = OrderedDict()
271271
self._variationalize_module(sub_dico, sub_module,
272-
prefix + ('.' if prefix else '') +
272+
prefix + ('_' if prefix else '') +
273273
mname, zero_mean,
274274
learn_mean, learn_rho)
275275
dico[mname] = sub_dico

0 commit comments

Comments
 (0)