diff --git a/ChatTTS/core.py b/ChatTTS/core.py index c178a9ad2..28cebd92a 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -68,7 +68,7 @@ def download_models( custom_path: Optional[torch.serialization.FILE_LIKE] = None, ) -> Optional[str]: if source == "local": - download_path = os.getcwd() + download_path = custom_path if custom_path is not None else os.getcwd() if ( not check_all_assets(Path(download_path), self.sha256_map, update=True) or force_redownload @@ -83,10 +83,20 @@ def download_models( ) return None elif source == "huggingface": - hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) try: - download_path = get_latest_modified_file( - os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots") + download_path = ( + get_latest_modified_file( + os.path.join( + os.getenv( + "HF_HOME", os.path.expanduser("~/.cache/huggingface") + ), + "hub/models--2Noise--ChatTTS/snapshots", + ) + ) + if custom_path is None + else get_latest_modified_file( + os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots") + ) ) except: download_path = None @@ -99,16 +109,16 @@ def download_models( download_path = snapshot_download( repo_id="2Noise/ChatTTS", allow_patterns=["*.yaml", "*.json", "*.safetensors"], + cache_dir=custom_path, + force_download=force_redownload, ) except: download_path = None - else: - self.logger.log( - logging.INFO, f"load latest snapshot from cache: {download_path}" - ) - if download_path is None: - self.logger.error("download from huggingface failed.") - return None + else: + self.logger.log( + logging.INFO, + f"load latest snapshot from cache: {download_path}", + ) elif source == "custom": self.logger.log(logging.INFO, f"try to load from local: {custom_path}") if not check_all_assets(Path(custom_path), self.sha256_map, update=False): @@ -116,6 +126,10 @@ def download_models( return None download_path = custom_path + if download_path is None: + self.logger.error("Model download failed") + return None + return download_path def load(