Skip to content

Commit

Permalink
feat: add cache_dir parameter (#863)
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Jan 7, 2025
1 parent d21106f commit a933b66
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def download_models(
custom_path: Optional[torch.serialization.FILE_LIKE] = None,
) -> Optional[str]:
if source == "local":
download_path = os.getcwd()
download_path = custom_path if custom_path is not None else os.getcwd()
if (
not check_all_assets(Path(download_path), self.sha256_map, update=True)
or force_redownload
Expand All @@ -83,10 +83,20 @@ def download_models(
)
return None
elif source == "huggingface":
hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
try:
download_path = get_latest_modified_file(
os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots")
download_path = (
get_latest_modified_file(
os.path.join(
os.getenv(
"HF_HOME", os.path.expanduser("~/.cache/huggingface")
),
"hub/models--2Noise--ChatTTS/snapshots",
)
)
if custom_path is None
else get_latest_modified_file(
os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots")
)
)
except:
download_path = None
Expand All @@ -99,23 +109,27 @@ def download_models(
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
cache_dir=custom_path,
force_download=force_redownload,
)
except:
download_path = None
else:
self.logger.log(
logging.INFO, f"load latest snapshot from cache: {download_path}"
)
if download_path is None:
self.logger.error("download from huggingface failed.")
return None
else:
self.logger.log(
logging.INFO,
f"load latest snapshot from cache: {download_path}",
)
elif source == "custom":
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
self.logger.error("check models in custom path %s failed.", custom_path)
return None
download_path = custom_path

if download_path is None:
self.logger.error("Model download failed")
return None

return download_path

def load(
Expand Down

0 comments on commit a933b66

Please sign in to comment.