Skip to content

Commit 6feb586

Browse files
committed
fix: parallel inference
1 parent 1e39fdc commit 6feb586

File tree

7 files changed

+234
-101
lines changed

7 files changed

+234
-101
lines changed

ChatTTS/core.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -174,17 +174,15 @@ class RefineTextParams:
174174
min_new_token: int = 0
175175
show_tqdm: bool = True
176176
ensure_non_empty: bool = True
177-
manual_seed: Optional[int] = 0
177+
manual_seed: Optional[int] = None
178178

179179
@dataclass(repr=False, eq=False)
180180
class InferCodeParams(RefineTextParams):
181181
prompt: str = "[speed_5]"
182182
spk_emb: Optional[str] = None
183183
spk_smp: Optional[str] = None
184184
txt_smp: Optional[str] = None
185-
top_P: float = 1
186-
top_K: int = 1
187-
temperature: float = 0.01
185+
temperature: float = 0.3
188186
repetition_penalty: float = 1.05
189187
max_new_token: int = 2048
190188
stream_batch: int = 24
@@ -196,13 +194,13 @@ def infer(
196194
text,
197195
stream=False,
198196
lang=None,
199-
skip_refine_text=True,
197+
skip_refine_text=False,
200198
refine_text_only=False,
201199
use_decoder=True,
202200
do_text_normalization=True,
203201
do_homophone_replacement=True,
204-
params_refine_text=None,
205-
params_infer_code=None,
202+
params_refine_text=RefineTextParams(),
203+
params_infer_code=InferCodeParams(),
206204
stream_batch_size=16,
207205
):
208206
self.context.set(False)
@@ -273,7 +271,7 @@ def _load(
273271
vq_config=asdict(self.config.dvae.vq),
274272
dim=self.config.dvae.decoder.idim,
275273
coef=coef,
276-
device=device,
274+
device=self.device,
277275
)
278276
.to(device)
279277
.eval()
@@ -290,8 +288,8 @@ def _load(
290288
self.config.embed.num_text_tokens,
291289
self.config.embed.num_vq,
292290
)
293-
embed.from_pretrained(embed_path, device=device)
294-
self.embed = embed.to(device)
291+
embed.from_pretrained(embed_path, device=self.device)
292+
self.embed = embed.to(self.device)
295293
self.logger.log(logging.INFO, "embed loaded.")
296294

