We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 354eb8e commit 334978eCopy full SHA for 334978e
utils.py
@@ -1,6 +1,8 @@
1
import torch.nn as nn
2
import numpy as np
3
4
+from models import AMCNNAttention
5
+
6
7
def init_weights(mat):
8
for m in mat.modules():
@@ -16,11 +18,13 @@ def init_weights(mat):
16
18
nn.init.orthogonal_(param[idx*mul:(idx+1)*mul])
17
19
elif 'bias' in name:
20
param.data.fill_(0)
- else:
- if type(m) in [nn.Linear]:
21
- nn.init.uniform_(m.weight, -0.01, 0.01)
22
- if m.bias != None:
23
- m.bias.data.fill_(0.01)
+ elif type(m) in [nn.Linear]:
+ nn.init.uniform_(m.weight, -0.01, 0.01)
+ if m.bias != None:
24
+ m.bias.data.fill_(0.01)
25
+ elif type(m) in [AMCNNAttention]:
26
+ for p in m.parameters():
27
+ nn.init.uniform_(p, -0.01, 0.01)
28
29
30
def removeObjectiveSents(docs_sents, mask, tokenized=False):
0 commit comments