Skip to content

Commit

Permalink
feat: add infer param show_tqdm
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 1, 2024
1 parent eb0fb71 commit 3836db8
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 142 deletions.
8 changes: 4 additions & 4 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,15 @@ class RefineTextParams:
repetition_penalty: float = 1.0
max_new_token: int = 384
min_new_token: int = 0
show_tqdm: bool = True

@dataclass(repr=False, eq=False)
class InferCodeParams:
class InferCodeParams(RefineTextParams):
prompt: str = "[speed_5]"
spk_emb: Optional[str] = None
top_P: float = 0.7
top_K: int = 20
temperature: float = 0.3
repetition_penalty: float = 1.05
max_new_token: int = 2048
min_new_token: int = 0

def infer(
self,
Expand Down Expand Up @@ -596,6 +594,7 @@ def _infer_code(
infer_text=False,
return_hidden=return_hidden,
stream=stream,
show_tqdm=params.show_tqdm,
context=self.context,
)

Expand Down Expand Up @@ -644,6 +643,7 @@ def _refine_text(
logits_processors=logits_processors,
infer_text=True,
stream=False,
show_tqdm=params.show_tqdm,
context=self.context,
)
)
Expand Down
282 changes: 144 additions & 138 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def generate(
return_attn=False,
return_hidden=False,
stream=False,
show_tqdm=True,
context=Context(),
):

Expand Down Expand Up @@ -368,160 +369,165 @@ def generate(
attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
attention_mask
)

pbar: Optional[tqdm] = None

if show_tqdm:
pbar = tqdm(
total=max_new_token,
desc="text" if infer_text else "code",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
)

with tqdm(
total=max_new_token,
desc="text" if infer_text else "code",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
) as pbar:

past_key_values = None
past_key_values = None

for i in range(max_new_token):
model_input = self._prepare_generation_inputs(
inputs_ids,
past_key_values,
attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
use_cache=True,
)
for i in range(max_new_token):
model_input = self._prepare_generation_inputs(
inputs_ids,
past_key_values,
attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
use_cache=True,
)

if i > 0:
del emb
inputs_ids_emb = model_input.input_ids.to(self.device_gpt)
if infer_text:
emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0])
else:
code_emb = [
self.emb_code[i](inputs_ids_emb[:, :, i])
for i in range(self.num_vq)
]
emb = torch.stack(code_emb, 3).sum(3)
del inputs_ids_emb, model_input.input_ids
model_input.inputs_embeds = emb

model_input.to(self.device_gpt)

outputs: BaseModelOutputWithPast = self.gpt(
attention_mask=model_input.attention_mask,
position_ids=model_input.position_ids,
past_key_values=model_input.past_key_values,
inputs_embeds=model_input.inputs_embeds,
use_cache=model_input.use_cache,
output_attentions=return_attn,
cache_position=model_input.cache_position,
)
del_all(model_input)
attentions.append(outputs.attentions)
hidden_states = outputs.last_hidden_state.to(self.device) # 🐻
past_key_values = outputs.past_key_values
del_all(outputs)
if return_hidden:
hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1))

with P.cached():
if infer_text:
logits: torch.Tensor = self.head_text(hidden_states)
else:
# logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
logits = torch.empty(
hidden_states.size(0),
hidden_states.size(1),
self.num_audio_tokens,
self.num_vq,
dtype=torch.float,
device=self.device,
)
for i in range(self.num_vq):
x: torch.Tensor = self.head_code[i](hidden_states)
logits[..., i] = x
del x

# logits = logits[:, -1].float()
logits = logits.narrow(1, -1, 1).squeeze_(1).float()

if not infer_text:
# logits = rearrange(logits, "b c n -> (b n) c")
logits = logits.permute(0, 2, 1)
logits = logits.reshape(-1, logits.size(2))
# logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1)
logits_token = inputs_ids_sliced.reshape(
inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
-1,
).to(self.device)
if i > 0:
del emb
inputs_ids_emb = model_input.input_ids.to(self.device_gpt)
if infer_text:
emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0])
else:
code_emb = [
self.emb_code[i](inputs_ids_emb[:, :, i])
for i in range(self.num_vq)
]
emb = torch.stack(code_emb, 3).sum(3)
del inputs_ids_emb, model_input.input_ids
model_input.inputs_embeds = emb

model_input.to(self.device_gpt)

