@@ -50,26 +50,29 @@ def forward(self, src, sim, tgt, src_lengths, sim_lengths, bptt=False, with_alig
50
50
src_enc_out , src_memory_bank , src_lens = self .encoder (src , src_lengths )
51
51
sim_enc_out , sim_memory_bank , sim_lens = self .encoder2 (sim , sim_lengths )
52
52
53
- sim_pooled_enc = self .pooler (sim_enc_out .transpose (2 ,0 )).transpose (2 ,0 )
54
- sim_pooled_mb = self .pooler (sim_memory_bank .transpose (2 ,0 )).transpose (2 ,0 )
55
- sim_lineared_enc = torch .bmm (sim_pooled_enc .transpose (0 ,1 ), self .sim_weight .expand (src .size ()[1 ],512 ,512 ).to (src .device )).transpose (0 ,1 )
56
- sim_lineared_mb = torch .bmm (sim_pooled_mb .transpose (0 ,1 ), self .sim_weight .expand (src .size ()[1 ],512 ,512 ).to (src .device )).transpose (0 ,1 )
53
+ # sim_pooled_enc = self.pooler(sim_enc_out.transpose(2,0)).transpose(2,0)
54
+ # sim_pooled_mb = self.pooler(sim_memory_bank.transpose(2,0)).transpose(2,0)
55
+ # sim_lineared_enc = torch.bmm(sim_pooled_enc.transpose(0,1), self.sim_weight.expand(src.size()[1],512,512).to(src.device)).transpose(0,1)
56
+ # sim_lineared_mb = torch.bmm(sim_pooled_mb.transpose(0,1), self.sim_weight.expand(src.size()[1],512,512).to(src.device)).transpose(0,1)
57
57
58
- src_out = torch .cat ([torch .zeros (10 ,src .size ()[1 ],src .size ()[2 ],dtype = src .dtype , device = src .device ),src ])
58
+ # src_out=torch.cat([torch.zeros(10,src.size()[1],src.size()[2],dtype=src.dtype, device=src.device),src])
59
59
60
60
print (src_enc_out .size ())
61
61
#print(sim_lineared_enc.size())
62
62
63
- enc_out = torch .cat ([src_enc_out , sim_lineared_enc ])
64
- mb_out = torch .cat ([src_memory_bank , sim_lineared_mb ])
63
+ # enc_out = torch.cat([sim_lineared_enc, src_enc_out ])
64
+ # mb_out = torch.cat([sim_lineared_mb, src_memory_bank ])
65
65
66
- #src_out=src
67
- #enc_out=src_enc_out
68
- #mb_out=src_memory_bank
66
+ print (sim .size ())
67
+ print (src .size ())
68
+ src_out = torch .cat ([sim , src ])
69
+ print (src_out .size ())
70
+ enc_out = torch .cat ([sim_enc_out , src_enc_out ])
71
+ mb_out = torch .cat ([sim_memory_bank , src_memory_bank ])
69
72
70
73
if bptt is False :
71
74
self .decoder .init_state (src_out , mb_out , enc_out )
72
- dec_out , attns = self .decoder (dec_in , mb_out , memory_lengths = src_lens , with_align = with_align )
75
+ dec_out , attns = self .decoder (dec_in , mb_out , memory_lengths = ( src_lens + sim_lens ) , with_align = with_align )
73
76
return dec_out , attns
74
77
75
78
def update_dropout (self , dropout ):
0 commit comments