Skip to content

Commit 51ec0c7

Browse files
authored
fix(gpt): stream mode dim mismatch (fix #606) (#607)
1 parent a0e6cd8 commit 51ec0c7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ChatTTS/model/gpt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def generate(
416416
)
417417
inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids)
418418
del inputs_ids
419+
inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
419420

420421
pbar: Optional[tqdm] = None
421422

@@ -430,8 +431,6 @@ def generate(
430431

431432
for i in range(max_new_token):
432433

433-
inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
434-
435434
model_input = self._prepare_generation_inputs(
436435
inputs_ids,
437436
past_key_values,
@@ -606,6 +605,7 @@ def generate(
606605

607606
del idx_next
608607
progress += 1
608+
inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
609609

610610
not_finished = finish.logical_not().to(end_idx.device)
611611
end_idx.add_(not_finished.int())

0 commit comments

Comments
 (0)