Skip to content

Commit 3836db8

Browse files
committed
feat: add infer param show_tqdm
1 parent eb0fb71 commit 3836db8

File tree

2 files changed

+148
-142
lines changed

2 files changed

+148
-142
lines changed

ChatTTS/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,15 @@ class RefineTextParams:
197197
repetition_penalty: float = 1.0
198198
max_new_token: int = 384
199199
min_new_token: int = 0
200+
show_tqdm: bool = True
200201

201202
@dataclass(repr=False, eq=False)
202-
class InferCodeParams:
203+
class InferCodeParams(RefineTextParams):
203204
prompt: str = "[speed_5]"
204205
spk_emb: Optional[str] = None
205-
top_P: float = 0.7
206-
top_K: int = 20
207206
temperature: float = 0.3
208207
repetition_penalty: float = 1.05
209208
max_new_token: int = 2048
210-
min_new_token: int = 0
211209

212210
def infer(
213211
self,
@@ -596,6 +594,7 @@ def _infer_code(
596594
infer_text=False,
597595
return_hidden=return_hidden,
598596
stream=stream,
597+
show_tqdm=params.show_tqdm,
599598
context=self.context,
600599
)
601600

@@ -644,6 +643,7 @@ def _refine_text(
644643
logits_processors=logits_processors,
645644
infer_text=True,
646645
stream=False,
646+
show_tqdm=params.show_tqdm,
647647
context=self.context,
648648
)
649649
)

ChatTTS/model/gpt.py

Lines changed: 144 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def generate(
335335
return_attn=False,
336336
return_hidden=False,
337337
stream=False,
338+
show_tqdm=True,
338339
context=Context(),
339340
):
340341

@@ -368,160 +369,165 @@ def generate(
368369
attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
369370
attention_mask
370371
)
372+
373+
pbar: Optional[tqdm] = None
374+
375+
if show_tqdm:
376+
pbar = tqdm(
377+
total=max_new_token,
378+
desc="text" if infer_text else "code",
379+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
380+
)
371381

372-
with tqdm(
373-
total=max_new_token,
374-
desc="text" if infer_text else "code",
375-
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
376-
) as pbar:
377-
378-
past_key_values = None
382+
past_key_values = None
379383

380-
for i in range(max_new_token):
381-
model_input = self._prepare_generation_inputs(
382-
inputs_ids,
383-
past_key_values,
384-
attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
385-
use_cache=True,
386-
)
384+
for i in range(max_new_token):
385+
model_input = self._prepare_generation_inputs(
386+
inputs_ids,
387+
past_key_values,
388+
attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
389+
use_cache=True,
390+
)
387391

