Skip to content

Commit

Permalink
feat(chat): add model chack to custom load (#453)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 25, 2024
1 parent b65e450 commit 21e48be
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
6 changes: 6 additions & 0 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Literal, Optional, List, Callable, Tuple, Dict
from functools import lru_cache
from json import load
from pathlib import Path

import numpy as np
import torch
Expand Down Expand Up @@ -107,6 +108,11 @@ def download_models(
return None
elif source == "custom":
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
if not check_all_assets(self.sha256_map, update=False, base_dir=Path(custom_path)):
self.logger.error(
"check models in custom path %s failed.", custom_path
)
return None
download_path = custom_path

return download_path
Expand Down
8 changes: 3 additions & 5 deletions ChatTTS/utils/dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ def check_model(
return True


def check_all_assets(sha256_map: dict[str, str], update=False) -> bool:
BASE_DIR = Path(os.getcwd())

def check_all_assets(sha256_map: dict[str, str], update=False, base_dir = Path(os.getcwd())) -> bool:
logger.get_logger().info("checking assets...")
current_dir = BASE_DIR / "asset"
current_dir = base_dir / "asset"
names = [
"Decoder.pt",
"DVAE.pt",
Expand All @@ -63,7 +61,7 @@ def check_all_assets(sha256_map: dict[str, str], update=False) -> bool:
return False

logger.get_logger().info("checking configs...")
current_dir = BASE_DIR / "config"
current_dir = base_dir / "config"
names = [
"decoder.yaml",
"dvae.yaml",
Expand Down

0 comments on commit 21e48be

Please sign in to comment.