From 62c35c9720ff88153b10511797d0513a38aec7fd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 8 Jan 2025 03:10:45 +0000 Subject: [PATCH] chore(format): run black on dev --- examples/torchserve/model_handler.py | 33 ++++++++++++++++++---------- tools/audio/av.py | 14 ++++++------ 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/examples/torchserve/model_handler.py b/examples/torchserve/model_handler.py index ae1c73b8e..2d034bf0f 100644 --- a/examples/torchserve/model_handler.py +++ b/examples/torchserve/model_handler.py @@ -43,7 +43,7 @@ def initialize(self, ctx): self.chat = ChatTTS.Chat(logging.getLogger("ChatTTS")) self.chat.normalizer.register("en", normalizer_en_nemo_text()) self.chat.normalizer.register("zh", normalizer_zh_tn()) - + model_dir = ctx.system_properties.get("model_dir") os.chdir(model_dir) if self.chat.load(source="custom", custom_path=model_dir): @@ -68,29 +68,38 @@ def _group_reuqest_by_config(self, data): text = params.pop("text") key = json.dumps(params) - + if key not in batched_requests: params_refine_text = params.get("params_refine_text") params_infer_code = params.get("params_infer_code") - if params_infer_code and params_infer_code.get("manual_seed") is not None: + if ( + params_infer_code + and params_infer_code.get("manual_seed") is not None + ): torch.manual_seed(params_infer_code.get("manual_seed")) params_infer_code["spk_emb"] = self.chat.sample_random_speaker() batched_requests[key] = { "text": [text], - "stream": params.get("stream", False), + "stream": params.get("stream", False), "lang": params.get("lang"), "skip_refine_text": params.get("skip_refine_text", False), - "use_decoder": params.get("use_decoder", True), + "use_decoder": params.get("use_decoder", True), "do_text_normalization": params.get("do_text_normalization", True), - "do_homophone_replacement": params.get("do_homophone_replacement", False), - "params_refine_text": ChatTTS.Chat.InferCodeParams( - **params_refine_text - ) if params_refine_text else None, - "params_infer_code": ChatTTS.Chat.InferCodeParams( - **params_infer_code - ) if params_infer_code else None, + "do_homophone_replacement": params.get( + "do_homophone_replacement", False + ), + "params_refine_text": ( + ChatTTS.Chat.InferCodeParams(**params_refine_text) + if params_refine_text + else None + ), + "params_infer_code": ( + ChatTTS.Chat.InferCodeParams(**params_infer_code) + if params_infer_code + else None + ), } else: batched_requests[key]["text"].append(text) diff --git a/tools/audio/av.py b/tools/audio/av.py index 333b423d6..cd3a7d66a 100644 --- a/tools/audio/av.py +++ b/tools/audio/av.py @@ -41,11 +41,11 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str): def load_audio( - file: Union[str, BytesIO, Path], - sr: Optional[int] = None, - format: Optional[str] = None, - mono=True, - ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: + file: Union[str, BytesIO, Path], + sr: Optional[int] = None, + format: Optional[str] = None, + mono=True, +) -> Union[np.ndarray, Tuple[np.ndarray, int]]: """ https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L39 """ @@ -113,7 +113,7 @@ def frame_iter(container): np.copyto(decoded_audio[..., offset:end_index], frame_data) offset += len(frame_data[0]) - + container.close() # Truncate the array to the actual size @@ -124,4 +124,4 @@ def frame_iter(container): if sr is not None: return decoded_audio - return decoded_audio, rate \ No newline at end of file + return decoded_audio, rate