Skip to content

Commit

Permalink
fix(core): mps infer incorrect on some device (fix #444)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jun 25, 2024
1 parent 21e48be commit 39026d9
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 55 deletions.
43 changes: 10 additions & 33 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import tempfile
from dataclasses import dataclass
from typing import Literal, Optional, List, Callable, Tuple, Dict
from functools import lru_cache
from json import load
from pathlib import Path

Expand All @@ -13,9 +12,8 @@
from omegaconf import OmegaConf
from vocos import Vocos
from huggingface_hub import snapshot_download
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper

from .model import DVAE, GPT, CustomRepetitionPenaltyLogitsProcessorRepeat
from .model import DVAE, GPT, gen_logits
from .utils import (
check_all_assets,
download_all_assets,
Expand Down Expand Up @@ -72,10 +70,10 @@ def download_models(
) -> Optional[str]:
if source == "local":
download_path = os.getcwd()
if not check_all_assets(self.sha256_map, update=True) or force_redownload:
if not check_all_assets(Path(download_path), self.sha256_map, update=True) or force_redownload:
with tempfile.TemporaryDirectory() as tmp:
download_all_assets(tmpdir=tmp)
if not check_all_assets(self.sha256_map, update=False):
if not check_all_assets(Path(download_path), self.sha256_map, update=False):
self.logger.error(
"download to local path %s failed.", download_path
)
Expand Down Expand Up @@ -108,7 +106,7 @@ def download_models(
return None
elif source == "custom":
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
if not check_all_assets(self.sha256_map, update=False, base_dir=Path(custom_path)):
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
self.logger.error(
"check models in custom path %s failed.", custom_path
)
Expand Down Expand Up @@ -147,7 +145,6 @@ def unload(self):
self.normalizer.destroy()
del self.normalizer
del self.sha256_map
self._gen_logits.cache_clear()
del_list = ["vocos", "_vocos_decode", "gpt", "decoder", "dvae"]
for module in del_list:
if hasattr(self, module):
Expand Down Expand Up @@ -415,30 +412,6 @@ def _gen_gpt_inputs(self, text: str, device="cpu"):

return input_ids, text_token, text_mask

@lru_cache
def _gen_logits(
self,
num_code: int,
top_P=0.7,
top_K=20,
repetition_penalty=1.0,
):
logits_warpers = []
if top_P is not None:
logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
if top_K is not None:
logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))

logits_processors = []
if repetition_penalty is not None and repetition_penalty != 1:
logits_processors.append(
CustomRepetitionPenaltyLogitsProcessorRepeat(
repetition_penalty, num_code, 16
)
)

return logits_warpers, logits_processors

def _apply_spk_emb(
self,
emb: torch.Tensor,
Expand Down Expand Up @@ -494,7 +467,7 @@ def _infer_code(

num_code = int(gpt.emb_code[0].num_embeddings - 1)

logits_warpers, logits_processors = self._gen_logits(
logits_warpers, logits_processors = gen_logits(
num_code=num_code,
top_P=params.top_P,
top_K=params.top_K,
Expand All @@ -519,6 +492,8 @@ def _infer_code(

del_all(text_token)
del emb, text_token, input_ids
del_all(logits_warpers)
del_all(logits_processors)

return result

Expand All @@ -539,7 +514,7 @@ def _refine_text(

input_ids, text_token, text_mask = self._gen_gpt_inputs(text, gpt.device_gpt)

logits_warpers, logits_processors = self._gen_logits(
logits_warpers, logits_processors = gen_logits(
num_code=len(tokenizer),
top_P=params.top_P,
top_K=params.top_K,
Expand Down Expand Up @@ -568,5 +543,7 @@ def _refine_text(

del_all(text_token)
del emb, text_token, input_ids
del_all(logits_warpers)
del_all(logits_processors)

return next(result)
2 changes: 1 addition & 1 deletion ChatTTS/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .dvae import DVAE
from .gpt import GPT
from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat
from .processors import gen_logits
41 changes: 21 additions & 20 deletions ChatTTS/model/processors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper


class CustomRepetitionPenaltyLogitsProcessorRepeat:
Expand Down Expand Up @@ -31,24 +32,24 @@ def __call__(
del inp, oth, scores, con, alpha
return out

def gen_logits(
num_code: int,
top_P=0.7,
top_K=20,
repetition_penalty=1.0,
):
logits_warpers = []
if top_P is not None:
logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
if top_K is not None:
logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))

logits_processors = []
if repetition_penalty is not None and repetition_penalty != 1:
logits_processors.append(
CustomRepetitionPenaltyLogitsProcessorRepeat(
repetition_penalty, num_code, 16
)
)

"""class CustomRepetitionPenaltyLogitsProcessor():
def __init__(self, penalty: float, max_input_ids: int, past_window: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self.past_window:]
score = torch.gather(scores, 1, input_ids)
_score = score.detach().clone()
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
scores.scatter_(1, input_ids, score)
return scores"""
return logits_warpers, logits_processors
2 changes: 1 addition & 1 deletion ChatTTS/utils/dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def check_model(
return True


def check_all_assets(sha256_map: dict[str, str], update=False, base_dir = Path(os.getcwd())) -> bool:
def check_all_assets(base_dir: Path, sha256_map: dict[str, str], update=False) -> bool:
logger.get_logger().info("checking assets...")
current_dir = base_dir / "asset"
names = [
Expand Down

0 comments on commit 39026d9

Please sign in to comment.