3
3
import logging
4
4
import os
5
5
import tempfile
6
+ from collections import OrderedDict
6
7
7
8
import numpy as np
8
9
import onnx
@@ -155,8 +156,8 @@ def forward(self, src_tokens, src_lengths):
155
156
encoder_outputs = encoder_out [0 ]
156
157
outputs .append (encoder_outputs )
157
158
output_names .append (f"encoder_output_{ i } " )
158
- if hasattr (model .decoder , "_init_prev_states " ):
159
- states .extend (model .decoder ._init_prev_states (encoder_out ))
159
+ if hasattr (model .decoder , "get_init_prev_states " ):
160
+ states .extend (model .decoder .get_init_prev_states (encoder_out ))
160
161
161
162
# underlying assumption is each model has same vocab_reduction_module
162
163
vocab_reduction_module = self .models [0 ].decoder .vocab_reduction_module
@@ -272,9 +273,6 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
272
273
273
274
next_state_input = len (self .models )
274
275
275
- # size of "batch" dimension of input as tensor
276
- batch_size = torch .onnx .operators .shape_as_tensor (input_tokens )[0 ]
277
-
278
276
# underlying assumption is each model has same vocab_reduction_module
279
277
vocab_reduction_module = self .models [0 ].decoder .vocab_reduction_module
280
278
if vocab_reduction_module is not None :
@@ -285,20 +283,6 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
285
283
286
284
for i , model in enumerate (self .models ):
287
285
encoder_output = inputs [i ]
288
- prev_hiddens = []
289
- prev_cells = []
290
-
291
- for _ in range (len (model .decoder .layers )):
292
- prev_hiddens .append (inputs [next_state_input ])
293
- prev_cells .append (inputs [next_state_input + 1 ])
294
- next_state_input += 2
295
-
296
- # ensure previous attention context has batch dimension
297
- input_feed_shape = torch .cat ((batch_size .view (1 ), torch .LongTensor ([- 1 ])))
298
- prev_input_feed = torch .onnx .operators .reshape_from_tensor_shape (
299
- inputs [next_state_input ], input_feed_shape
300
- )
301
- next_state_input += 1
302
286
303
287
# no batching, we only care about care about "max" length
304
288
src_length_int = int (encoder_output .size ()[0 ])
@@ -310,8 +294,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
310
294
311
295
encoder_out = (
312
296
encoder_output ,
313
- prev_hiddens ,
314
- prev_cells ,
297
+ None ,
298
+ None ,
315
299
src_length ,
316
300
src_tokens ,
317
301
src_embeddings ,
@@ -321,16 +305,12 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
321
305
model .decoder ._is_incremental_eval = True
322
306
model .eval ()
323
307
324
- # placeholder
325
- incremental_state = {}
326
-
327
- # cache previous state inputs
328
- utils .set_incremental_state (
329
- model .decoder ,
330
- incremental_state ,
331
- "cached_state" ,
332
- (prev_hiddens , prev_cells , prev_input_feed ),
333
- )
308
+ # pass state inputs via incremental_state
309
+ num_states = model .decoder .get_num_states ()
310
+ prev_states = inputs [next_state_input : next_state_input + num_states ]
311
+ next_state_input += num_states
312
+ incremental_state = OrderedDict ()
313
+ model .decoder .populate_incremental_state (incremental_state , prev_states )
334
314
335
315
decoder_output = model .decoder (
336
316
input_tokens ,
@@ -345,13 +325,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
345
325
log_probs_per_model .append (log_probs )
346
326
attn_weights_per_model .append (attn_scores )
347
327
348
- (next_hiddens , next_cells , next_input_feed ) = utils .get_incremental_state (
349
- model .decoder , incremental_state , "cached_state"
350
- )
351
-
352
- for h , c in zip (next_hiddens , next_cells ):
353
- state_outputs .extend ([h , c ])
354
- state_outputs .append (next_input_feed )
328
+ next_states = model .decoder .serialize_incremental_state (incremental_state )
329
+ state_outputs .extend (next_states )
355
330
356
331
average_log_probs = torch .mean (
357
332
torch .cat (log_probs_per_model , dim = 1 ), dim = 1 , keepdim = True
@@ -735,15 +710,6 @@ def forward(self, input_token, target_token, timestep, *inputs):
735
710
736
711
for i , model in enumerate (self .models ):
737
712
encoder_output = inputs [i ]
738
- prev_hiddens = []
739
- prev_cells = []
740
-
741
- for _ in range (len (model .decoder .layers )):
742
- prev_hiddens .append (inputs [next_state_input ])
743
- prev_cells .append (inputs [next_state_input + 1 ])
744
- next_state_input += 2
745
- prev_input_feed = inputs [next_state_input ].view (1 , - 1 )
746
- next_state_input += 1
747
713
748
714
# no batching, we only care about care about "max" length
749
715
src_length_int = int (encoder_output .size ()[0 ])
@@ -755,8 +721,8 @@ def forward(self, input_token, target_token, timestep, *inputs):
755
721
756
722
encoder_out = (
757
723
encoder_output ,
758
- prev_hiddens ,
759
- prev_cells ,
724
+ None ,
725
+ None ,
760
726
src_length ,
761
727
src_tokens ,
762
728
src_embeddings ,
@@ -766,16 +732,12 @@ def forward(self, input_token, target_token, timestep, *inputs):
766
732
model .decoder ._is_incremental_eval = True
767
733
model .eval ()
768
734
769
- # placeholder
770
- incremental_state = {}
771
-
772
- # cache previous state inputs
773
- utils .set_incremental_state (
774
- model .decoder ,
775
- incremental_state ,
776
- "cached_state" ,
777
- (prev_hiddens , prev_cells , prev_input_feed ),
778
- )
735
+ # pass state inputs via incremental_state
736
+ num_states = model .decoder .get_num_states ()
737
+ prev_states = inputs [next_state_input : next_state_input + num_states ]
738
+ next_state_input += num_states
739
+ incremental_state = OrderedDict ()
740
+ model .decoder .populate_incremental_state (incremental_state , prev_states )
779
741
780
742
decoder_output = model .decoder (
781
743
input_token .view (1 , 1 ),
@@ -789,13 +751,8 @@ def forward(self, input_token, target_token, timestep, *inputs):
789
751
790
752
log_probs_per_model .append (log_probs )
791
753
792
- (next_hiddens , next_cells , next_input_feed ) = utils .get_incremental_state (
793
- model .decoder , incremental_state , "cached_state"
794
- )
795
-
796
- for h , c in zip (next_hiddens , next_cells ):
797
- state_outputs .extend ([h , c ])
798
- state_outputs .append (next_input_feed )
754
+ next_states = model .decoder .serialize_incremental_state (incremental_state )
755
+ state_outputs .extend (next_states )
799
756
800
757
average_log_probs = torch .mean (
801
758
torch .cat (log_probs_per_model , dim = 0 ), dim = 0 , keepdim = True
@@ -1020,8 +977,8 @@ def forward(self, src_tokens, src_lengths, char_inds, word_lengths):
1020
977
outputs .append (encoder_outputs )
1021
978
output_names .append (f"encoder_output_{ i } " )
1022
979
1023
- if hasattr (model .decoder , "_init_prev_states " ):
1024
- states .extend (model .decoder ._init_prev_states (encoder_out ))
980
+ if hasattr (model .decoder , "get_init_prev_states " ):
981
+ states .extend (model .decoder .get_init_prev_states (encoder_out ))
1025
982
1026
983
# underlying assumption is each model has same vocab_reduction_module
1027
984
vocab_reduction_module = self .models [0 ].decoder .vocab_reduction_module
0 commit comments