388-
if i > 0:
389-
del emb
390-
inputs_ids_emb = model_input.input_ids.to(self.device_gpt)
391-
if infer_text:
392-
emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0])
393-
else:
394-
code_emb = [
395-
self.emb_code[i](inputs_ids_emb[:, :, i])
396-
for i in range(self.num_vq)
397-
]
398-
emb = torch.stack(code_emb, 3).sum(3)
399-
del inputs_ids_emb, model_input.input_ids
400-
model_input.inputs_embeds = emb
401-
402-
model_input.to(self.device_gpt)
403-
404-
outputs: BaseModelOutputWithPast = self.gpt(
405-
attention_mask=model_input.attention_mask,
406-
position_ids=model_input.position_ids,
407-
past_key_values=model_input.past_key_values,
408-
inputs_embeds=model_input.inputs_embeds,
409-
use_cache=model_input.use_cache,
410-
output_attentions=return_attn,
411-
cache_position=model_input.cache_position,
412-
)
413-
del_all(model_input)
414-
attentions.append(outputs.attentions)
415-
hidden_states = outputs.last_hidden_state.to(self.device) # 🐻
416-
past_key_values = outputs.past_key_values
417-
del_all(outputs)
418-
if return_hidden:
419-
hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1))
420-
421-
with P.cached():
422-
if infer_text:
423-
logits: torch.Tensor = self.head_text(hidden_states)
424-
else:
425-
# logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
426-
logits = torch.empty(
427-
hidden_states.size(0),
428-
hidden_states.size(1),
429-
self.num_audio_tokens,
430-
self.num_vq,
431-
dtype=torch.float,
432-
device=self.device,
433-
)
434-
for i in range(self.num_vq):
435-
x: torch.Tensor = self.head_code[i](hidden_states)
436-
logits[..., i] = x
437-
del x
438-
439-
# logits = logits[:, -1].float()
440-
logits = logits.narrow(1, -1, 1).squeeze_(1).float()
441-
442-
if not infer_text:
443-
# logits = rearrange(logits, "b c n -> (b n) c")
444-
logits = logits.permute(0, 2, 1)
445-
logits = logits.reshape(-1, logits.size(2))
446-
# logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
447-
inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1)
448-
logits_token = inputs_ids_sliced.reshape(
449-
inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
450-
-1,
451-
).to(self.device)
392+
if i > 0:
393+
del emb
394+
inputs_ids_emb = model_input.input_ids.to(self.device_gpt)
395+
if infer_text:
396+
emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0])
397+
else:
398+
code_emb = [
399+
self.emb_code[i](inputs_ids_emb[:, :, i])
400+
for i in range(self.num_vq)
401+
]
402+
emb = torch.stack(code_emb, 3).sum(3)
403+
del inputs_ids_emb, model_input.input_ids
404+
model_input.inputs_embeds = emb
405+
406+
model_input.to(self.device_gpt)
407+
408+
outputs: BaseModelOutputWithPast = self.gpt(
409+
attention_mask=model_input.attention_mask,
410+
position_ids=model_input.position_ids,
411+
past_key_values=model_input.past_key_values,
412+
inputs_embeds=model_input.inputs_embeds,
413+
use_cache=model_input.use_cache,
414+
output_attentions=return_attn,
415+
cache_position=model_input.cache_position,
416+
)
417+
del_all(model_input)
418+
attentions.append(outputs.attentions)
419+
hidden_states = outputs.last_hidden_state.to(self.device) # 🐻
420+
past_key_values = outputs.past_key_values
421+
del_all(outputs)
422+
if return_hidden:
423+
hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1))
424+
425+
with P.cached():
426+
if infer_text:
427+
logits: torch.Tensor = self.head_text(hidden_states)
452428
else:
453-
logits_token = inputs_ids[:, start_idx:, 0].to(self.device)
429+
# logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
430+
logits = torch.empty(
431+
hidden_states.size(0),
432+
hidden_states.size(1),
433+
self.num_audio_tokens,
434+
self.num_vq,
435+
dtype=torch.float,
436+
device=self.device,
437+
)
438+
for i in range(self.num_vq):
439+
x: torch.Tensor = self.head_code[i](hidden_states)
440+
logits[..., i] = x
441+
del x
442+
443+
# logits = logits[:, -1].float()
444+
logits = logits.narrow(1, -1, 1).squeeze_(1).float()
445+
446+
if not infer_text:
447+
# logits = rearrange(logits, "b c n -> (b n) c")
448+
logits = logits.permute(0, 2, 1)
449+
logits = logits.reshape(-1, logits.size(2))
450+
# logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
451+
inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1)
452+
logits_token = inputs_ids_sliced.reshape(
453+
inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
454+
-1,
455+
).to(self.device)
456+
else:
457+
logits_token = inputs_ids[:, start_idx:, 0].to(self.device)
458+
459+
logits /= temperature
454460

455-
logits /= temperature
461+
for logitsProcessors in logits_processors:
462+
logits = logitsProcessors(logits_token, logits)
456463

457-
for logitsProcessors in logits_processors:
458-
logits = logitsProcessors(logits_token, logits)
464+
for logitsWarpers in logits_warpers:
465+
logits = logitsWarpers(logits_token, logits)
459466

