Skip to content

Commit 334978e

Browse files
committedAug 23, 2022
✨ Add init weights for AMCNN
1 parent 354eb8e commit 334978e

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed
 

‎utils.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch.nn as nn
22
import numpy as np
33

4+
from models import AMCNNAttention
5+
46

57
def init_weights(mat):
68
for m in mat.modules():
@@ -16,11 +18,13 @@ def init_weights(mat):
1618
nn.init.orthogonal_(param[idx*mul:(idx+1)*mul])
1719
elif 'bias' in name:
1820
param.data.fill_(0)
19-
else:
20-
if type(m) in [nn.Linear]:
21-
nn.init.uniform_(m.weight, -0.01, 0.01)
22-
if m.bias != None:
23-
m.bias.data.fill_(0.01)
21+
elif type(m) in [nn.Linear]:
22+
nn.init.uniform_(m.weight, -0.01, 0.01)
23+
if m.bias != None:
24+
m.bias.data.fill_(0.01)
25+
elif type(m) in [AMCNNAttention]:
26+
for p in m.parameters():
27+
nn.init.uniform_(p, -0.01, 0.01)
2428

2529

2630
def removeObjectiveSents(docs_sents, mask, tokenized=False):

0 commit comments

Comments
 (0)
Please sign in to comment.