Skip to content

fix(core): mps infer incorrect on some device (fix #444) #454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 10 additions & 33 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import tempfile
from dataclasses import dataclass
from typing import Literal, Optional, List, Callable, Tuple, Dict
from functools import lru_cache
from json import load
from pathlib import Path

Expand All @@ -13,9 +12,8 @@
from omegaconf import OmegaConf
from vocos import Vocos
from huggingface_hub import snapshot_download
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper

from .model import DVAE, GPT, CustomRepetitionPenaltyLogitsProcessorRepeat
from .model import DVAE, GPT, gen_logits
from .utils import (
check_all_assets,
download_all_assets,
Expand Down Expand Up @@ -72,10 +70,10 @@ def download_models(
) -> Optional[str]:
if source == "local":
download_path = os.getcwd()
if not check_all_assets(self.sha256_map, update=True) or force_redownload:
if not check_all_assets(Path(download_path), self.sha256_map, update=True) or force_redownload:
with tempfile.TemporaryDirectory() as tmp:
download_all_assets(tmpdir=tmp)
if not check_all_assets(self.sha256_map, update=False):
if not check_all_assets(Path(download_path), self.sha256_map, update=False):
self.logger.error(
"download to local path %s failed.", download_path
)
Expand Down Expand Up @@ -108,7 +106,7 @@ def download_models(
return None
elif source == "custom":
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
if not check_all_assets(self.sha256_map, update=False, base_dir=Path(custom_path)):
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
self.logger.error(
"check models in custom path %s failed.", custom_path
)
Expand Down Expand Up @@ -147,7 +145,6 @@ def unload(self):
self.normalizer.destroy()
del self.normalizer
del self.sha256_map
self._gen_logits.cache_clear()
del_list = ["vocos", "_vocos_decode", "gpt", "decoder", "dvae"]
for module in del_list:
if hasattr(self, module):
Expand Down Expand Up @@ -415,30 +412,6 @@ def _gen_gpt_inputs(self, text: str, device="cpu"):

return input_ids, text_token, text_mask

@lru_cache
def _gen_logits(
self,
num_code: int,
top_P=0.7,
top_K=20,
repetition_penalty=1.0,
):
logits_warpers = []
if top_P is not None:
logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
if top_K is not None:
logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))

logits_processors = []
if repetition_penalty is not None and repetition_penalty != 1:
logits_processors.append(
CustomRepetitionPenaltyLogitsProcessorRepeat(
repetition_penalty, num_code, 16
)
)

return logits_warpers, logits_processors

def _apply_spk_emb(
self,
emb: torch.Tensor,
Expand Down Expand Up @@ -494,7 +467,7 @@ def _infer_code(

num_code = int(gpt.emb_code[0].num_embeddings - 1)

logits_warpers, logits_processors = self._gen_logits(
logits_warpers, logits_processors = gen_logits(
num_code=num_code,
top_P=params.top_P,
top_K=params.top_K,
Expand All @@ -519,6 +492,8 @@ def _infer_code(

del_all(text_token)
del emb, text_token, input_ids
del_all(logits_warpers)
del_all(logits_processors)

return result

Expand All @@ -539,7 +514,7 @@ def _refine_text(

input_ids, text_token, text_mask = self._gen_gpt_inputs(text, gpt.device_gpt)

logits_warpers, logits_processors = self._gen_logits(
logits_warpers, logits_processors = gen_logits(
num_code=len(tokenizer),
top_P=params.top_P,
top_K=params.top_K,
Expand Down Expand Up @@ -568,5 +543,7 @@ def _refine_text(

del_all(text_token)
del emb, text_token, input_ids
del_all(logits_warpers)
del_all(logits_processors)

return next(result)
2 changes: 1 addition & 1 deletion ChatTTS/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .dvae import DVAE
from .gpt import GPT
from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat
from .processors import gen_logits
41 changes: 21 additions & 20 deletions ChatTTS/model/processors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper


class CustomRepetitionPenaltyLogitsProcessorRepeat:
Expand Down Expand Up @@ -31,24 +32,24 @@ def __call__(
del inp, oth, scores, con, alpha
return out

def gen_logits(
num_code: int,
top_P=0.7,
top_K=20,
repetition_penalty=1.0,
):
logits_warpers = []
if top_P is not None:
logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
if top_K is not None:
logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))

logits_processors = []
if repetition_penalty is not None and repetition_penalty != 1:
logits_processors.append(
CustomRepetitionPenaltyLogitsProcessorRepeat(
repetition_penalty, num_code, 16
)
)

"""class CustomRepetitionPenaltyLogitsProcessor():

def __init__(self, penalty: float, max_input_ids: int, past_window: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

input_ids = input_ids[:, -self.past_window:]
score = torch.gather(scores, 1, input_ids)
_score = score.detach().clone()
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
scores.scatter_(1, input_ids, score)

return scores"""
return logits_warpers, logits_processors
2 changes: 1 addition & 1 deletion ChatTTS/utils/dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def check_model(
return True


def check_all_assets(sha256_map: dict[str, str], update=False, base_dir = Path(os.getcwd())) -> bool:
def check_all_assets(base_dir: Path, sha256_map: dict[str, str], update=False) -> bool:
logger.get_logger().info("checking assets...")
current_dir = base_dir / "asset"
names = [
Expand Down
Loading