forked from skit-ai/N-Best-ASR-Transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
84 lines (61 loc) · 3.01 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.modules.hierarchical_classifier import HierarchicalClassifier
def make_model(opt):
return TOD_ASR_Transformer_STC(opt)
class TOD_ASR_Transformer_STC(nn.Module):
'''TOD ASR Transformer Semantic Tuple Classifier'''
def __init__(self, opt):
super(TOD_ASR_Transformer_STC, self).__init__()
#self.pretrained_model_opts = pretrained_model_opts
self.bert_encoder = opt.pretrained_model
# encoders
self.dropout_layer = nn.Dropout(opt.dropout)
self.device = opt.device
self.score_util = opt.score_util
self.sent_repr = opt.sent_repr
self.cls_type = opt.cls_type
# feature dimension
fea_dim = 768
self.clf = HierarchicalClassifier(opt.top2bottom_dict, fea_dim, opt.label_vocab_size, opt.dropout)
def forward(self,opt,input_ids,trans_input_ids=None,seg_ids=None,trans_seg_ids=None,return_attns=False,classifier_input_type="asr"):
#linear input to fed to downstream classifier
lin_in=None
# encoder on asr out
#If XLM-Roberta don't pass token type ids
if opt.pre_trained_model and opt.pre_trained_model=="xlm-roberta":
outputs = self.bert_encoder(input_ids=input_ids,attention_mask=input_ids>0)
else:
outputs = self.bert_encoder(input_ids=input_ids,attention_mask=input_ids>0,token_type_ids=seg_ids)
sequence_output = outputs[0]
asr_lin_in = sequence_output[:, 0, :]
#encoder on manual transcription
trans_lin_in = None
if trans_input_ids is not None:
#If XLM-Roberta don't pass token type ids
if opt.pre_trained_model and opt.pre_trained_model=="xlm-roberta":
trans_outputs = self.bert_encoder(input_ids=trans_input_ids,attention_mask=trans_input_ids>0)
else:
trans_outputs = self.bert_encoder(input_ids=trans_input_ids,attention_mask=trans_input_ids>0,token_type_ids=trans_seg_ids)
trans_sequence_output = trans_outputs[0]
trans_lin_in = trans_sequence_output[:, 0, :]
if classifier_input_type=="transcript":
lin_in = trans_lin_in
else:
lin_in = asr_lin_in
# decoder / classifier
if self.cls_type == 'stc':
top_scores, bottom_scores_dict, final_scores = self.clf(lin_in)
if return_attns:
return top_scores, bottom_scores_dict, final_scores, attns,asr_lin_in,trans_lin_in
else:
return top_scores, bottom_scores_dict, final_scores,asr_lin_in,trans_lin_in
def load_model(self, load_dir):
if self.device.type == 'cuda':
self.load_state_dict(torch.load(open(load_dir, 'rb')))
else:
self.load_state_dict(torch.load(open(load_dir, 'rb'),
map_location=lambda storage, loc: storage))
def save_model(self, save_dir):
torch.save(self.state_dict(), open(save_dir, 'wb'))