From 4baac43226a6232f41183c240c14fecf6caa4857 Mon Sep 17 00:00:00 2001 From: cijad Date: Tue, 19 Nov 2019 08:09:39 +0100 Subject: [PATCH 1/3] Updated bn_fusion to handle ConvTranspose2d layers --- bn_fusion.py | 41 +++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/bn_fusion.py b/bn_fusion.py index eef81a6..68c3b7a 100644 --- a/bn_fusion.py +++ b/bn_fusion.py @@ -14,43 +14,52 @@ def fuse_bn_sequential(block): stack = [] for m in block.children(): if isinstance(m, nn.BatchNorm2d): - if isinstance(stack[-1], nn.Conv2d): + if isinstance(stack[-1], nn.Conv2d) or isinstance( + stack[-1], nn.ConvTranspose2d + ): + # Extract params of BatchNorm and Convolution layers bn_st_dict = m.state_dict() conv_st_dict = stack[-1].state_dict() # BatchNorm params eps = m.eps - mu = bn_st_dict['running_mean'] - var = bn_st_dict['running_var'] - gamma = bn_st_dict['weight'] + mu = bn_st_dict["running_mean"] + var = bn_st_dict["running_var"] + gamma = bn_st_dict["weight"] - if 'bias' in bn_st_dict: - beta = bn_st_dict['bias'] + if "bias" in bn_st_dict: + beta = bn_st_dict["bias"] else: beta = torch.zeros(gamma.size(0)).float().to(gamma.device) # Conv params - W = conv_st_dict['weight'] - if 'bias' in conv_st_dict: - bias = conv_st_dict['bias'] + W = conv_st_dict["weight"] + + if isinstance(stack[-1], nn.ConvTranspose2d): + W = W.transpose(0, 1) + + if "bias" in conv_st_dict: + bias = conv_st_dict["bias"] else: bias = torch.zeros(W.size(0)).float().to(gamma.device) denom = torch.sqrt(var + eps) - b = beta - gamma.mul(mu).div(denom) - A = gamma.div(denom) - bias *= A - A = A.expand_as(W.transpose(0, -1)).transpose(0, -1) + b_BN = beta - gamma.mul(mu).div(denom) + W_BN = gamma.div(denom) + bias *= W_BN + W_BN = W_BN.expand_as(W.transpose(0, -1)).transpose(0, -1) - W.mul_(A) - bias.add_(b) + W.mul_(W_BN) + if isinstance(stack[-1], nn.ConvTranspose2d): + W = W.transpose(0, 1) + + bias.add_(b_BN) stack[-1].weight.data.copy_(W) if stack[-1].bias is None: stack[-1].bias = torch.nn.Parameter(bias) else: stack[-1].bias.data.copy_(bias) - else: stack.append(m) From be9ba3f052ee26440cc1cebec373d2c5ba392998 Mon Sep 17 00:00:00 2001 From: cijad Date: Tue, 19 Nov 2019 08:21:52 +0100 Subject: [PATCH 2/3] Updated bn_fusion function to handle ConvTranspose2d layers --- bn_fusion.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/bn_fusion.py b/bn_fusion.py index 68c3b7a..3bba1cc 100644 --- a/bn_fusion.py +++ b/bn_fusion.py @@ -14,10 +14,7 @@ def fuse_bn_sequential(block): stack = [] for m in block.children(): if isinstance(m, nn.BatchNorm2d): - if isinstance(stack[-1], nn.Conv2d) or isinstance( - stack[-1], nn.ConvTranspose2d - ): - # Extract params of BatchNorm and Convolution layers + if isinstance(stack[-1], nn.Conv2d) or isinstance(stack[-1], nn.ConvTranspose2d): bn_st_dict = m.state_dict() conv_st_dict = stack[-1].state_dict() @@ -34,7 +31,6 @@ def fuse_bn_sequential(block): # Conv params W = conv_st_dict["weight"] - if isinstance(stack[-1], nn.ConvTranspose2d): W = W.transpose(0, 1) From 78d15d7d7263ff7548a497006d4b34491468eb35 Mon Sep 17 00:00:00 2001 From: cijad Date: Mon, 6 Apr 2020 09:30:54 +0200 Subject: [PATCH 3/3] Misc fix --- bn_fusion.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/bn_fusion.py b/bn_fusion.py index 3bba1cc..76dda6b 100644 --- a/bn_fusion.py +++ b/bn_fusion.py @@ -20,22 +20,22 @@ def fuse_bn_sequential(block): # BatchNorm params eps = m.eps - mu = bn_st_dict["running_mean"] - var = bn_st_dict["running_var"] - gamma = bn_st_dict["weight"] + mu = bn_st_dict['running_mean'] + var = bn_st_dict['running_var'] + gamma = bn_st_dict['weight'] - if "bias" in bn_st_dict: - beta = bn_st_dict["bias"] + if 'bias' in bn_st_dict: + beta = bn_st_dict['bias'] else: beta = torch.zeros(gamma.size(0)).float().to(gamma.device) # Conv params - W = conv_st_dict["weight"] + W = conv_st_dict['weight'] if isinstance(stack[-1], nn.ConvTranspose2d): W = W.transpose(0, 1) - if "bias" in conv_st_dict: - bias = conv_st_dict["bias"] + if 'bias' in conv_st_dict: + bias = conv_st_dict['bias'] else: bias = torch.zeros(W.size(0)).float().to(gamma.device) @@ -56,6 +56,7 @@ def fuse_bn_sequential(block): stack[-1].bias = torch.nn.Parameter(bias) else: stack[-1].bias.data.copy_(bias) + else: stack.append(m)