Skip to content

Commit ab069de

Browse files
author
rhgrossman
committed
QOL improvements and updates for FFNN baseline
1 parent 76b5022 commit ab069de

File tree

2 files changed

+51
-16
lines changed

2 files changed

+51
-16
lines changed

code/Step4_RemoveRedundancy/bert/model.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def __init__(
1717
super(MatchHead, self).__init__()
1818
self.GRU_1 = nn.GRU(base_model_feature_size, rnn_dimension, bidirectional=False)
1919
self.GRU_2 = nn.GRU(base_model_feature_size, rnn_dimension, bidirectional=False)
20-
self.linear_1 = nn.Linear(additional_feature_size, linear_1_dimension)
2120
self.linear_1 = nn.Linear(rnn_dimension * 2 + additional_feature_size, linear_1_dimension)
2221
self.linear_2 = nn.Linear(linear_1_dimension, num_classes)
2322

@@ -47,6 +46,7 @@ def forward(self, data_1, data_2, additional_feats):
4746

4847
return sigmoid_output
4948

49+
5050
class MatchArchitecture(nn.Module):
5151
"Transformer base model for matching."
5252
def __init__(
@@ -108,8 +108,54 @@ def forward(
108108
sequence_output_2 = outputs_2[0]
109109

110110
match_classification = self.match_head(sequence_output_1, sequence_output_2, additional_feats)
111-
#match_classification = self.match_head(None, None, additional_feats)
112111

113112
return match_classification
114113

115114

115+
class FFMatchHead(nn.Module):
116+
"""Roberta Head for Matching."""
117+
def __init__(
118+
self,
119+
additional_feature_size,
120+
num_classes,
121+
linear_1_dimension,
122+
):
123+
"""Model architecture definition for the capitalization model in torch."""
124+
super(MatchHead, self).__init__()
125+
self.linear_1 = nn.Linear(additional_feature_size, linear_1_dimension)
126+
self.linear_2 = nn.Linear(linear_1_dimension, num_classes)
127+
128+
def forward(self, additional_feats):
129+
"""Forward pass"""
130+
131+
# batch second is faster
132+
linear_input = additional_feats
133+
linear_output = self.linear_1(linear_input)
134+
activated_linear_output = F.relu(linear_output)
135+
pre_sigmoid_output = self.linear_2(activated_linear_output)
136+
sigmoid_output = F.sigmoid(pre_sigmoid_output)
137+
138+
return sigmoid_output
139+
140+
141+
class FFMatchArchitecture(nn.Module):
142+
"Transformer base model for matching."
143+
def __init__(
144+
self,
145+
additional_feature_size,
146+
num_classes,
147+
linear_1_dimension,
148+
):
149+
super(FFMatchArchitecture, self).__init__()
150+
self.match_head = FFMatchHead(
151+
additional_feature_size, num_classes, linear_1_dimension
152+
)
153+
154+
def forward(self, additional_feats):
155+
"""Forward pass"""
156+
157+
match_classification = self.match_head(additional_feats)
158+
return match_classification
159+
160+
161+

code/Step4_RemoveRedundancy/bert/train.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from model import MatchArchitecture
1818
from data_utils import MatchingDataset
1919

20-
20+
RANDOM_SEED = 117
2121
SEQ_LEN = 10
2222
RNN_DIM = 64
23-
LINEAR_DIM=64
23+
LINEAR_DIM = 64
2424
CLASSES = 1
2525
ROBERTA_FEAT_SIZE = 768
2626
ADDITIONAL_FEAT_SIZE = 0
@@ -53,7 +53,7 @@
5353
if "target" in additional_feats.columns:
5454
additional_feats.drop("target", axis=1, inplace=True)
5555
ADDITIONAL_FEAT_SIZE = additional_feats.shape[1]
56-
kf = KFold(n_splits=5, random_state = 117, shuffle = True)
56+
kf = KFold(n_splits=5, random_state = RANDOM_SEED, shuffle = True)
5757

5858

5959
# TODO:@Ray improve the fold selection
@@ -271,13 +271,8 @@ def _run_training_loop(model, train_config):
271271
else:
272272
temp_preds = sigmoid_output.cpu().detach().numpy()
273273
preds = np.concatenate([preds, temp_preds], axis=0)
274-
temp_labels = y_batch.cpu().detach().numpy()
275-
labels = np.concatenate([labels, temp_labels], axis =0)
276274

277-
assert(len(preds)==len(labels))
278275
oof_preds[val_inx, 0] = preds[:len(val_inx), 0]
279-
oof_preds2[cur_oof_inx:cur_oof_inx + len(labels), 0] = preds[:len(val_inx), 0]
280-
oof_labels[cur_oof_inx:cur_oof_inx + len(labels), 0] = labels[:len(val_inx), 0]
281276
cur_oof_inx += len(labels)
282277
del model
283278

@@ -291,9 +286,3 @@ def _run_training_loop(model, train_config):
291286
print('Precision at {}: '.format(threshold), mt.precision_score(target, oof_preds > threshold))
292287

293288

294-
for threshold in thresholds:
295-
print('F1 at {}: '.format(threshold), mt.f1_score(oof_labels, oof_preds2 > threshold))
296-
print('Recall at {}: '.format(threshold), mt.recall_score(oof_labels, oof_preds2 > threshold))
297-
print('Precision at {}: '.format(threshold), mt.precision_score(oof_labels, oof_preds2 > threshold))
298-
299-

0 commit comments

Comments
 (0)