Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.

Commit b5fc2a5

Browse files
jhcrossfacebook-github-bot
authored andcommitted
more general state init/passing
Summary: A more general mechanism for passing states during incremental decoding, which in particular makes requirements for ONNX export more explicit. Differential Revision: D9599067 fbshipit-source-id: 806a0d6ba213fb531f8b44bbc9bc2fb089066b4e
1 parent adb281f commit b5fc2a5

File tree

2 files changed

+67
-71
lines changed

2 files changed

+67
-71
lines changed

pytorch_translate/ensemble_export.py

Lines changed: 25 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import os
55
import tempfile
6+
from collections import OrderedDict
67

78
import numpy as np
89
import onnx
@@ -155,8 +156,8 @@ def forward(self, src_tokens, src_lengths):
155156
encoder_outputs = encoder_out[0]
156157
outputs.append(encoder_outputs)
157158
output_names.append(f"encoder_output_{i}")
158-
if hasattr(model.decoder, "_init_prev_states"):
159-
states.extend(model.decoder._init_prev_states(encoder_out))
159+
if hasattr(model.decoder, "get_init_prev_states"):
160+
states.extend(model.decoder.get_init_prev_states(encoder_out))
160161

161162
# underlying assumption is each model has same vocab_reduction_module
162163
vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
@@ -272,9 +273,6 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
272273

273274
next_state_input = len(self.models)
274275

275-
# size of "batch" dimension of input as tensor
276-
batch_size = torch.onnx.operators.shape_as_tensor(input_tokens)[0]
277-
278276
# underlying assumption is each model has same vocab_reduction_module
279277
vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
280278
if vocab_reduction_module is not None:
@@ -285,20 +283,6 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
285283

286284
for i, model in enumerate(self.models):
287285
encoder_output = inputs[i]
288-
prev_hiddens = []
289-
prev_cells = []
290-
291-
for _ in range(len(model.decoder.layers)):
292-
prev_hiddens.append(inputs[next_state_input])
293-
prev_cells.append(inputs[next_state_input + 1])
294-
next_state_input += 2
295-
296-
# ensure previous attention context has batch dimension
297-
input_feed_shape = torch.cat((batch_size.view(1), torch.LongTensor([-1])))
298-
prev_input_feed = torch.onnx.operators.reshape_from_tensor_shape(
299-
inputs[next_state_input], input_feed_shape
300-
)
301-
next_state_input += 1
302286

303287
# no batching, we only care about care about "max" length
304288
src_length_int = int(encoder_output.size()[0])
@@ -310,8 +294,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
310294

