Skip to content

Commit 00cd943

Browse files
authored
chore: optimize tensor padding in model_runner.py (#628)
1 parent 4f72f4a commit 00cd943

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

ChatTTS/model/velocity/model_runner.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,9 @@ def execute_model(
536536
)
537537
# print(hidden_states.shape)
538538
# print(input_tokens)
539+
B_NO_PAD = input_tokens_history.shape[0]
540+
input_tokens = input_tokens[:B_NO_PAD, :, :]
541+
hidden_states = hidden_states[:B_NO_PAD, :, :]
539542
idx_next, logprob, finish = self.sampler.sample(
540543
inputs_ids=(
541544
input_tokens
@@ -774,13 +777,17 @@ def _make_tensor_with_pad(
774777
device: Union[str, torch.device] = "cuda",
775778
pin_memory: bool = False,
776779
) -> torch.Tensor:
777-
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
778-
return torch.tensor(
779-
padded_x,
780-
dtype=dtype,
781-
device=device,
782-
pin_memory=pin_memory and str(device) == "cpu",
783-
)
780+
padded_x = []
781+
for x_i in x:
782+
pad_i = pad
783+
if isinstance(x[0][0],tuple):
784+
pad_i = (0,) * len(x[0][0])
785+
padded_x.append(_pad_to_max(x_i, max_len, pad_i))
786+
787+
return torch.tensor(padded_x,
788+
dtype=dtype,
789+
device=device,
790+
pin_memory=pin_memory and str(device) == "cpu")
784791

785792

786793
def _get_graph_batch_size(batch_size: int) -> int:

0 commit comments

Comments
 (0)