@@ -536,6 +536,9 @@ def execute_model(
536
536
)
537
537
# print(hidden_states.shape)
538
538
# print(input_tokens)
539
+ B_NO_PAD = input_tokens_history .shape [0 ]
540
+ input_tokens = input_tokens [:B_NO_PAD , :, :]
541
+ hidden_states = hidden_states [:B_NO_PAD , :, :]
539
542
idx_next , logprob , finish = self .sampler .sample (
540
543
inputs_ids = (
541
544
input_tokens
@@ -774,13 +777,17 @@ def _make_tensor_with_pad(
774
777
device : Union [str , torch .device ] = "cuda" ,
775
778
pin_memory : bool = False ,
776
779
) -> torch .Tensor :
777
- padded_x = [_pad_to_max (x_i , max_len , pad ) for x_i in x ]
778
- return torch .tensor (
779
- padded_x ,
780
- dtype = dtype ,
781
- device = device ,
782
- pin_memory = pin_memory and str (device ) == "cpu" ,
783
- )
780
+ padded_x = []
781
+ for x_i in x :
782
+ pad_i = pad
783
+ if isinstance (x [0 ][0 ],tuple ):
784
+ pad_i = (0 ,) * len (x [0 ][0 ])
785
+ padded_x .append (_pad_to_max (x_i , max_len , pad_i ))
786
+
787
+ return torch .tensor (padded_x ,
788
+ dtype = dtype ,
789
+ device = device ,
790
+ pin_memory = pin_memory and str (device ) == "cpu" )
784
791
785
792
786
793
def _get_graph_batch_size (batch_size : int ) -> int :
0 commit comments