311295
encoder_out = (
312296
encoder_output,
313-
prev_hiddens,
314-
prev_cells,
297+
None,
298+
None,
315299
src_length,
316300
src_tokens,
317301
src_embeddings,
@@ -321,16 +305,12 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
321305
model.decoder._is_incremental_eval = True
322306
model.eval()
323307

324-
# placeholder
325-
incremental_state = {}
326-
327-
# cache previous state inputs
328-
utils.set_incremental_state(
329-
model.decoder,
330-
incremental_state,
331-
"cached_state",
332-
(prev_hiddens, prev_cells, prev_input_feed),
333-
)
308+
# pass state inputs via incremental_state
309+
num_states = model.decoder.get_num_states()
310+
prev_states = inputs[next_state_input : next_state_input + num_states]
311+
next_state_input += num_states
312+
incremental_state = OrderedDict()
313+
model.decoder.populate_incremental_state(incremental_state, prev_states)
334314

335315
decoder_output = model.decoder(
336316
input_tokens,
@@ -345,13 +325,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
345325
log_probs_per_model.append(log_probs)
346326
attn_weights_per_model.append(attn_scores)
347327

348-
(next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state(
349-
model.decoder, incremental_state, "cached_state"
350-
)
351-
352-
for h, c in zip(next_hiddens, next_cells):
353-
state_outputs.extend([h, c])
354-
state_outputs.append(next_input_feed)
328+
next_states = model.decoder.serialize_incremental_state(incremental_state)
329+
state_outputs.extend(next_states)
355330

356331
average_log_probs = torch.mean(
357332
torch.cat(log_probs_per_model, dim=1), dim=1, keepdim=True
@@ -735,15 +710,6 @@ def forward(self, input_token, target_token, timestep, *inputs):
735710

736711
for i, model in enumerate(self.models):
737712
encoder_output = inputs[i]
738-
prev_hiddens = []
739-
prev_cells = []
740-
741-
for _ in range(len(model.decoder.layers)):
742-
prev_hiddens.append(inputs[next_state_input])
743-
prev_cells.append(inputs[next_state_input + 1])
744-
next_state_input += 2
745-
prev_input_feed = inputs[next_state_input].view(1, -1)
746-
next_state_input += 1
747713

748714
# no batching, we only care about care about "max" length
749715
src_length_int = int(encoder_output.size()[0])
@@ -755,8 +721,8 @@ def forward(self, input_token, target_token, timestep, *inputs):
755721

756722
encoder_out = (
757723
encoder_output,
758-
prev_hiddens,
759-
prev_cells,
724+
None,
725+
None,
760726
src_length,
761727
src_tokens,
762728
src_embeddings,
@@ -766,16 +732,12 @@ def forward(self, input_token, target_token, timestep, *inputs):
766732
model.decoder._is_incremental_eval = True
767733
model.eval()
768734

769-
# placeholder
770-
incremental_state = {}
771-
772-
# cache previous state inputs
773-
utils.set_incremental_state(
774-
model.decoder,
775-
incremental_state,
776-
"cached_state",
777-
(prev_hiddens, prev_cells, prev_input_feed),
778-
)
735+
# pass state inputs via incremental_state
736+
num_states = model.decoder.get_num_states()
737+
prev_states = inputs[next_state_input : next_state_input + num_states]
738+
next_state_input += num_states
739+
incremental_state = OrderedDict()
740+
model.decoder.populate_incremental_state(incremental_state, prev_states)
779741

780742
decoder_output = model.decoder(
781743
input_token.view(1, 1),
@@ -789,13 +751,8 @@ def forward(self, input_token, target_token, timestep, *inputs):
789751

790752
log_probs_per_model.append(log_probs)
791753

792-
(next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state(
793-
model.decoder, incremental_state, "cached_state"
794-
)
795-
796-
for h, c in zip(next_hiddens, next_cells):
797-
state_outputs.extend([h, c])
798-
state_outputs.append(next_input_feed)
754+
next_states = model.decoder.serialize_incremental_state(incremental_state)
755+
state_outputs.extend(next_states)
799756

800757
average_log_probs = torch.mean(
801758
torch.cat(log_probs_per_model, dim=0), dim=0, keepdim=True
@@ -1020,8 +977,8 @@ def forward(self, src_tokens, src_lengths, char_inds, word_lengths):
1020977
outputs.append(encoder_outputs)
1021978
output_names.append(f"encoder_output_{i}")
1022979

1023-
if hasattr(model.decoder, "_init_prev_states"):
1024-
states.extend(model.decoder._init_prev_states(encoder_out))
980+
if hasattr(model.decoder, "get_init_prev_states"):
981+
states.extend(model.decoder.get_init_prev_states(encoder_out))
1025982

1026983
# underlying assumption is each model has same vocab_reduction_module
1027984
vocab_reduction_module = self.models[0].decoder.vocab_reduction_module

pytorch_translate/rnn.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,7 @@ def forward_unprojected(self, input_tokens, encoder_out, incremental_state=None)
11391139
prev_hiddens, prev_cells, input_feed = cached_state
11401140
else:
11411141
# first time step, initialize previous states
1142-
init_prev_states = self._init_prev_states(encoder_out)
1142+
init_prev_states = self.get_init_prev_states(encoder_out)
11431143
prev_hiddens = []
11441144
prev_cells = []
11451145

@@ -1247,7 +1247,13 @@ def max_positions(self):
12471247
"""Maximum output length supported by the decoder."""
12481248
return int(1e5) # an arbitrary large number
12491249

1250-
def _init_prev_states(self, encoder_out):
1250+
def get_num_states(self):
1251+
num_states = 2 * len(self.layers)
1252+
if self.attention.context_dim:
1253+
num_states += 1
1254+
return num_states
1255+
1256+
def get_init_prev_states(self, encoder_out):
12511257
(
12521258
encoder_output,
12531259
final_hiddens,
@@ -1274,10 +1280,43 @@ def _init_prev_states(self, encoder_out):
12741280
for h, c in zip(prev_hiddens, prev_cells):
12751281
prev_states.extend([h, c])
12761282
if self.attention.context_dim:
1277-
prev_states.append(self.initial_attn_context)
1283+
prev_states.append(self.initial_attn_context.view(1, -1))
12781284

12791285
return prev_states
12801286

1287+
def populate_incremental_state(self, incremental_state, states):
1288+
"""
1289+
From output of previous step outputs, for ONNX tracing.
1290+
"""
1291+
prev_hiddens = []
1292+
prev_cells = []
1293+
1294+
for i in range(len(self.layers)):
1295+
prev_hiddens.append(states[2 * i])
1296+
prev_cells.append(states[2 * i + 1])
1297+
1298+
input_feed = states[-1]
1299+
1300+
# cache previous state inputs
1301+
utils.set_incremental_state(
1302+
self,
1303+
incremental_state,
1304+
"cached_state",
1305+
(prev_hiddens, prev_cells, input_feed),
1306+
)
1307+
1308+
def serialize_incremental_state(self, incremental_state):
1309+
state_outputs = []
1310+
(hiddens, cells, input_feed) = utils.get_incremental_state(
1311+
self, incremental_state, "cached_state"
1312+
)
1313+
1314+
for h, c in zip(hiddens, cells):
1315+
state_outputs.extend([h, c])
1316+
state_outputs.append(input_feed)
1317+
1318+
return state_outputs
1319+
12811320

12821321
@register_model_architecture("rnn", "rnn")
12831322
def base_architecture(args):

0 commit comments

Comments
 (0)