diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 3919e1333..b74310fe9 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -3,7 +3,6 @@ import tempfile from dataclasses import dataclass from typing import Literal, Optional, List, Callable, Tuple, Dict -from functools import lru_cache from json import load from pathlib import Path @@ -13,9 +12,8 @@ from omegaconf import OmegaConf from vocos import Vocos from huggingface_hub import snapshot_download -from transformers.generation import TopKLogitsWarper, TopPLogitsWarper -from .model import DVAE, GPT, CustomRepetitionPenaltyLogitsProcessorRepeat +from .model import DVAE, GPT, gen_logits from .utils import ( check_all_assets, download_all_assets, @@ -72,10 +70,10 @@ def download_models( ) -> Optional[str]: if source == "local": download_path = os.getcwd() - if not check_all_assets(self.sha256_map, update=True) or force_redownload: + if not check_all_assets(Path(download_path), self.sha256_map, update=True) or force_redownload: with tempfile.TemporaryDirectory() as tmp: download_all_assets(tmpdir=tmp) - if not check_all_assets(self.sha256_map, update=False): + if not check_all_assets(Path(download_path), self.sha256_map, update=False): self.logger.error( "download to local path %s failed.", download_path ) @@ -108,7 +106,7 @@ def download_models( return None elif source == "custom": self.logger.log(logging.INFO, f"try to load from local: {custom_path}") - if not check_all_assets(self.sha256_map, update=False, base_dir=Path(custom_path)): + if not check_all_assets(Path(custom_path), self.sha256_map, update=False): self.logger.error( "check models in custom path %s failed.", custom_path ) @@ -147,7 +145,6 @@ def unload(self): self.normalizer.destroy() del self.normalizer del self.sha256_map - self._gen_logits.cache_clear() del_list = ["vocos", "_vocos_decode", "gpt", "decoder", "dvae"] for module in del_list: if hasattr(self, module): @@ -415,30 +412,6 @@ def _gen_gpt_inputs(self, text: str, device="cpu"): return input_ids, text_token, text_mask - @lru_cache - def _gen_logits( - self, - num_code: int, - top_P=0.7, - top_K=20, - repetition_penalty=1.0, - ): - logits_warpers = [] - if top_P is not None: - logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) - if top_K is not None: - logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) - - logits_processors = [] - if repetition_penalty is not None and repetition_penalty != 1: - logits_processors.append( - CustomRepetitionPenaltyLogitsProcessorRepeat( - repetition_penalty, num_code, 16 - ) - ) - - return logits_warpers, logits_processors - def _apply_spk_emb( self, emb: torch.Tensor, @@ -494,7 +467,7 @@ def _infer_code( num_code = int(gpt.emb_code[0].num_embeddings - 1) - logits_warpers, logits_processors = self._gen_logits( + logits_warpers, logits_processors = gen_logits( num_code=num_code, top_P=params.top_P, top_K=params.top_K, @@ -519,6 +492,8 @@ def _infer_code( del_all(text_token) del emb, text_token, input_ids + del_all(logits_warpers) + del_all(logits_processors) return result @@ -539,7 +514,7 @@ def _refine_text( input_ids, text_token, text_mask = self._gen_gpt_inputs(text, gpt.device_gpt) - logits_warpers, logits_processors = self._gen_logits( + logits_warpers, logits_processors = gen_logits( num_code=len(tokenizer), top_P=params.top_P, top_K=params.top_K, @@ -568,5 +543,7 @@ def _refine_text( del_all(text_token) del emb, text_token, input_ids + del_all(logits_warpers) + del_all(logits_processors) return next(result) diff --git a/ChatTTS/model/__init__.py b/ChatTTS/model/__init__.py index e2bc5bdf9..aee737f63 100644 --- a/ChatTTS/model/__init__.py +++ b/ChatTTS/model/__init__.py @@ -1,3 +1,3 @@ from .dvae import DVAE from .gpt import GPT -from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat +from .processors import gen_logits diff --git a/ChatTTS/model/processors.py b/ChatTTS/model/processors.py index de1879990..41c4aff92 100644 --- a/ChatTTS/model/processors.py +++ b/ChatTTS/model/processors.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from transformers.generation import TopKLogitsWarper, TopPLogitsWarper class CustomRepetitionPenaltyLogitsProcessorRepeat: @@ -31,24 +32,24 @@ def __call__( del inp, oth, scores, con, alpha return out +def gen_logits( + num_code: int, + top_P=0.7, + top_K=20, + repetition_penalty=1.0, +): + logits_warpers = [] + if top_P is not None: + logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) + if top_K is not None: + logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) + + logits_processors = [] + if repetition_penalty is not None and repetition_penalty != 1: + logits_processors.append( + CustomRepetitionPenaltyLogitsProcessorRepeat( + repetition_penalty, num_code, 16 + ) + ) -"""class CustomRepetitionPenaltyLogitsProcessor(): - - def __init__(self, penalty: float, max_input_ids: int, past_window: int): - if not isinstance(penalty, float) or not (penalty > 0): - raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") - - self.penalty = penalty - self.max_input_ids = max_input_ids - self.past_window = past_window - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - - input_ids = input_ids[:, -self.past_window:] - score = torch.gather(scores, 1, input_ids) - _score = score.detach().clone() - score = torch.where(score < 0, score * self.penalty, score / self.penalty) - score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids] - scores.scatter_(1, input_ids, score) - - return scores""" + return logits_warpers, logits_processors diff --git a/ChatTTS/utils/dl.py b/ChatTTS/utils/dl.py index ede306605..462f509ad 100644 --- a/ChatTTS/utils/dl.py +++ b/ChatTTS/utils/dl.py @@ -42,7 +42,7 @@ def check_model( return True -def check_all_assets(sha256_map: dict[str, str], update=False, base_dir = Path(os.getcwd())) -> bool: +def check_all_assets(base_dir: Path, sha256_map: dict[str, str], update=False) -> bool: logger.get_logger().info("checking assets...") current_dir = base_dir / "asset" names = [