We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 319d037 commit 2bfb097Copy full SHA for 2bfb097
ChatTTS/model/velocity/model_runner.py
@@ -569,13 +569,16 @@ def execute_model(
569
for i in range(idx_next.shape[0]):
570
idx_next_i = idx_next[i, 0, :].cpu().tolist()
571
logprob_i = logprob[i].cpu().tolist()
572
+ tmp_hidden_states = hidden_states[i].cpu()
573
+ if input_tokens[i].shape[-2] != 1:
574
+ tmp_hidden_states = tmp_hidden_states[-1:,:]
575
result = SequenceGroupOutput(
576
samples=[
577
SequenceOutput(
578
parent_seq_id=seq_groups[i],
579
logprobs={tuple(idx_next_i): logprob_i},
580
output_token=tuple(idx_next_i),
- hidden_states=hidden_states[i].cpu(),
581
+ hidden_states=tmp_hidden_states,
582
finished=finish[i].item(),
583
),
584
],
0 commit comments