@@ -258,7 +258,7 @@ def from_factorization(cls, factorization, implementation='factorized',
258
258
order = len (kernel_size )
259
259
260
260
instance = cls (in_channels , out_channels , kernel_size , order = order , implementation = implementation ,
261
- padding = padding , stride = stride , bias = (bias is not None ), n_layers = n_layers ,
261
+ padding = padding , stride = stride , bias = (bias is not None ), n_layers = n_layers , dilation = dilation ,
262
262
factorization = factorization , rank = factorization .rank )
263
263
264
264
instance .weight = factorization
@@ -296,10 +296,12 @@ def from_conv(cls, conv_layer, rank='same', implementation='reconstructed', fact
296
296
out_channels , in_channels , * kernel_size = conv_layer .weight .shape
297
297
stride = conv_layer .stride [0 ]
298
298
bias = conv_layer .bias is not None
299
+ dilation = conv_layer .dilation
299
300
300
301
instance = cls (in_channels , out_channels , kernel_size ,
301
302
factorization = factorization , implementation = implementation , rank = rank ,
302
- padding = padding , stride = stride , fixed_rank_modes = fixed_rank_modes , bias = bias , ** kwargs )
303
+ dilation = dilation , padding = padding , stride = stride , bias = bias ,
304
+ fixed_rank_modes = fixed_rank_modes , ** kwargs )
303
305
304
306
if decompose_weights :
305
307
if conv_layer .bias is not None :
@@ -321,9 +323,10 @@ def from_conv_list(cls, conv_list, rank='same', implementation='reconstructed',
321
323
out_channels , in_channels , * kernel_size = conv_layer .weight .shape
322
324
stride = conv_layer .stride [0 ]
323
325
bias = True
326
+ dilation = conv_layer .dilation
324
327
325
328
instance = cls (in_channels , out_channels , kernel_size , implementation = implementation , rank = rank , factorization = factorization ,
326
- padding = padding , stride = stride , bias = bias , n_layers = len (conv_list ), fixed_rank_modes = None , ** kwargs )
329
+ padding = padding , stride = stride , bias = bias , dilation = dilation , n_layers = len (conv_list ), fixed_rank_modes = None , ** kwargs )
327
330
328
331
if decompose_weights :
329
332
with torch .no_grad ():
0 commit comments