Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.

Commit c2fbd68

Browse files
jhcrossfacebook-github-bot
authored andcommitted
return empty tensor instead of None (#332)
Summary: Pull Request resolved: #332 To allow efficient use of fork/join annotation, we return an empty tensor instead of `None` for `encoder_padding_mask` from transformer encoder in the unmasked/inference case. Note that this slight hack is preferable to more far-reaching changes in, e.g., Fairseq multihead_attention. Differential Revision: D13969691 fbshipit-source-id: 862ed44019012449554527f236cb344046c75184
1 parent 17c1f47 commit c2fbd68

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

pytorch_translate/ensemble_export.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -256,22 +256,12 @@ def forward(self, src_tokens, src_lengths):
256256
# evaluation mode
257257
model.eval()
258258

259-
# TODO(jamesreed): transformer encodder returns a None output, and
260-
# the fork/join API doesn't handle that well. We should figure out
261-
# a way to annotate outputs as Optional and record that in fork/join
262-
# traces.
263-
if isinstance(model.encoder, TransformerEncoder):
264-
futures.append(model.encoder(src_tokens_seq_first, src_lengths))
265-
else:
266-
futures.append(
267-
torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths)
268-
)
259+
futures.append(
260+
torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths)
261+
)
269262

270263
for i, (model, future) in enumerate(zip(self.models, futures)):
271-
if isinstance(model.encoder, TransformerEncoder):
272-
encoder_out = future
273-
else:
274-
encoder_out = torch.jit._wait(future)
264+
encoder_out = torch.jit._wait(future)
275265
# "primary" encoder output (vector representations per source token)
276266
encoder_outputs = encoder_out[0]
277267
outputs.append(encoder_outputs)

pytorch_translate/hybrid_transformer_rnn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ def forward(
247247
):
248248
(encoder_x, src_tokens, encoder_padding_mask) = encoder_out
249249

250+
if encoder_padding_mask is not None and encoder_padding_mask.numel() == 0:
251+
encoder_padding_mask = None
252+
250253
bsz, seqlen = prev_output_tokens.size()
251254
if incremental_state is not None:
252255
prev_output_tokens = prev_output_tokens[:, -1:]

pytorch_translate/transformer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,10 @@ def forward(self, src_tokens, src_lengths):
277277
x=x, positions=positions, encoder_padding_mask=encoder_padding_mask
278278
)
279279

280+
if encoder_padding_mask is None:
281+
# using an empty tensor instead of None for PyTorch native export
282+
encoder_padding_mask = torch.Tensor().type_as(src_tokens)
283+
280284
return x, src_tokens, encoder_padding_mask
281285

282286
def reorder_encoder_out(self, encoder_out, new_order):
@@ -285,7 +289,7 @@ def reorder_encoder_out(self, encoder_out, new_order):
285289
x = x.index_select(1, new_order)
286290
if src_tokens is not None:
287291
src_tokens = src_tokens.index_select(0, new_order)
288-
if encoder_padding_mask is not None:
292+
if encoder_padding_mask is not None and encoder_padding_mask.numel() != 0:
289293
encoder_padding_mask = encoder_padding_mask.index_select(0, new_order)
290294
return (x, src_tokens, encoder_padding_mask)
291295

@@ -382,6 +386,9 @@ def forward(
382386
):
383387
(encoder_x, src_tokens, encoder_padding_mask) = encoder_out
384388

389+
if encoder_padding_mask is not None and encoder_padding_mask.numel() == 0:
390+
encoder_padding_mask = None
391+
385392
# embed positions
386393
positions = self.embed_positions(
387394
prev_output_tokens, incremental_state=incremental_state, timestep=timestep

0 commit comments

Comments
 (0)