Skip to content

Commit

Permalink
fix(tools): log formatting (#401)
Browse files Browse the repository at this point in the history
- introduce numba
- move audio normalization functions to tools/audio
- remove IPython from tools/cmd
  • Loading branch information
fumiama authored Jun 22, 2024
1 parent f809de7 commit 3e0d87b
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 34 deletions.
15 changes: 4 additions & 11 deletions examples/cmd/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,25 @@

import wave
import ChatTTS
from IPython.display import Audio

from tools.audio import unsafe_float_to_int16
from tools.logger import get_logger

logger = get_logger("Command")

def save_wav_file(wav, index):
wav_filename = f"output_audio_{index}.wav"
# Convert numpy array to bytes and write to WAV file
wav_bytes = (wav * 32768).astype('int16').tobytes()
with wave.open(wav_filename, "wb") as wf:
wf.setnchannels(1) # Mono channel
wf.setsampwidth(2) # Sample width in bytes
wf.setframerate(24000) # Sample rate in Hz
wf.writeframes(wav_bytes)
wf.writeframes(unsafe_float_to_int16(wav))
logger.info(f"Audio saved to {wav_filename}")

def main():
# Retrieve text from command line argument
text_input = sys.argv[1] if len(sys.argv) > 1 else "<YOUR TEXT HERE>"
logger.info("Received text input: %s", text_input)
logger.info("Text input: %s", text_input)

chat = ChatTTS.Chat(get_logger("ChatTTS"))
logger.info("Initializing ChatTTS...")
Expand All @@ -41,17 +39,12 @@ def main():
logger.error("Models load failed.")
sys.exit(1)

texts = [text_input]
logger.info("Text prepared for inference: %s", texts)

wavs = chat.infer(texts, use_decoder=True)
wavs = chat.infer((text_input), use_decoder=True)
logger.info("Inference completed. Audio generation successful.")
# Save each generated wav file to a local file
for index, wav in enumerate(wavs):
save_wav_file(wav, index)

return Audio(wavs[0], rate=24_000, autoplay=True)

if __name__ == "__main__":
logger.info("Starting the TTS application...")
main()
Expand Down
25 changes: 3 additions & 22 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gradio as gr
import numpy as np

from tools.audio import unsafe_float_to_int16
from tools.logger import get_logger
logger = get_logger(" WebUI ")

Expand Down Expand Up @@ -74,27 +75,7 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_

if stream:
for gen in wav:
wavs = [np.array([[]])]
wavs[0] = np.hstack([wavs[0], np.array(gen[0])])
audio = wavs[0][0]

# normalize
am = np.abs(audio).max() * 32768
if am > 32768:
am = 32768 * 32768 / am
np.multiply(audio, am, audio)
audio = audio.astype(np.int16)

yield 24000, audio
yield 24000, unsafe_float_to_int16(gen[0][0])
return

audio_data = np.array(wav[0]).flatten()
# normalize
am = np.abs(audio_data).max() * 32768
if am > 32768:
am = 32768 * 32768 / am
np.multiply(audio_data, am, audio_data)
audio_data = audio_data.astype(np.int16)
sample_rate = 24000

yield sample_rate, audio_data
yield 24000, unsafe_float_to_int16(np.array(wav[0]).flatten())
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
numpy<2.0.0
numba
omegaconf>=2.3.0
torch>=2.1.0
tqdm
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
author='2noise',
url='https://github.com/2noise/ChatTTS',
install_requires=['omegaconf>=2.3.0',
'numpy<2.0.0',
'numba',
'torch>=2.1.0',
'tqdm',
'vector_quantize_pytorch',
Expand Down
1 change: 1 addition & 0 deletions tools/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .np import unsafe_float_to_int16
14 changes: 14 additions & 0 deletions tools/audio/np.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np
from numba import jit

@jit
def unsafe_float_to_int16(audio: np.ndarray) -> np.ndarray:
"""
This function will destroy audio, use only once.
"""
am = np.abs(audio).max() * 32768
if am > 32768:
am = 32768 * 32768 / am
np.multiply(audio, am, audio)
audio16 = audio.astype(np.int16)
return audio16
5 changes: 4 additions & 1 deletion tools/logger/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import logging
from datetime import datetime, timezone

logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)

# from https://github.com/FloatTech/ZeroBot-Plugin/blob/c70766a989698452e60e5e48fb2f802a2444330d/console/console_windows.go#L89-L96
colorCodePanic = "\x1b[1;31m"
colorCodeFatal = "\x1b[1;31m"
Expand Down Expand Up @@ -41,7 +44,7 @@ def format(self, record: logging.LogRecord):
logstr += log_level_msg_str.get(record.levelno, record.levelname)
if self.color:
logstr += colorReset
logstr += f"] {str(record.name)} | {str(record.msg)}"
logstr += f"] {str(record.name)} | {str(record.msg)%record.args}"
return logstr

def get_logger(name: str, lv = logging.INFO):
Expand Down

0 comments on commit 3e0d87b

Please sign in to comment.