Skip to content

Commit

Permalink
feature: TorchServe example
Browse files Browse the repository at this point in the history
  • Loading branch information
JaysonAlbert committed Jan 8, 2025
1 parent 00c56ee commit 6966cd3
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 0 deletions.
2 changes: 2 additions & 0 deletions examples/torchserve/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.mar
logs/
14 changes: 14 additions & 0 deletions examples/torchserve/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM pytorch/torchserve:latest-gpu

WORKDIR /app

RUN pip config set global.index-url https://mirrors.cloud.tencent.com/pypi/simple

RUN pip install --upgrade pip \
&& pip install ChatTTS nvgpu soundfile nemo_text_processing WeTextProcessing --no-cache-dir # 安装 ChatTTS 库


COPY ./model_store /app/model_store
COPY ./config.properties /app/config.properties

CMD ["torchserve", "--start", "--model-store", "/app/model_store", "--models", "chattts=chattts.mar", "--ts-config", "/app/config.properties"]
106 changes: 106 additions & 0 deletions examples/torchserve/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Generating voice with ChatTTS via TorchServe for high-performance inference

## Why We Use TorchServe

TorchServe is designed to deliver high performance for serving PyTorch models, and it excels in the following key areas:

1. Batching Requests: TorchServe automatically batches incoming requests, processing multiple predictions in parallel. This reduces overhead, improves throughput, and ensures efficient use of resources, especially when dealing with large volumes of requests.

2. Horizontal Scaling: TorchServe allows for horizontal scaling, meaning it can easily scale across multiple machines or containers to handle increasing traffic. This ensures that the system remains responsive and can handle large volumes of inference requests without sacrificing performance.

## Navigate to the examples/torchserve directory

``` bash
cd examples/torchserve
```

## Download the model if needed

``` bash
huggingface-cli download 2Noise/ChatTTS
```

## Store the Model

Replace `/path/to/your/model` with the actual path to your model files, for example: `/home/username/.cache/huggingface/hub/models--2Noise--ChatTTS/snapshots/1a3cxxxx`

``` bash

torch-model-archiver --model-name chattts \
--version 1.0 \
--serialized-file /path/to/your/model/asset/Decoder.pt \
--handler model_handler.py \
--extra-files "/path/to/your/model" \
--export-path model_store
```

## Optional: TorchServe Model Configuration

TorchServe support batch inference which aggregates inference requests and sending this aggregated requests through the ML/DL framework for inference all at once. TorchServe was designed to natively support batching of incoming inference requests. This functionality enables you to use your host resources optimally, because most ML/DL frameworks are optimized for batch requests.

Started from Torchserve 0.4.1, there are two methods to configure TorchServe to use the batching feature:

The configuration properties that we are interested in are the following:

1. `batch_size`: This is the maximum batch size that a model is expected to handle, in this example we set `batch_size` to `32`

2. `max_batch_delay`: This is the maximum batch delay time in ms TorchServe waits to receive batch_size number of requests. If TorchServe doesn’t receive batch_size number of requests before this timer time’s out, it sends what ever requests that were received to the model handler, in this example we set `max_batch_delay` to `1000`

## Start TorchServe locally

``` bash
torchserve --start --model-store model_store --models chattts=chattts.mar
```

## Optional: Start TorchServe with docker

### Prerequisites

