Skip to content

Commit 3e0d87b

Browse files
authored
fix(tools): log formatting (#401)
- introduce numba - move audio normalization functions to tools/audio - remove IPython from tools/cmd
1 parent f809de7 commit 3e0d87b

File tree

7 files changed

+29
-34
lines changed

7 files changed

+29
-34
lines changed

examples/cmd/run.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,25 @@
1111

1212
import wave
1313
import ChatTTS
14-
from IPython.display import Audio
1514

15+
from tools.audio import unsafe_float_to_int16
1616
from tools.logger import get_logger
1717

1818
logger = get_logger("Command")
1919

2020
def save_wav_file(wav, index):
2121
wav_filename = f"output_audio_{index}.wav"
22-
# Convert numpy array to bytes and write to WAV file
23-
wav_bytes = (wav * 32768).astype('int16').tobytes()
2422
with wave.open(wav_filename, "wb") as wf:
2523
wf.setnchannels(1) # Mono channel
2624
wf.setsampwidth(2) # Sample width in bytes
2725
wf.setframerate(24000) # Sample rate in Hz
28-
wf.writeframes(wav_bytes)
26+
wf.writeframes(unsafe_float_to_int16(wav))
2927
logger.info(f"Audio saved to {wav_filename}")
3028

3129
def main():
3230
# Retrieve text from command line argument
3331
text_input = sys.argv[1] if len(sys.argv) > 1 else "<YOUR TEXT HERE>"
34-
logger.info("Received text input: %s", text_input)
32+
logger.info("Text input: %s", text_input)
3533

3634
chat = ChatTTS.Chat(get_logger("ChatTTS"))
3735
logger.info("Initializing ChatTTS...")
@@ -41,17 +39,12 @@ def main():
4139
logger.error("Models load failed.")
4240
sys.exit(1)
4341

44-
texts = [text_input]
45-
logger.info("Text prepared for inference: %s", texts)
46-
47-
wavs = chat.infer(texts, use_decoder=True)
42+
wavs = chat.infer((text_input), use_decoder=True)
4843
logger.info("Inference completed. Audio generation successful.")
4944
# Save each generated wav file to a local file
5045
for index, wav in enumerate(wavs):
5146
save_wav_file(wav, index)
5247

53-
return Audio(wavs[0], rate=24_000, autoplay=True)
54-
5548
if __name__ == "__main__":
5649
logger.info("Starting the TTS application...")
5750
main()

examples/web/funcs.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import gradio as gr
55
import numpy as np
66

7+
from tools.audio import unsafe_float_to_int16
78
from tools.logger import get_logger
89
logger = get_logger(" WebUI ")
910

@@ -74,27 +75,7 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_
7475

7576
if stream:
7677
for gen in wav:
77-
wavs = [np.array([[]])]
78-
wavs[0] = np.hstack([wavs[0], np.array(gen[0])])
79-
audio = wavs[0][0]
80-
81-
# normalize
82-
am = np.abs(audio).max() * 32768
83-
if am > 32768:
84-
am = 32768 * 32768 / am
85-
np.multiply(audio, am, audio)
86-
audio = audio.astype(np.int16)
87-
88-
yield 24000, audio
78+
yield 24000, unsafe_float_to_int16(gen[0][0])
8979
return
9080

91-
audio_data = np.array(wav[0]).flatten()
92-
# normalize
93-
am = np.abs(audio_data).max() * 32768
94-
if am > 32768:
95-
am = 32768 * 32768 / am
96-
np.multiply(audio_data, am, audio_data)
97-
audio_data = audio_data.astype(np.int16)
98-
sample_rate = 24000
99-
100-
yield sample_rate, audio_data
81+
yield 24000, unsafe_float_to_int16(np.array(wav[0]).flatten())

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
numpy<2.0.0
2+
numba
23
omegaconf>=2.3.0
34
torch>=2.1.0
45
tqdm

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
author='2noise',
55
url='https://github.com/2noise/ChatTTS',
66
install_requires=['omegaconf>=2.3.0',
7+
'numpy<2.0.0',
8+
'numba',
79
'torch>=2.1.0',
810
'tqdm',
911
'vector_quantize_pytorch',

tools/audio/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .np import unsafe_float_to_int16

tools/audio/np.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import numpy as np
2+
from numba import jit
3+
4+
@jit
5+
def unsafe_float_to_int16(audio: np.ndarray) -> np.ndarray:
6+
"""
7+
This function will destroy audio, use only once.
8+
"""
9+
am = np.abs(audio).max() * 32768
10+
if am > 32768:
11+
am = 32768 * 32768 / am
12+
np.multiply(audio, am, audio)
13+
audio16 = audio.astype(np.int16)
14+
return audio16

tools/logger/log.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import logging
33
from datetime import datetime, timezone
44

5+
logging.getLogger("numba").setLevel(logging.WARNING)
6+
logging.getLogger("httpx").setLevel(logging.WARNING)
7+
58
# from https://github.com/FloatTech/ZeroBot-Plugin/blob/c70766a989698452e60e5e48fb2f802a2444330d/console/console_windows.go#L89-L96
69
colorCodePanic = "\x1b[1;31m"
710
colorCodeFatal = "\x1b[1;31m"
@@ -41,7 +44,7 @@ def format(self, record: logging.LogRecord):
4144
logstr += log_level_msg_str.get(record.levelno, record.levelname)
4245
if self.color:
4346
logstr += colorReset
44-
logstr += f"] {str(record.name)} | {str(record.msg)}"
47+
logstr += f"] {str(record.name)} | {str(record.msg)%record.args}"
4548
return logstr
4649

4750
def get_logger(name: str, lv = logging.INFO):

0 commit comments

Comments
 (0)