Skip to content

Commit

Permalink
Merge branch 'feature/torchserve-example' of https://github.com/Jayso…
Browse files Browse the repository at this point in the history
…nAlbert/ChatTTS into feature/torchserve-example
  • Loading branch information
JaysonAlbert committed Jan 8, 2025
2 parents f333612 + 62c35c9 commit 6945654
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
33 changes: 21 additions & 12 deletions examples/torchserve/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def initialize(self, ctx):
self.chat = ChatTTS.Chat(logging.getLogger("ChatTTS"))
self.chat.normalizer.register("en", normalizer_en_nemo_text())
self.chat.normalizer.register("zh", normalizer_zh_tn())

model_dir = ctx.system_properties.get("model_dir")
os.chdir(model_dir)
if self.chat.load(source="custom", custom_path=model_dir, compile=True):
Expand All @@ -68,29 +68,38 @@ def _group_reuqest_by_config(self, data):
text = params.pop("text")

key = json.dumps(params)

if key not in batched_requests:
params_refine_text = params.get("params_refine_text")
params_infer_code = params.get("params_infer_code")

if params_infer_code and params_infer_code.get("manual_seed") is not None:
if (
params_infer_code
and params_infer_code.get("manual_seed") is not None
):
torch.manual_seed(params_infer_code.get("manual_seed"))
params_infer_code["spk_emb"] = self.chat.sample_random_speaker()

batched_requests[key] = {
"text": [text],
"stream": params.get("stream", False),
"stream": params.get("stream", False),
"lang": params.get("lang"),
"skip_refine_text": params.get("skip_refine_text", False),
"use_decoder": params.get("use_decoder", True),
"use_decoder": params.get("use_decoder", True),
"do_text_normalization": params.get("do_text_normalization", True),
"do_homophone_replacement": params.get("do_homophone_replacement", False),
"params_refine_text": ChatTTS.Chat.InferCodeParams(
**params_refine_text
) if params_refine_text else None,
"params_infer_code": ChatTTS.Chat.InferCodeParams(
**params_infer_code
) if params_infer_code else None,
"do_homophone_replacement": params.get(
"do_homophone_replacement", False
),
"params_refine_text": (
ChatTTS.Chat.InferCodeParams(**params_refine_text)
if params_refine_text
else None
),
"params_infer_code": (
ChatTTS.Chat.InferCodeParams(**params_infer_code)
if params_infer_code
else None
),
}
else:
batched_requests[key]["text"].append(text)
Expand Down
14 changes: 7 additions & 7 deletions tools/audio/av.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str):


def load_audio(
file: Union[str, BytesIO, Path],
sr: Optional[int] = None,
format: Optional[str] = None,
mono=True,
) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
file: Union[str, BytesIO, Path],
sr: Optional[int] = None,
format: Optional[str] = None,
mono=True,
) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
"""
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L39
"""
Expand Down Expand Up @@ -113,7 +113,7 @@ def frame_iter(container):

np.copyto(decoded_audio[..., offset:end_index], frame_data)
offset += len(frame_data[0])

container.close()

# Truncate the array to the actual size
Expand All @@ -124,4 +124,4 @@ def frame_iter(container):

if sr is not None:
return decoded_audio
return decoded_audio, rate
return decoded_audio, rate

0 comments on commit 6945654

Please sign in to comment.