* docker - Refer to the [official docker installation guide](https://docs.docker.com/install/)

### 1. Build the docker image

```bash
docker build -t torchserve-chattts:latest-gpu .
```

### 2. Start the container

```bash
docker run --name torchserve-chattts-gpu --gpus all -p 8080:8080 -p 8081:8081 torchserve-chattts:latest-gpu
```

## Inference with restful api

Note that the `text` parameter takes a string instead of a list, TorchServe will automaticly batch multiple inferences into one request

``` bash
curl --location --request GET 'http://127.0.0.1:8080/predictions/chattts' \
--header 'Content-Type: application/json' \
--data '{
"text": "今天天气不错哦",
"stream": false,
"temperature": 0.3,
"lang": "zh",
"skip_refine_text": true,
"refine_text_only": false,
"use_decoder": true,
"do_text_normalization": true,
"do_homophone_replacement": false,
"params_infer_code": {
"prompt": "",
"top_P": 0.7,
"top_K": 20,
"temperature": 0.3,
"repetition_penalty": 1.05,
"max_new_token": 2048,
"min_new_token": 0,
"show_tqdm": true,
"ensure_non_empty": true,
"manual_seed": 888,
"stream_batch": 24,
"stream_speed": 12000,
"pass_first_n_batches": 2
}
}' --output test.wav
```
18 changes: 18 additions & 0 deletions examples/torchserve/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model_store=model-store
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
disable_token_authorization=true
models={\
"chattts": {\
"1.0": {\
"defaultVersion": true,\
"marName": "chattts.mar",\
"minWorkers": 1,\
"maxWorkers": 1,\
"batchSize": 32,\
"maxBatchDelay": 1000,\
"responseTimeout": 120\
}\
}\
}
128 changes: 128 additions & 0 deletions examples/torchserve/model_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os
import json
import torch
from typing import Callable
from functools import partial
import io
from ts.torch_handler.base_handler import BaseHandler
import logging
import torchaudio

import ChatTTS

logger = logging.getLogger("TorchServeHandler")

from typing import Callable


def normalizer_zh_tn() -> Callable[[str], str]:
from tn.chinese.normalizer import Normalizer

return Normalizer(remove_interjections=False).normalize


def normalizer_en_nemo_text() -> Callable[[str], str]:
from nemo_text_processing.text_normalization.normalize import Normalizer

return partial(
Normalizer(input_case="cased", lang="en").normalize,
verbose=False,
punct_post_process=True,
)


class ChatTTSHandler(BaseHandler):
def __init__(self):
super(ChatTTSHandler, self).__init__()
self.chat = None
self.initialized = False

def initialize(self, ctx):
"""Load the model and initialize resources."""
logger.info("Initializing ChatTTS...")
self.chat = ChatTTS.Chat(logging.getLogger("ChatTTS"))
self.chat.normalizer.register("en", normalizer_en_nemo_text())
self.chat.normalizer.register("zh", normalizer_zh_tn())

model_dir = ctx.system_properties.get("model_dir")
os.chdir(model_dir)
if self.chat.load(source="custom", custom_path=model_dir):
logger.info("Models loaded successfully.")
else:
logger.error("Models load failed.")
raise RuntimeError("Failed to load models.")
self.initialized = True

def preprocess(self, data):
"""Preprocess incoming requests."""
if len(data) == 0:
raise ValueError("No data received for inference.")
logger.info(f"batch size: {len(data)}")
return self._group_reuqest_by_config(data)

def _group_reuqest_by_config(self, data):

batched_requests = {}
for req in data:
params = req.get("body")
text = params.pop("text")

key = json.dumps(params)

if key not in batched_requests:
params_refine_text = params.get("params_refine_text")
params_infer_code = params.get("params_infer_code")

if params_infer_code and params_infer_code.get("manual_seed") is not None:
torch.manual_seed(params_infer_code.get("manual_seed"))
params_infer_code["spk_emb"] = self.chat.sample_random_speaker()

batched_requests[key] = {
"text": [text],
"stream": params.get("stream", False),
"lang": params.get("lang"),
"skip_refine_text": params.get("skip_refine_text", False),
"use_decoder": params.get("use_decoder", True),
"do_text_normalization": params.get("do_text_normalization", True),
"do_homophone_replacement": params.get("do_homophone_replacement", False),
"params_refine_text": ChatTTS.Chat.InferCodeParams(
**params_refine_text
) if params_refine_text else None,
"params_infer_code": ChatTTS.Chat.InferCodeParams(
**params_infer_code
) if params_infer_code else None,
}
else:
batched_requests[key]["text"].append(text)

return batched_requests

def inference(self, data):
"""Run inference."""

for key, params in data.items():
logger.info(f"Request: {key}")
logger.info(f"Text input: {str(params['text'])}")

text = params["text"]
if params["params_refine_text"]:
text = self.chat.infer(text=text, refine_text_only=True)
logger.info(f"Refined text: {text}")

yield self.chat.infer(**params)

def postprocess(self, batched_results):
"""Post-process inference results into raw wav data."""
results = []
for wavs in batched_results:
for wav in wavs:
buf = io.BytesIO()
try:
torchaudio.save(
buf, torch.from_numpy(wav).unsqueeze(0), 24000, format="wav"
)
except:
torchaudio.save(buf, torch.from_numpy(wav), 24000, format="wav")
buf.seek(0)
results.append(buf.getvalue())
return results
Empty file.

0 comments on commit 6966cd3

Please sign in to comment.