Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: TorchServe example #870

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/torchserve/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.mar
logs/
14 changes: 14 additions & 0 deletions examples/torchserve/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM pytorch/torchserve:latest-gpu

WORKDIR /app

RUN pip config set global.index-url https://mirrors.cloud.tencent.com/pypi/simple

RUN pip install --upgrade pip \
&& pip install ChatTTS nvgpu soundfile nemo_text_processing WeTextProcessing --no-cache-dir # 安装 ChatTTS 库


COPY ./model_store /app/model_store
COPY ./config.properties /app/config.properties

CMD ["torchserve", "--start", "--model-store", "/app/model_store", "--models", "chattts=chattts.mar", "--ts-config", "/app/config.properties"]
108 changes: 108 additions & 0 deletions examples/torchserve/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Generating voice with ChatTTS via TorchServe for high-performance inference

## Why We Use TorchServe

TorchServe is designed to deliver high performance for serving PyTorch models, and it excels in the following key areas:

1. Batching Requests: TorchServe automatically batches incoming requests, processing multiple predictions in parallel. This reduces overhead, improves throughput, and ensures efficient use of resources, especially when dealing with large volumes of requests.

2. Horizontal Scaling: TorchServe allows for horizontal scaling, meaning it can easily scale across multiple machines or containers to handle increasing traffic. This ensures that the system remains responsive and can handle large volumes of inference requests without sacrificing performance.

## Install requirements

``` bash
pip install -r requirements.txt
```

## Download the model if needed

``` bash
huggingface-cli download 2Noise/ChatTTS
```

## Store the Model

Replace `/path/to/your/model` with the actual path to your model files, for example: `/home/username/.cache/huggingface/hub/models--2Noise--ChatTTS/snapshots/1a3cxxxx`

``` bash

torch-model-archiver --model-name chattts \
--version 1.0 \
--serialized-file /path/to/your/model/asset/Decoder.pt \
--handler model_handler.py \
--extra-files "/path/to/your/model" \
--export-path model_store
```

## Optional: TorchServe Model Configuration

TorchServe support batch inference which aggregates inference requests and sending this aggregated requests through the ML/DL framework for inference all at once. TorchServe was designed to natively support batching of incoming inference requests. This functionality enables you to use your host resources optimally, because most ML/DL frameworks are optimized for batch requests.

Started from Torchserve 0.4.1, there are two methods to configure TorchServe to use the batching feature:

The configuration properties that we are interested in are the following:

1. `batch_size`: This is the maximum batch size that a model is expected to handle, in this example we set `batch_size` to `32`

2. `max_batch_delay`: This is the maximum batch delay time in ms TorchServe waits to receive batch_size number of requests. If TorchServe doesn’t receive batch_size number of requests before this timer time’s out, it sends what ever requests that were received to the model handler, in this example we set `max_batch_delay` to `1000`

## Start TorchServe locally

``` bash
pip install ChatTTS nvgpu soundfile nemo_text_processing WeTextProcessing

torchserve --start --model-store model_store --models chattts=chattts.mar --ts-config config.properties
```

## Optional: Start TorchServe with docker

### Prerequisites

