Skip to content

Commit

Permalink
Update core.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jan 7, 2025
1 parent 12318b5 commit fa21047
Showing 1 changed file with 17 additions and 24 deletions.
41 changes: 17 additions & 24 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,23 @@ def download_models(
)
return None
elif source == "huggingface":
if cache_dir:
if cache_dir is None:
hf_home = os.getenv(
"HF_HOME", os.path.expanduser("~/.cache/huggingface")
)
else:
hf_home = cache_dir
try:
download_path = get_latest_modified_file(
os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots")
)
except:
download_path = None
if download_path is None or force_redownload:
self.logger.log(
logging.INFO,
f"download from HF: https://huggingface.co/2Noise/ChatTTS",
)
try:
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
Expand All @@ -94,28 +110,6 @@ def download_models(
)
except:
download_path = None
else:
hf_home = os.getenv(
"HF_HOME", os.path.expanduser("~/.cache/huggingface")
)
try:
download_path = get_latest_modified_file(
os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots")
)
except:
download_path = None
if download_path is None or force_redownload:
self.logger.log(
logging.INFO,
f"download from HF: https://huggingface.co/2Noise/ChatTTS",
)
try:
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
)
except:
download_path = None
else:
self.logger.log(
logging.INFO,
Expand All @@ -134,7 +128,6 @@ def download_models(

return download_path

# Modified
def load(
self,
source: Literal["huggingface", "local", "custom"] = "local",
Expand Down

0 comments on commit fa21047

Please sign in to comment.