@@ -17,7 +17,6 @@ def __init__(
17
17
super (MatchHead , self ).__init__ ()
18
18
self .GRU_1 = nn .GRU (base_model_feature_size , rnn_dimension , bidirectional = False )
19
19
self .GRU_2 = nn .GRU (base_model_feature_size , rnn_dimension , bidirectional = False )
20
- self .linear_1 = nn .Linear (additional_feature_size , linear_1_dimension )
21
20
self .linear_1 = nn .Linear (rnn_dimension * 2 + additional_feature_size , linear_1_dimension )
22
21
self .linear_2 = nn .Linear (linear_1_dimension , num_classes )
23
22
@@ -47,6 +46,7 @@ def forward(self, data_1, data_2, additional_feats):
47
46
48
47
return sigmoid_output
49
48
49
+
50
50
class MatchArchitecture (nn .Module ):
51
51
"Transformer base model for matching."
52
52
def __init__ (
@@ -108,8 +108,54 @@ def forward(
108
108
sequence_output_2 = outputs_2 [0 ]
109
109
110
110
match_classification = self .match_head (sequence_output_1 , sequence_output_2 , additional_feats )
111
- #match_classification = self.match_head(None, None, additional_feats)
112
111
113
112
return match_classification
114
113
115
114
115
+ class FFMatchHead (nn .Module ):
116
+ """Roberta Head for Matching."""
117
+ def __init__ (
118
+ self ,
119
+ additional_feature_size ,
120
+ num_classes ,
121
+ linear_1_dimension ,
122
+ ):
123
+ """Model architecture definition for the capitalization model in torch."""
124
+ super (MatchHead , self ).__init__ ()
125
+ self .linear_1 = nn .Linear (additional_feature_size , linear_1_dimension )
126
+ self .linear_2 = nn .Linear (linear_1_dimension , num_classes )
127
+
128
+ def forward (self , additional_feats ):
129
+ """Forward pass"""
130
+
131
+ # batch second is faster
132
+ linear_input = additional_feats
133
+ linear_output = self .linear_1 (linear_input )
134
+ activated_linear_output = F .relu (linear_output )
135
+ pre_sigmoid_output = self .linear_2 (activated_linear_output )
136
+ sigmoid_output = F .sigmoid (pre_sigmoid_output )
137
+
138
+ return sigmoid_output
139
+
140
+
141
+ class FFMatchArchitecture (nn .Module ):
142
+ "Transformer base model for matching."
143
+ def __init__ (
144
+ self ,
145
+ additional_feature_size ,
146
+ num_classes ,
147
+ linear_1_dimension ,
148
+ ):
149
+ super (FFMatchArchitecture , self ).__init__ ()
150
+ self .match_head = FFMatchHead (
151
+ additional_feature_size , num_classes , linear_1_dimension
152
+ )
153
+
154
+ def forward (self , additional_feats ):
155
+ """Forward pass"""
156
+
157
+ match_classification = self .match_head (additional_feats )
158
+ return match_classification
159
+
160
+
161
+
0 commit comments