* docker - Refer to the [official docker installation guide](https://docs.docker.com/install/)

### 1. Build the docker image

```bash
docker build -t torchserve-chattts:latest-gpu .
```

### 2. Start the container

```bash
docker run --name torchserve-chattts-gpu --gpus all -p 8080:8080 -p 8081:8081 torchserve-chattts:latest-gpu
```

## Inference with restful api

Note that the `text` parameter takes a string instead of a list, TorchServe will automaticly batch multiple inferences into one request

``` bash
curl --location --request GET 'http://127.0.0.1:8080/predictions/chattts' \
--header 'Content-Type: application/json' \
--data '{
"text": "今天天气不错哦",
"stream": false,
"temperature": 0.3,
"lang": "zh",
"skip_refine_text": true,
"refine_text_only": false,
"use_decoder": true,
"do_text_normalization": true,
"do_homophone_replacement": false,
"params_infer_code": {
"prompt": "",
"top_P": 0.7,
"top_K": 20,
"temperature": 0.3,
"repetition_penalty": 1.05,
"max_new_token": 2048,
"min_new_token": 0,
"show_tqdm": true,
"ensure_non_empty": true,
"manual_seed": 888,
"stream_batch": 24,
"stream_speed": 12000,
"pass_first_n_batches": 2
}
}' --output test.wav
```
18 changes: 18 additions & 0 deletions examples/torchserve/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model_store=model-store
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
disable_token_authorization=true
models={\
"chattts": {\
"1.0": {\
"defaultVersion": true,\
"marName": "chattts.mar",\
"minWorkers": 1,\
"maxWorkers": 1,\
"batchSize": 32,\
"maxBatchDelay": 1000,\
"responseTimeout": 120\
}\
}\
}
137 changes: 137 additions & 0 deletions examples/torchserve/model_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
import json
import torch
from typing import Callable
from functools import partial
import io
from ts.torch_handler.base_handler import BaseHandler
import logging
import torchaudio

import ChatTTS

logger = logging.getLogger("TorchServeHandler")

from typing import Callable


def normalizer_zh_tn() -> Callable[[str], str]:
from tn.chinese.normalizer import Normalizer

return Normalizer(remove_interjections=False).normalize


def normalizer_en_nemo_text() -> Callable[[str], str]:
from nemo_text_processing.text_normalization.normalize import Normalizer

return partial(
Normalizer(input_case="cased", lang="en").normalize,
verbose=False,
punct_post_process=True,
)


class ChatTTSHandler(BaseHandler):
def __init__(self):
super(ChatTTSHandler, self).__init__()
self.chat = None
self.initialized = False

def initialize(self, ctx):
"""Load the model and initialize resources."""
logger.info("Initializing ChatTTS...")
self.chat = ChatTTS.Chat(logging.getLogger("ChatTTS"))
self.chat.normalizer.register("en", normalizer_en_nemo_text())
self.chat.normalizer.register("zh", normalizer_zh_tn())

model_dir = ctx.system_properties.get("model_dir")
os.chdir(model_dir)
if self.chat.load(source="custom", custom_path=model_dir):
logger.info("Models loaded successfully.")
else:
logger.error("Models load failed.")
raise RuntimeError("Failed to load models.")
self.initialized = True

def preprocess(self, data):
"""Preprocess incoming requests."""
if len(data) == 0:
raise ValueError("No data received for inference.")
logger.info(f"batch size: {len(data)}")
return self._group_reuqest_by_config(data)

def _group_reuqest_by_config(self, data):

batched_requests = {}
for req in data:
params = req.get("body")
text = params.pop("text")

key = json.dumps(params)

if key not in batched_requests:
params_refine_text = params.get("params_refine_text")
params_infer_code = params.get("params_infer_code")

if (
params_infer_code
and params_infer_code.get("manual_seed") is not None
):
torch.manual_seed(params_infer_code.get("manual_seed"))
params_infer_code["spk_emb"] = self.chat.sample_random_speaker()

batched_requests[key] = {
"text": [text],
"stream": params.get("stream", False),
"lang": params.get("lang"),
"skip_refine_text": params.get("skip_refine_text", False),
"use_decoder": params.get("use_decoder", True),
"do_text_normalization": params.get("do_text_normalization", True),
"do_homophone_replacement": params.get(
"do_homophone_replacement", False
),
"params_refine_text": (
ChatTTS.Chat.InferCodeParams(**params_refine_text)
if params_refine_text
else None
),
"params_infer_code": (
ChatTTS.Chat.InferCodeParams(**params_infer_code)
if params_infer_code
else None
),
}
else:
batched_requests[key]["text"].append(text)

return batched_requests

def inference(self, data):
"""Run inference."""

for key, params in data.items():
logger.info(f"Request: {key}")
logger.info(f"Text input: {str(params['text'])}")

text = params["text"]
if params["params_refine_text"]:
text = self.chat.infer(text=text, refine_text_only=True)
logger.info(f"Refined text: {text}")

yield self.chat.infer(**params)

def postprocess(self, batched_results):
"""Post-process inference results into raw wav data."""
results = []
for wavs in batched_results:
for wav in wavs:
buf = io.BytesIO()
try:
torchaudio.save(
buf, torch.from_numpy(wav).unsqueeze(0), 24000, format="wav"
)
except:
torchaudio.save(buf, torch.from_numpy(wav), 24000, format="wav")
buf.seek(0)
results.append(buf.getvalue())
return results
Empty file.
3 changes: 3 additions & 0 deletions examples/torchserve/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
huggingface_hub[cli]
torchserve
torch-model-archiver
Loading