297295
gpt = GPT(
@@ -343,15 +341,15 @@ def _load(
343341
async def _infer(
344342
self,
345343
text,
346-
stream=True,
344+
stream=False,
347345
lang=None,
348-
skip_refine_text=True,
346+
skip_refine_text=False,
349347
refine_text_only=False,
350348
use_decoder=True,
351349
do_text_normalization=True,
352350
do_homophone_replacement=True,
353-
params_refine_text=None,
354-
params_infer_code=None,
351+
params_refine_text=RefineTextParams(),
352+
params_infer_code=InferCodeParams(),
355353
stream_batch_size=16,
356354
):
357355

@@ -399,13 +397,11 @@ async def _infer(
399397
result.hiddens if use_decoder else result.ids,
400398
use_decoder,
401399
)
402-
403400
if result.finished:
404401
yield wavs[:, length:]
405402
else:
406403
# Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop.
407404
import librosa
408-
409405
silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10)
410406
silence_left = 0
411407
if len(silence_intervals) == 0:
@@ -504,8 +500,8 @@ async def _infer_code(
504500
repetition_penalty=params.repetition_penalty,
505501
)
506502

507-
speaker_embedding_param = gpt(input_ids, text_mask)
508-
503+
speaker_embedding_param = self.embed(input_ids, text_mask)
504+
del text_mask
509505
if params.spk_emb is not None:
510506
self.speaker.apply(
511507
speaker_embedding_param,
@@ -536,7 +532,7 @@ async def _infer_code(
536532
async for i in results_generator:
537533
token_ids = []
538534
hidden_states = []
539-
if len(i.outputs[0].token_ids) % stream_batch_size == 0 or i.finished:
535+
if (stream and len(i.outputs[0].token_ids) % stream_batch_size == 0) or i.finished:
540536
token_ids.append(torch.tensor(i.outputs[0].token_ids))
541537
hidden_states.append(
542538
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
@@ -547,6 +543,40 @@ async def _infer_code(
547543
hiddens=hidden_states,
548544
attentions=[],
549545
)
546+
else:
547+
results_generator = gpt.generate(
548+
speaker_embedding_param,
549+
input_ids,
550+
temperature=torch.tensor(temperature, device=device),
551+
eos_token=num_code,
552+
attention_mask=attention_mask,
553+
max_new_token=params.max_new_token,
554+
min_new_token=params.min_new_token,
555+
logits_processors=(*logits_processors, *logits_warpers),
556+
infer_text=False,
557+
return_hidden=return_hidden,
558+
stream=stream,
559+
show_tqdm=params.show_tqdm,
560+
ensure_non_empty=params.ensure_non_empty,
561+
stream_batch=params.stream_batch,
562+
manual_seed=params.manual_seed,
563+
context=self.context,
564+
)
565+
del speaker_embedding_param, input_ids
566+
async for i in results_generator:
567+
token_ids = []
568+
hidden_states = []
569+
if (stream and len(i.ids[0]) % stream_batch_size == 0) or i.finished:
570+
token_ids.append(i.ids[0])
571+
hidden_states.append(
572+
i.hiddens[0].to(torch.float32).to(self.device)
573+
)
574+
yield GPT.GenerationOutputs(
575+
ids=token_ids,
576+
finished=i.finished,
577+
hiddens=hidden_states,
578+
attentions=[],
579+
)
550580

551581
@torch.no_grad()
552582
def _refine_text(

ChatTTS/model/gpt.py

Lines changed: 11 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,16 @@ def __init__(
4646
self.is_te_llama = False
4747
self.is_vllm = use_vllm
4848

49-
self.emb_code = [ec.__call__ for ec in embed.emb_code]
50-
self.emb_text = embed.emb_text.__call__
51-
self.head_text = embed.head_text.__call__
52-
self.head_code = [hc.__call__ for hc in embed.head_code]
5349
if self.is_vllm:
5450
return
5551

5652
self.llama_config = self._build_llama_config(gpt_config)
5753

54+
self.emb_code = [ec.__call__ for ec in embed.emb_code]
55+
self.emb_text = embed.emb_text.__call__
56+
self.head_text = embed.head_text.__call__
57+
self.head_code = [hc.__call__ for hc in embed.head_code]
58+
5859
def from_pretrained(
5960
self, gpt_folder: str, embed_file_path: str, experimental=False
6061
):
@@ -67,7 +68,7 @@ def from_pretrained(
6768
num_audio_tokens=self.num_audio_tokens,
6869
num_text_tokens=self.num_text_tokens,
6970
post_model_path=embed_file_path,
70-
dtype="float32",
71+
dtype="float32"
7172
)
7273
self.logger.info("vLLM model loaded")
7374
return
@@ -138,44 +139,6 @@ def prepare(self, compile=False):
138139
except RuntimeError as e:
139140
self.logger.warning(f"compile failed: {e}. fallback to normal mode.")
140141

141-
def __call__(
142-
self, input_ids: torch.Tensor, text_mask: torch.Tensor
143-
) -> torch.Tensor:
144-
"""
145-
get_emb
146-
"""
147-
return super().__call__(input_ids, text_mask)
148-
149-
def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor:
150-
"""
151-
get_emb
152-
"""
153-
input_ids = input_ids.clone()
154-
text_mask = text_mask.clone()
155-
emb_text: torch.Tensor = self.emb_text(
156-
input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(self.device_gpt)
157-
)
158-
159-
text_mask_inv = text_mask.logical_not().to(self.device_gpt)
160-
masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(self.device_gpt)
161-
162-
emb_code = [
163-
self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)
164-
]
165-
emb_code = torch.stack(emb_code, 2).sum(2)
166-
167-
emb = torch.zeros(
168-
(input_ids.shape[:-1]) + (emb_text.shape[-1],),
169-
device=emb_text.device,
170-
dtype=emb_text.dtype,
171-
)
172-
emb[text_mask] = emb_text
173-
emb[text_mask_inv] = emb_code.to(emb.dtype)
174-
175-
del emb_text, emb_code, text_mask_inv
176-
177-
return emb
178-
179142
@dataclass(repr=False, eq=False)
180143
class _GenerationInputs:
181144
position_ids: torch.Tensor
@@ -327,6 +290,7 @@ def _prepare_generation_outputs(
327290
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]],
328291
hiddens: List[torch.Tensor],
329292
infer_text: bool,
293+
finished: bool,
330294
) -> GenerationOutputs:
331295
inputs_ids = [
332296
inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx)
@@ -344,10 +308,11 @@ def _prepare_generation_outputs(
344308
ids=inputs_ids,
345309
attentions=attentions,
346310
hiddens=hiddens,
311+
finished=finished,
347312
)
348313

349314
@torch.no_grad()
350-
def generate(
315+
async def generate(
351316
self,
352317
emb: torch.Tensor,
353318
inputs_ids: torch.Tensor,
@@ -620,6 +585,7 @@ def generate(
620585
attentions,
621586
hiddens,
622587
infer_text,
588+
False
623589
)
624590
del not_finished
625591

@@ -649,4 +615,5 @@ def generate(
649615
attentions,
650616
hiddens,
651617
infer_text,
618+
True
652619
)

ChatTTS/model/velocity/llm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from .async_llm_engine import AsyncLLMEngine
88
from .configs import EngineArgs
9-
from .llm_engine import LLMEngine
109
from .output import RequestOutput
1110
from .sampling_params import SamplingParams
1211

ChatTTS/model/velocity/llm_engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from vllm.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer
2424
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port
25-
import numpy as np
2625
import torch
2726

2827
if ray:

0 commit comments

Comments
 (0)