Skip to content

Commit 615fbdd

Browse files
authored
Factorized convolution: fix dilation
1 parent 1668c75 commit 615fbdd

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tltorch/factorized_layers/factorized_convolution.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def from_factorization(cls, factorization, implementation='factorized',
258258
order = len(kernel_size)
259259

260260
instance = cls(in_channels, out_channels, kernel_size, order=order, implementation=implementation,
261-
padding=padding, stride=stride, bias=(bias is not None), n_layers=n_layers,
261+
padding=padding, stride=stride, bias=(bias is not None), n_layers=n_layers, dilation=dilation,
262262
factorization=factorization, rank=factorization.rank)
263263

264264
instance.weight = factorization
@@ -296,10 +296,12 @@ def from_conv(cls, conv_layer, rank='same', implementation='reconstructed', fact
296296
out_channels, in_channels, *kernel_size = conv_layer.weight.shape
297297
stride = conv_layer.stride[0]
298298
bias = conv_layer.bias is not None
299+
dilation = conv_layer.dilation
299300

300301
instance = cls(in_channels, out_channels, kernel_size,
301302
factorization=factorization, implementation=implementation, rank=rank,
302-
padding=padding, stride=stride, fixed_rank_modes=fixed_rank_modes, bias=bias, **kwargs)
303+
dilation=dilation, padding=padding, stride=stride, bias=bias,
304+
fixed_rank_modes=fixed_rank_modes, **kwargs)
303305

304306
if decompose_weights:
305307
if conv_layer.bias is not None:
@@ -321,9 +323,10 @@ def from_conv_list(cls, conv_list, rank='same', implementation='reconstructed',
321323
out_channels, in_channels, *kernel_size = conv_layer.weight.shape
322324
stride = conv_layer.stride[0]
323325
bias = True
326+
dilation = conv_layer.dilation
324327

325328
instance = cls(in_channels, out_channels, kernel_size, implementation=implementation, rank=rank, factorization=factorization,
326-
padding=padding, stride=stride, bias=bias, n_layers=len(conv_list), fixed_rank_modes=None, **kwargs)
329+
padding=padding, stride=stride, bias=bias, dilation=dilation, n_layers=len(conv_list), fixed_rank_modes=None, **kwargs)
327330

328331
if decompose_weights:
329332
with torch.no_grad():

0 commit comments

Comments
 (0)