Skip to content

Commit 8620dfb

Browse files
authored
fix(load): KeyError: 'sha256_config_decoder_yaml' (fix #438) (#441)
by removing environment vars
1 parent c23e514 commit 8620dfb

File tree

12 files changed

+42
-44
lines changed

12 files changed

+42
-44
lines changed

ChatTTS/core.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import logging
33
import tempfile
44
from dataclasses import dataclass
5-
from typing import Literal, Optional, List, Callable, Tuple
5+
from typing import Literal, Optional, List, Callable, Tuple, Dict
66
from functools import lru_cache
7+
from json import load
78

89
import numpy as np
910
import torch
@@ -30,6 +31,8 @@ def __init__(self, logger=logging.getLogger(__name__)):
3031
os.path.join(os.path.dirname(__file__), 'res', 'homophones_map.json'),
3132
logger,
3233
)
34+
with open(os.path.join(os.path.dirname(__file__), 'res', 'sha256_map.json')) as f:
35+
self.sha256_map: Dict[str, str] = load(f)
3336

3437
self.context = GPT.Context()
3538

@@ -60,10 +63,10 @@ def download_models(
6063
) -> Optional[str]:
6164
if source == 'local':
6265
download_path = os.getcwd()
63-
if not check_all_assets(update=True) or force_redownload:
66+
if not check_all_assets(self.sha256_map, update=True) or force_redownload:
6467
with tempfile.TemporaryDirectory() as tmp:
6568
download_all_assets(tmpdir=tmp)
66-
if not check_all_assets(update=False):
69+
if not check_all_assets(self.sha256_map, update=False):
6770
self.logger.error("download to local path %s failed.", download_path)
6871
return None
6972
elif source == 'huggingface':
@@ -111,6 +114,7 @@ def unload(self):
111114
del_all(self.pretrain_models)
112115
self.normalizer.destroy()
113116
del self.normalizer
117+
del self.sha256_map
114118
self._gen_logits.cache_clear()
115119
del_list = ["vocos", "_vocos_decode", 'gpt', 'decoder', 'dvae']
116120
for module in del_list:

ChatTTS/model/gpt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import os
22
os.environ["TOKENIZERS_PARALLELISM"] = "false"
3+
"""
4+
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
5+
"""
36

47
from dataclasses import dataclass
58
import logging

ChatTTS/res/sha256_map.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"sha256_asset_Decoder_pt" : "9964e36e840f0e3a748c5f716fe6de6490d2135a5f5155f4a642d51860e2ec38",
3+
"sha256_asset_DVAE_pt" : "613cb128adf89188c93ea5880ea0b798e66b1fe6186d0c535d99bcd87bfd6976",
4+
"sha256_asset_GPT_pt" : "d7d4ee6461ea097a2be23eb40d73fb94ad3b3d39cb64fbb50cb3357fd466cadb",
5+
"sha256_asset_spk_stat_pt" : "3228d8a4cbbf349d107a1b76d2f47820865bd3c9928c4bdfe1cefd5c7071105f",
6+
"sha256_asset_tokenizer_pt" : "e911ae7c6a7c27953433f35c44227a67838fe229a1f428503bdb6cd3d1bcc69c",
7+
"sha256_asset_Vocos_pt" : "09a670eda1c08b740013679c7a90ebb7f1a97646ea7673069a6838e6b51d6c58",
8+
9+
"sha256_config_decoder_yaml": "0890ab719716b0ad8abcb9eba0a9bf52c59c2e45ddedbbbb5ed514ff87bff369",
10+
"sha256_config_dvae_yaml" : "1b3a5aa0c6a314f766d4432ab36f84e882e29561648d837f71c04c7bea494fc6",
11+
"sha256_config_gpt_yaml" : "0c3c7277b674094bdd00b63b18b18aa3156502101dbd03c7f802e0fcf26cff51",
12+
"sha256_config_path_yaml" : "79829705c2d2a29b3f55e3b3f228bb81875e4e265211595fb50a73eb6434684b",
13+
"sha256_config_vocos_yaml" : "1ca837ce790dd8b55bdd5a16c6af8f813926b9c9b48f2a4da305e7e9ff0c9b0c"
14+
}

ChatTTS/utils/dl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def check_model(
4141
return True
4242

4343

44-
def check_all_assets(update=False) -> bool:
44+
def check_all_assets(sha256_map: dict[str, str], update=False) -> bool:
4545
BASE_DIR = Path(os.getcwd())
4646

4747
logger.get_logger().info("checking assets...")
@@ -57,7 +57,7 @@ def check_all_assets(update=False) -> bool:
5757
for model in names:
5858
menv = model.replace(".", "_")
5959
if not check_model(
60-
current_dir, model, os.environ[f"sha256_asset_{menv}"], update
60+
current_dir, model, sha256_map[f"sha256_asset_{menv}"], update
6161
):
6262
return False
6363

@@ -73,7 +73,7 @@ def check_all_assets(update=False) -> bool:
7373
for model in names:
7474
menv = model.replace(".", "_")
7575
if not check_model(
76-
current_dir, model, os.environ[f"sha256_config_{menv}"], update
76+
current_dir, model, sha256_map[f"sha256_config_{menv}"], update
7777
):
7878
return False
7979

examples/cmd/run.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
import wave
1010
import argparse
1111

12-
from dotenv import load_dotenv
13-
load_dotenv("sha256.env")
14-
1512
import ChatTTS
1613

1714
from tools.audio import unsafe_float_to_int16

examples/ipynb/colab.ipynb

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@
4141
},
4242
"outputs": [],
4343
"source": [
44-
"from dotenv import load_dotenv\n",
45-
"load_dotenv(\"ChatTTS/sha256.env\")\n",
46-
"\n",
4744
"import torch\n",
4845
"torch._dynamo.config.cache_size_limit = 64\n",
4946
"torch._dynamo.config.suppress_errors = True\n",

examples/ipynb/example.ipynb

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
" sys.path.append(root_dir)\n",
2525
" print(\"init root dir to\", root_dir)\n",
2626
"\n",
27-
"from dotenv import load_dotenv\n",
28-
"load_dotenv(os.path.join(root_dir, \"sha256.env\"))\n",
29-
"\n",
3027
"import torch\n",
3128
"torch._dynamo.config.cache_size_limit = 64\n",
3229
"torch._dynamo.config.suppress_errors = True\n",

examples/web/webui.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010

1111
import gradio as gr
1212

13-
from dotenv import load_dotenv
14-
load_dotenv("sha256.env")
15-
1613
from examples.web.funcs import *
1714

1815
def main():

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ transformers>=4.41.1
88
vocos
99
IPython
1010
gradio
11-
python-dotenv
1211
pybase16384
1312
pynini==2.1.5; sys_platform == 'linux'
1413
WeTextProcessing; sys_platform == 'linux'

sha256.env

Lines changed: 0 additions & 12 deletions
This file was deleted.

0 commit comments

Comments
 (0)