outputs: BaseModelOutputWithPast = self.gpt(
attention_mask=model_input.attention_mask,
position_ids=model_input.position_ids,
past_key_values=model_input.past_key_values,
inputs_embeds=model_input.inputs_embeds,
use_cache=model_input.use_cache,
output_attentions=return_attn,
cache_position=model_input.cache_position,
)
del_all(model_input)
attentions.append(outputs.attentions)
hidden_states = outputs.last_hidden_state.to(self.device) # 🐻
past_key_values = outputs.past_key_values
del_all(outputs)
if return_hidden:
hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1))

with P.cached():
if infer_text:
logits: torch.Tensor = self.head_text(hidden_states)
else:
logits_token = inputs_ids[:, start_idx:, 0].to(self.device)
# logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
logits = torch.empty(
hidden_states.size(0),
hidden_states.size(1),
self.num_audio_tokens,
self.num_vq,
dtype=torch.float,
device=self.device,
)
for i in range(self.num_vq):
x: torch.Tensor = self.head_code[i](hidden_states)
logits[..., i] = x
del x

# logits = logits[:, -1].float()
logits = logits.narrow(1, -1, 1).squeeze_(1).float()

if not infer_text:
# logits = rearrange(logits, "b c n -> (b n) c")
logits = logits.permute(0, 2, 1)
logits = logits.reshape(-1, logits.size(2))
# logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1)
logits_token = inputs_ids_sliced.reshape(
inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
-1,
).to(self.device)
else:
logits_token = inputs_ids[:, start_idx:, 0].to(self.device)

logits /= temperature

logits /= temperature
for logitsProcessors in logits_processors:
logits = logitsProcessors(logits_token, logits)

for logitsProcessors in logits_processors:
logits = logitsProcessors(logits_token, logits)
for logitsWarpers in logits_warpers:
logits = logitsWarpers(logits_token, logits)

for logitsWarpers in logits_warpers:
logits = logitsWarpers(logits_token, logits)
del logits_token

del logits_token
if i < min_new_token:
logits[:, eos_token] = -torch.inf

if i < min_new_token:
logits[:, eos_token] = -torch.inf
scores = F.softmax(logits, dim=-1)

scores = F.softmax(logits, dim=-1)
del logits

del logits
idx_next = torch.multinomial(scores, num_samples=1).to(
finish.device
)

idx_next = torch.multinomial(scores, num_samples=1).to(
finish.device
if not infer_text:
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
idx_next = idx_next.view(-1, self.num_vq)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_tmp = torch.cat(
[inputs_ids, idx_next.unsqueeze_(1)], 1
)
else:
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_tmp = torch.cat(
[
inputs_ids,
idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
],
1,
)

if not infer_text:
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
idx_next = idx_next.view(-1, self.num_vq)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_tmp = torch.cat(
[inputs_ids, idx_next.unsqueeze_(1)], 1
)
else:
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_tmp = torch.cat(
[
inputs_ids,
idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
],
1,
del inputs_ids
inputs_ids = inputs_ids_tmp
del inputs_ids_tmp, idx_next

if stream:
minus_prev_end_index = end_idx.neg()
end_idx.add_((finish.logical_not().to(end_idx.device)).int())
if stream:
if (
end_idx.all()
and end_idx.fmod(24).eq(0).any()
and minus_prev_end_index.add_(end_idx).any()
):
self.logger.debug("yield stream result, end: %d", end_idx)
yield self._prepare_generation_outputs(
inputs_ids,
start_idx,
end_idx,
attentions,
hiddens,
infer_text,
)
del minus_prev_end_index

if finish.all() or context.get():
break

del inputs_ids
inputs_ids = inputs_ids_tmp
del inputs_ids_tmp, idx_next

if stream:
minus_prev_end_index = end_idx.neg()
end_idx.add_((finish.logical_not().to(end_idx.device)).int())
if stream:
if (
end_idx.all()
and end_idx.fmod(24).eq(0).any()
and minus_prev_end_index.add_(end_idx).any()
):
self.logger.debug("yield stream result, end: %d", end_idx)
yield self._prepare_generation_outputs(
inputs_ids,
start_idx,
end_idx,
attentions,
hiddens,
infer_text,
)
del minus_prev_end_index

if finish.all() or context.get():
break

pbar.update(1)
if pbar is not None: pbar.update(1)

if pbar is not None: pbar.close()

if not finish.all():
if context.get():
Expand Down

0 comments on commit 3836db8

Please sign in to comment.