diff --git a/examples/cmd/run.py b/examples/cmd/run.py index 9f483db3b..b4a387f24 100644 --- a/examples/cmd/run.py +++ b/examples/cmd/run.py @@ -11,27 +11,25 @@ import wave import ChatTTS -from IPython.display import Audio +from tools.audio import unsafe_float_to_int16 from tools.logger import get_logger logger = get_logger("Command") def save_wav_file(wav, index): wav_filename = f"output_audio_{index}.wav" - # Convert numpy array to bytes and write to WAV file - wav_bytes = (wav * 32768).astype('int16').tobytes() with wave.open(wav_filename, "wb") as wf: wf.setnchannels(1) # Mono channel wf.setsampwidth(2) # Sample width in bytes wf.setframerate(24000) # Sample rate in Hz - wf.writeframes(wav_bytes) + wf.writeframes(unsafe_float_to_int16(wav)) logger.info(f"Audio saved to {wav_filename}") def main(): # Retrieve text from command line argument text_input = sys.argv[1] if len(sys.argv) > 1 else "" - logger.info("Received text input: %s", text_input) + logger.info("Text input: %s", text_input) chat = ChatTTS.Chat(get_logger("ChatTTS")) logger.info("Initializing ChatTTS...") @@ -41,17 +39,12 @@ def main(): logger.error("Models load failed.") sys.exit(1) - texts = [text_input] - logger.info("Text prepared for inference: %s", texts) - - wavs = chat.infer(texts, use_decoder=True) + wavs = chat.infer((text_input), use_decoder=True) logger.info("Inference completed. Audio generation successful.") # Save each generated wav file to a local file for index, wav in enumerate(wavs): save_wav_file(wav, index) - return Audio(wavs[0], rate=24_000, autoplay=True) - if __name__ == "__main__": logger.info("Starting the TTS application...") main() diff --git a/examples/web/funcs.py b/examples/web/funcs.py index 4b25df66c..e0dedb93d 100644 --- a/examples/web/funcs.py +++ b/examples/web/funcs.py @@ -4,6 +4,7 @@ import gradio as gr import numpy as np +from tools.audio import unsafe_float_to_int16 from tools.logger import get_logger logger = get_logger(" WebUI ") @@ -74,27 +75,7 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_ if stream: for gen in wav: - wavs = [np.array([[]])] - wavs[0] = np.hstack([wavs[0], np.array(gen[0])]) - audio = wavs[0][0] - - # normalize - am = np.abs(audio).max() * 32768 - if am > 32768: - am = 32768 * 32768 / am - np.multiply(audio, am, audio) - audio = audio.astype(np.int16) - - yield 24000, audio + yield 24000, unsafe_float_to_int16(gen[0][0]) return - audio_data = np.array(wav[0]).flatten() - # normalize - am = np.abs(audio_data).max() * 32768 - if am > 32768: - am = 32768 * 32768 / am - np.multiply(audio_data, am, audio_data) - audio_data = audio_data.astype(np.int16) - sample_rate = 24000 - - yield sample_rate, audio_data + yield 24000, unsafe_float_to_int16(np.array(wav[0]).flatten()) diff --git a/requirements.txt b/requirements.txt index 33bfb73ba..1b9bba011 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ numpy<2.0.0 +numba omegaconf>=2.3.0 torch>=2.1.0 tqdm diff --git a/setup.py b/setup.py index c70bd3fbf..ecd0b3afd 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,8 @@ author='2noise', url='https://github.com/2noise/ChatTTS', install_requires=['omegaconf>=2.3.0', + 'numpy<2.0.0', + 'numba', 'torch>=2.1.0', 'tqdm', 'vector_quantize_pytorch', diff --git a/tools/audio/__init__.py b/tools/audio/__init__.py new file mode 100644 index 000000000..fc55f41fe --- /dev/null +++ b/tools/audio/__init__.py @@ -0,0 +1 @@ +from .np import unsafe_float_to_int16 diff --git a/tools/audio/np.py b/tools/audio/np.py new file mode 100644 index 000000000..3dfb46d1c --- /dev/null +++ b/tools/audio/np.py @@ -0,0 +1,14 @@ +import numpy as np +from numba import jit + +@jit +def unsafe_float_to_int16(audio: np.ndarray) -> np.ndarray: + """ + This function will destroy audio, use only once. + """ + am = np.abs(audio).max() * 32768 + if am > 32768: + am = 32768 * 32768 / am + np.multiply(audio, am, audio) + audio16 = audio.astype(np.int16) + return audio16 diff --git a/tools/logger/log.py b/tools/logger/log.py index 5e5066d99..b36212047 100644 --- a/tools/logger/log.py +++ b/tools/logger/log.py @@ -2,6 +2,9 @@ import logging from datetime import datetime, timezone +logging.getLogger("numba").setLevel(logging.WARNING) +logging.getLogger("httpx").setLevel(logging.WARNING) + # from https://github.com/FloatTech/ZeroBot-Plugin/blob/c70766a989698452e60e5e48fb2f802a2444330d/console/console_windows.go#L89-L96 colorCodePanic = "\x1b[1;31m" colorCodeFatal = "\x1b[1;31m" @@ -41,7 +44,7 @@ def format(self, record: logging.LogRecord): logstr += log_level_msg_str.get(record.levelno, record.levelname) if self.color: logstr += colorReset - logstr += f"] {str(record.name)} | {str(record.msg)}" + logstr += f"] {str(record.name)} | {str(record.msg)%record.args}" return logstr def get_logger(name: str, lv = logging.INFO):