Skip to content

Commit 9941306

Browse files
author
futuran
committed
fix
1 parent 45906a4 commit 9941306

File tree

3 files changed

+15
-13
lines changed

3 files changed

+15
-13
lines changed

.vscode/launch.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
55
"version": "0.2.0",
66
"configurations": [
7-
87
{
98
"name": "Python: Current File",
109
"type": "python",

multimodal/train.conf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ queue_size: 10000
1616
world_size: 1
1717
gpu_ranks: [0]
1818
batch_type: "sents"
19-
batch_size: 50
19+
batch_size: 1
2020
valid_batch_size: 32
2121
max_generator_batches: 2
2222
accum_count: [1]

onmt/models/model.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,29 @@ def forward(self, src, sim, tgt, src_lengths, sim_lengths, bptt=False, with_alig
5050
src_enc_out, src_memory_bank, src_lens = self.encoder(src, src_lengths)
5151
sim_enc_out, sim_memory_bank, sim_lens = self.encoder2(sim, sim_lengths)
5252

53-
sim_pooled_enc = self.pooler(sim_enc_out.transpose(2,0)).transpose(2,0)
54-
sim_pooled_mb = self.pooler(sim_memory_bank.transpose(2,0)).transpose(2,0)
55-
sim_lineared_enc = torch.bmm(sim_pooled_enc.transpose(0,1), self.sim_weight.expand(src.size()[1],512,512).to(src.device)).transpose(0,1)
56-
sim_lineared_mb = torch.bmm(sim_pooled_mb.transpose(0,1), self.sim_weight.expand(src.size()[1],512,512).to(src.device)).transpose(0,1)
53+
#sim_pooled_enc = self.pooler(sim_enc_out.transpose(2,0)).transpose(2,0)
54+
#sim_pooled_mb = self.pooler(sim_memory_bank.transpose(2,0)).transpose(2,0)
55+
#sim_lineared_enc = torch.bmm(sim_pooled_enc.transpose(0,1), self.sim_weight.expand(src.size()[1],512,512).to(src.device)).transpose(0,1)
56+
#sim_lineared_mb = torch.bmm(sim_pooled_mb.transpose(0,1), self.sim_weight.expand(src.size()[1],512,512).to(src.device)).transpose(0,1)
5757

58-
src_out=torch.cat([torch.zeros(10,src.size()[1],src.size()[2],dtype=src.dtype, device=src.device),src])
58+
#src_out=torch.cat([torch.zeros(10,src.size()[1],src.size()[2],dtype=src.dtype, device=src.device),src])
5959

6060
print(src_enc_out.size())
6161
#print(sim_lineared_enc.size())
6262

63-
enc_out = torch.cat([src_enc_out, sim_lineared_enc])
64-
mb_out = torch.cat([src_memory_bank, sim_lineared_mb])
63+
#enc_out = torch.cat([sim_lineared_enc, src_enc_out])
64+
#mb_out = torch.cat([sim_lineared_mb, src_memory_bank])
6565

66-
#src_out=src
67-
#enc_out=src_enc_out
68-
#mb_out=src_memory_bank
66+
print(sim.size())
67+
print(src.size())
68+
src_out = torch.cat([sim, src])
69+
print(src_out.size())
70+
enc_out = torch.cat([sim_enc_out, src_enc_out])
71+
mb_out = torch.cat([sim_memory_bank, src_memory_bank])
6972

7073
if bptt is False:
7174
self.decoder.init_state(src_out, mb_out, enc_out)
72-
dec_out, attns = self.decoder(dec_in, mb_out, memory_lengths=src_lens, with_align=with_align)
75+
dec_out, attns = self.decoder(dec_in, mb_out, memory_lengths=(src_lens+sim_lens), with_align=with_align)
7376
return dec_out, attns
7477

7578
def update_dropout(self, dropout):

0 commit comments

Comments
 (0)