diff --git a/examples/torchserve/.gitignore b/examples/torchserve/.gitignore new file mode 100644 index 000000000..097cfe186 --- /dev/null +++ b/examples/torchserve/.gitignore @@ -0,0 +1,2 @@ +*.mar +logs/ \ No newline at end of file diff --git a/examples/torchserve/Dockerfile b/examples/torchserve/Dockerfile new file mode 100644 index 000000000..511cc25d2 --- /dev/null +++ b/examples/torchserve/Dockerfile @@ -0,0 +1,14 @@ +FROM pytorch/torchserve:latest-gpu + +WORKDIR /app + +RUN pip config set global.index-url https://mirrors.cloud.tencent.com/pypi/simple + +RUN pip install --upgrade pip \ + && pip install ChatTTS nvgpu soundfile nemo_text_processing WeTextProcessing --no-cache-dir # 安装 ChatTTS 库 + + +COPY ./model_store /app/model_store +COPY ./config.properties /app/config.properties + +CMD ["torchserve", "--start", "--model-store", "/app/model_store", "--models", "chattts=chattts.mar", "--ts-config", "/app/config.properties"] diff --git a/examples/torchserve/README.md b/examples/torchserve/README.md new file mode 100644 index 000000000..b2bb1ecc1 --- /dev/null +++ b/examples/torchserve/README.md @@ -0,0 +1,108 @@ +# Generating voice with ChatTTS via TorchServe for high-performance inference + +## Why We Use TorchServe + +TorchServe is designed to deliver high performance for serving PyTorch models, and it excels in the following key areas: + +1. Batching Requests: TorchServe automatically batches incoming requests, processing multiple predictions in parallel. This reduces overhead, improves throughput, and ensures efficient use of resources, especially when dealing with large volumes of requests. + +2. Horizontal Scaling: TorchServe allows for horizontal scaling, meaning it can easily scale across multiple machines or containers to handle increasing traffic. This ensures that the system remains responsive and can handle large volumes of inference requests without sacrificing performance. + +## Install requirements + +``` bash +pip install -r requirements.txt +``` + +## Download the model if needed + +``` bash +huggingface-cli download 2Noise/ChatTTS +``` + +## Store the Model + +Replace `/path/to/your/model` with the actual path to your model files, for example: `/home/username/.cache/huggingface/hub/models--2Noise--ChatTTS/snapshots/1a3cxxxx` + +``` bash + +torch-model-archiver --model-name chattts \ + --version 1.0 \ + --serialized-file /path/to/your/model/asset/Decoder.pt \ + --handler model_handler.py \ + --extra-files "/path/to/your/model" \ + --export-path model_store +``` + +## Optional: TorchServe Model Configuration + +TorchServe support batch inference which aggregates inference requests and sending this aggregated requests through the ML/DL framework for inference all at once. TorchServe was designed to natively support batching of incoming inference requests. This functionality enables you to use your host resources optimally, because most ML/DL frameworks are optimized for batch requests. + +Started from Torchserve 0.4.1, there are two methods to configure TorchServe to use the batching feature: + +The configuration properties that we are interested in are the following: + +1. `batch_size`: This is the maximum batch size that a model is expected to handle, in this example we set `batch_size` to `32` + +2. `max_batch_delay`: This is the maximum batch delay time in ms TorchServe waits to receive batch_size number of requests. If TorchServe doesn’t receive batch_size number of requests before this timer time’s out, it sends what ever requests that were received to the model handler, in this example we set `max_batch_delay` to `1000` + +## Start TorchServe locally + +``` bash +pip install ChatTTS nvgpu soundfile nemo_text_processing WeTextProcessing + +torchserve --start --model-store model_store --models chattts=chattts.mar --ts-config config.properties +``` + +## Optional: Start TorchServe with docker + +### Prerequisites + +* docker - Refer to the [official docker installation guide](https://docs.docker.com/install/) + +### 1. Build the docker image + +```bash +docker build -t torchserve-chattts:latest-gpu . +``` + +### 2. Start the container + +```bash +docker run --name torchserve-chattts-gpu --gpus all -p 8080:8080 -p 8081:8081 torchserve-chattts:latest-gpu +``` + +## Inference with restful api + +Note that the `text` parameter takes a string instead of a list, TorchServe will automaticly batch multiple inferences into one request + +``` bash +curl --location --request GET 'http://127.0.0.1:8080/predictions/chattts' \ +--header 'Content-Type: application/json' \ +--data '{ + "text": "今天天气不错哦", + "stream": false, + "temperature": 0.3, + "lang": "zh", + "skip_refine_text": true, + "refine_text_only": false, + "use_decoder": true, + "do_text_normalization": true, + "do_homophone_replacement": false, + "params_infer_code": { + "prompt": "", + "top_P": 0.7, + "top_K": 20, + "temperature": 0.3, + "repetition_penalty": 1.05, + "max_new_token": 2048, + "min_new_token": 0, + "show_tqdm": true, + "ensure_non_empty": true, + "manual_seed": 888, + "stream_batch": 24, + "stream_speed": 12000, + "pass_first_n_batches": 2 + } +}' --output test.wav +``` diff --git a/examples/torchserve/config.properties b/examples/torchserve/config.properties new file mode 100644 index 000000000..785e30ff3 --- /dev/null +++ b/examples/torchserve/config.properties @@ -0,0 +1,18 @@ +model_store=model-store +inference_address=http://0.0.0.0:8080 +management_address=http://0.0.0.0:8081 +metrics_address=http://0.0.0.0:8082 +disable_token_authorization=true +models={\ + "chattts": {\ + "1.0": {\ + "defaultVersion": true,\ + "marName": "chattts.mar",\ + "minWorkers": 1,\ + "maxWorkers": 1,\ + "batchSize": 32,\ + "maxBatchDelay": 1000,\ + "responseTimeout": 120\ + }\ + }\ +} diff --git a/examples/torchserve/model_handler.py b/examples/torchserve/model_handler.py new file mode 100644 index 000000000..2d034bf0f --- /dev/null +++ b/examples/torchserve/model_handler.py @@ -0,0 +1,137 @@ +import os +import json +import torch +from typing import Callable +from functools import partial +import io +from ts.torch_handler.base_handler import BaseHandler +import logging +import torchaudio + +import ChatTTS + +logger = logging.getLogger("TorchServeHandler") + +from typing import Callable + + +def normalizer_zh_tn() -> Callable[[str], str]: + from tn.chinese.normalizer import Normalizer + + return Normalizer(remove_interjections=False).normalize + + +def normalizer_en_nemo_text() -> Callable[[str], str]: + from nemo_text_processing.text_normalization.normalize import Normalizer + + return partial( + Normalizer(input_case="cased", lang="en").normalize, + verbose=False, + punct_post_process=True, + ) + + +class ChatTTSHandler(BaseHandler): + def __init__(self): + super(ChatTTSHandler, self).__init__() + self.chat = None + self.initialized = False + + def initialize(self, ctx): + """Load the model and initialize resources.""" + logger.info("Initializing ChatTTS...") + self.chat = ChatTTS.Chat(logging.getLogger("ChatTTS")) + self.chat.normalizer.register("en", normalizer_en_nemo_text()) + self.chat.normalizer.register("zh", normalizer_zh_tn()) + + model_dir = ctx.system_properties.get("model_dir") + os.chdir(model_dir) + if self.chat.load(source="custom", custom_path=model_dir): + logger.info("Models loaded successfully.") + else: + logger.error("Models load failed.") + raise RuntimeError("Failed to load models.") + self.initialized = True + + def preprocess(self, data): + """Preprocess incoming requests.""" + if len(data) == 0: + raise ValueError("No data received for inference.") + logger.info(f"batch size: {len(data)}") + return self._group_reuqest_by_config(data) + + def _group_reuqest_by_config(self, data): + + batched_requests = {} + for req in data: + params = req.get("body") + text = params.pop("text") + + key = json.dumps(params) + + if key not in batched_requests: + params_refine_text = params.get("params_refine_text") + params_infer_code = params.get("params_infer_code") + + if ( + params_infer_code + and params_infer_code.get("manual_seed") is not None + ): + torch.manual_seed(params_infer_code.get("manual_seed")) + params_infer_code["spk_emb"] = self.chat.sample_random_speaker() + + batched_requests[key] = { + "text": [text], + "stream": params.get("stream", False), + "lang": params.get("lang"), + "skip_refine_text": params.get("skip_refine_text", False), + "use_decoder": params.get("use_decoder", True), + "do_text_normalization": params.get("do_text_normalization", True), + "do_homophone_replacement": params.get( + "do_homophone_replacement", False + ), + "params_refine_text": ( + ChatTTS.Chat.InferCodeParams(**params_refine_text) + if params_refine_text + else None + ), + "params_infer_code": ( + ChatTTS.Chat.InferCodeParams(**params_infer_code) + if params_infer_code + else None + ), + } + else: + batched_requests[key]["text"].append(text) + + return batched_requests + + def inference(self, data): + """Run inference.""" + + for key, params in data.items(): + logger.info(f"Request: {key}") + logger.info(f"Text input: {str(params['text'])}") + + text = params["text"] + if params["params_refine_text"]: + text = self.chat.infer(text=text, refine_text_only=True) + logger.info(f"Refined text: {text}") + + yield self.chat.infer(**params) + + def postprocess(self, batched_results): + """Post-process inference results into raw wav data.""" + results = [] + for wavs in batched_results: + for wav in wavs: + buf = io.BytesIO() + try: + torchaudio.save( + buf, torch.from_numpy(wav).unsqueeze(0), 24000, format="wav" + ) + except: + torchaudio.save(buf, torch.from_numpy(wav), 24000, format="wav") + buf.seek(0) + results.append(buf.getvalue()) + return results diff --git a/examples/torchserve/model_store/.gitkeep b/examples/torchserve/model_store/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/examples/torchserve/requirements.txt b/examples/torchserve/requirements.txt new file mode 100644 index 000000000..ec98bac0c --- /dev/null +++ b/examples/torchserve/requirements.txt @@ -0,0 +1,3 @@ +huggingface_hub[cli] +torchserve +torch-model-archiver \ No newline at end of file