460-
for logitsWarpers in logits_warpers:
461-
logits = logitsWarpers(logits_token, logits)
467+
del logits_token
462468

463-
del logits_token
469+
if i < min_new_token:
470+
logits[:, eos_token] = -torch.inf
464471

465-
if i < min_new_token:
466-
logits[:, eos_token] = -torch.inf
472+
scores = F.softmax(logits, dim=-1)
467473

468-
scores = F.softmax(logits, dim=-1)
474+
del logits
469475

470-
del logits
476+
idx_next = torch.multinomial(scores, num_samples=1).to(
477+
finish.device
478+
)
471479

472-
idx_next = torch.multinomial(scores, num_samples=1).to(
473-
finish.device
480+
if not infer_text:
481+
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
482+
idx_next = idx_next.view(-1, self.num_vq)
483+
finish_or = idx_next.eq(eos_token).any(1)
484+
finish.logical_or_(finish_or)
485+
del finish_or
486+
inputs_ids_tmp = torch.cat(
487+
[inputs_ids, idx_next.unsqueeze_(1)], 1
488+
)
489+
else:
490+
finish_or = idx_next.eq(eos_token).any(1)
491+
finish.logical_or_(finish_or)
492+
del finish_or
493+
inputs_ids_tmp = torch.cat(
494+
[
495+
inputs_ids,
496+
idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
497+
],
498+
1,
474499
)
475500

476-
if not infer_text:
477-
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
478-
idx_next = idx_next.view(-1, self.num_vq)
479-
finish_or = idx_next.eq(eos_token).any(1)
480-
finish.logical_or_(finish_or)
481-
del finish_or
482-
inputs_ids_tmp = torch.cat(
483-
[inputs_ids, idx_next.unsqueeze_(1)], 1
484-
)
485-
else:
486-
finish_or = idx_next.eq(eos_token).any(1)
487-
finish.logical_or_(finish_or)
488-
del finish_or
489-
inputs_ids_tmp = torch.cat(
490-
[
491-
inputs_ids,
492-
idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
493-
],
494-
1,
501+
del inputs_ids
502+
inputs_ids = inputs_ids_tmp
503+
del inputs_ids_tmp, idx_next
504+
505+
if stream:
506+
minus_prev_end_index = end_idx.neg()
507+
end_idx.add_((finish.logical_not().to(end_idx.device)).int())
508+
if stream:
509+
if (
510+
end_idx.all()
511+
and end_idx.fmod(24).eq(0).any()
512+
and minus_prev_end_index.add_(end_idx).any()
513+
):
514+
self.logger.debug("yield stream result, end: %d", end_idx)
515+
yield self._prepare_generation_outputs(
516+
inputs_ids,
517+
start_idx,
518+
end_idx,
519+
attentions,
520+
hiddens,
521+
infer_text,
495522
)
523+
del minus_prev_end_index
524+
525+
if finish.all() or context.get():
526+
break
496527

497-
del inputs_ids
498-
inputs_ids = inputs_ids_tmp
499-
del inputs_ids_tmp, idx_next
500-
501-
if stream:
502-
minus_prev_end_index = end_idx.neg()
503-
end_idx.add_((finish.logical_not().to(end_idx.device)).int())
504-
if stream:
505-
if (
506-
end_idx.all()
507-
and end_idx.fmod(24).eq(0).any()
508-
and minus_prev_end_index.add_(end_idx).any()
509-
):
510-
self.logger.debug("yield stream result, end: %d", end_idx)
511-
yield self._prepare_generation_outputs(
512-
inputs_ids,
513-
start_idx,
514-
end_idx,
515-
attentions,
516-
hiddens,
517-
infer_text,
518-
)
519-
del minus_prev_end_index
520-
521-
if finish.all() or context.get():
522-
break
523-
524-
pbar.update(1)
528+
if pbar is not None: pbar.update(1)
529+
530+
if pbar is not None: pbar.close()
525531

526532
if not finish.all():
527533
if context.get():

0 commit comments

Comments
 (0)