|
1 |
| -# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) |
2 |
| -# |
3 |
| -# Licensed under the Apache License, Version 2.0 (the "License"); |
4 |
| -# you may not use this file except in compliance with the License. |
5 |
| -# You may obtain a copy of the License at |
6 |
| -# |
7 |
| -# http://www.apache.org/licenses/LICENSE-2.0 |
8 |
| -# |
9 |
| -# Unless required by applicable law or agreed to in writing, software |
10 |
| -# distributed under the License is distributed on an "AS IS" BASIS, |
11 |
| -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 |
| -# See the License for the specific language governing permissions and |
13 |
| -# limitations under the License. |
| 1 | +#!/usr/bin/env python3 |
| 2 | +# coding: utf‑8 |
| 3 | +""" |
| 4 | +CosyVoice gRPC back‑end – updated to mirror the FastAPI logic |
| 5 | +* loads CosyVoice2 with TRT / FP16 first (falls back to CosyVoice) |
| 6 | +* inference_zero_shot ➜ adds stream=False + speed |
| 7 | +* inference_instruct ➜ keeps original “speaker‑ID” path |
| 8 | +* inference_instruct2 ➜ new: prompt‑audio + speed (no speaker‑ID) |
| 9 | +""" |
| 10 | + |
| 11 | +import io, tempfile, requests, soundfile as sf, torchaudio |
14 | 12 | import os
|
15 | 13 | import sys
|
16 | 14 | from concurrent import futures
|
17 | 15 | import argparse
|
18 |
| -import cosyvoice_pb2 |
19 |
| -import cosyvoice_pb2_grpc |
20 | 16 | import logging
|
21 |
| -logging.getLogger('matplotlib').setLevel(logging.WARNING) |
22 | 17 | import grpc
|
23 |
| -import torch |
24 | 18 | import numpy as np
|
| 19 | +import torch |
| 20 | + |
| 21 | +import cosyvoice_pb2 |
| 22 | +import cosyvoice_pb2_grpc |
| 23 | + |
| 24 | +# ──────────────────────────────────────────────────────────────────────────────── |
| 25 | +# set‑up |
| 26 | +# ──────────────────────────────────────────────────────────────────────────────── |
| 27 | +logging.getLogger("matplotlib").setLevel(logging.WARNING) |
| 28 | +logging.basicConfig(level=logging.INFO, |
| 29 | + format="%(asctime)s %(levelname)s %(message)s") |
| 30 | + |
25 | 31 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
26 |
| -sys.path.append('{}/../../..'.format(ROOT_DIR)) |
27 |
| -sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) |
28 |
| -from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 |
| 32 | +sys.path.extend([ |
| 33 | + f"{ROOT_DIR}/../../..", |
| 34 | + f"{ROOT_DIR}/../../../third_party/Matcha-TTS", |
| 35 | +]) |
| 36 | + |
| 37 | +from cosyvoice.cli.cosyvoice import CosyVoice2 # noqa: E402 |
| 38 | + |
| 39 | + |
| 40 | +# ──────────────────────────────────────────────────────────────────────────────── |
| 41 | +# helpers |
| 42 | +# ──────────────────────────────────────────────────────────────────────────────── |
| 43 | +def _bytes_to_tensor(wav_bytes: bytes) -> torch.Tensor: |
| 44 | + """ |
| 45 | + Convert int16 little‑endian PCM bytes → torch.FloatTensor in range [‑1,1] |
| 46 | + """ |
| 47 | + speech = torch.from_numpy( |
| 48 | + np.frombuffer(wav_bytes, dtype=np.int16) |
| 49 | + ).unsqueeze(0).float() / (2 ** 15) |
| 50 | + return speech # [1, T] |
| 51 | + |
| 52 | + |
| 53 | +def _yield_audio(model_output): |
| 54 | + """ |
| 55 | + Generator that converts CosyVoice output → protobuf Response messages. |
| 56 | + """ |
| 57 | + for seg in model_output: |
| 58 | + pcm16 = (seg["tts_speech"].numpy() * (2 ** 15)).astype(np.int16) |
| 59 | + resp = cosyvoice_pb2.Response(tts_audio=pcm16.tobytes()) |
| 60 | + yield resp |
| 61 | + |
| 62 | +import os, io, tempfile, requests, torch, torchaudio |
| 63 | +from urllib.parse import urlparse |
| 64 | + |
| 65 | +def _load_prompt_from_url(url: str, target_sr: int = 16_000) -> torch.Tensor: |
| 66 | + """Download an audio file from ``url`` (wav / mp3 / flac / ogg …), |
| 67 | + convert it to mono, resample to ``target_sr`` if necessary, |
| 68 | + and return a 1×T float‑tensor in the range ‑1…1.""" |
| 69 | + |
| 70 | + # ─── 1. Download ──────────────────────────────────────────────────────────── |
| 71 | + resp = requests.get(url, timeout=10) |
| 72 | + if resp.status_code != 200: |
| 73 | + raise HTTPException(status_code=400, |
| 74 | + detail=f"Failed to download audio from URL: {url}") |
29 | 75 |
|
30 |
| -logging.basicConfig(level=logging.DEBUG, |
31 |
| - format='%(asctime)s %(levelname)s %(message)s') |
| 76 | + # Infer extension from URL *or* Content‑Type header |
| 77 | + ext = os.path.splitext(urlparse(url).path)[1].lower() |
| 78 | + if not ext and 'content-type' in resp.headers: |
| 79 | + mime = resp.headers['content-type'].split(';')[0].strip() |
| 80 | + ext = { |
| 81 | + 'audio/mpeg': '.mp3', |
| 82 | + 'audio/wav': '.wav', |
| 83 | + 'audio/x-wav': '.wav', |
| 84 | + 'audio/flac': '.flac', |
| 85 | + 'audio/ogg': '.ogg', |
| 86 | + 'audio/x-m4a': '.m4a', |
| 87 | + }.get(mime, '.audio') # generic fallback |
32 | 88 |
|
| 89 | + with tempfile.NamedTemporaryFile(suffix=ext or '.audio', delete=False) as f: |
| 90 | + f.write(resp.content) |
| 91 | + temp_path = f.name |
33 | 92 |
|
| 93 | + # ─── 2. Decode (torchaudio first, pydub fallback) ────────────────────────── |
| 94 | + try: |
| 95 | + # Let torchaudio pick the right backend automatically |
| 96 | + speech, sample_rate = torchaudio.load(temp_path) |
| 97 | + except Exception: |
| 98 | + # Fallback that works as long as ffmpeg is present |
| 99 | + from pydub import AudioSegment |
| 100 | + import numpy as np |
| 101 | + |
| 102 | + seg = AudioSegment.from_file(temp_path) # any ffmpeg‑supported format |
| 103 | + seg = seg.set_channels(1) # force mono |
| 104 | + sample_rate = seg.frame_rate |
| 105 | + np_audio = np.array(seg.get_array_of_samples()).astype(np.float32) |
| 106 | + # normalise to −1…1 based on sample width |
| 107 | + np_audio /= float(1 << (8 * seg.sample_width - 1)) |
| 108 | + speech = torch.from_numpy(np_audio).unsqueeze(0) |
| 109 | + |
| 110 | + finally: |
| 111 | + os.unlink(temp_path) |
| 112 | + |
| 113 | + # ─── 3. Ensure mono + correct sample‑rate ────────────────────────────────── |
| 114 | + if speech.dim() > 1 and speech.size(0) > 1: |
| 115 | + speech = speech.mean(dim=0, keepdim=True) # average to mono |
| 116 | + |
| 117 | + if sample_rate != target_sr: |
| 118 | + speech = torchaudio.transforms.Resample(orig_freq=sample_rate, |
| 119 | + new_freq=target_sr)(speech) |
| 120 | + return speech |
| 121 | + |
| 122 | +# ──────────────────────────────────────────────────────────────────────────────── |
| 123 | +# gRPC service |
| 124 | +# ──────────────────────────────────────────────────────────────────────────────── |
34 | 125 | class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
35 | 126 | def __init__(self, args):
|
| 127 | + # try CosyVoice2 first (preferred runtime: TRT / FP16) |
36 | 128 | try:
|
37 |
| - self.cosyvoice = CosyVoice(args.model_dir) |
| 129 | + self.cosyvoice = CosyVoice2(args.model_dir, |
| 130 | + load_jit=False, |
| 131 | + load_trt=True, |
| 132 | + fp16=True) |
| 133 | + logging.info("Loaded CosyVoice2 (TRT / FP16).") |
38 | 134 | except Exception:
|
39 |
| - try: |
40 |
| - self.cosyvoice = CosyVoice2(args.model_dir) |
41 |
| - except Exception: |
42 |
| - raise TypeError('no valid model_type!') |
43 |
| - logging.info('grpc service initialized') |
| 135 | + raise TypeError("No valid CosyVoice model found!") |
44 | 136 |
|
| 137 | + # --------------------------------------------------------------------- |
| 138 | + # single bi‑di streaming RPC |
| 139 | + # --------------------------------------------------------------------- |
45 | 140 | def Inference(self, request, context):
|
46 |
| - if request.HasField('sft_request'): |
47 |
| - logging.info('get sft inference request') |
48 |
| - model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id) |
49 |
| - elif request.HasField('zero_shot_request'): |
50 |
| - logging.info('get zero_shot inference request') |
51 |
| - prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0) |
52 |
| - prompt_speech_16k = prompt_speech_16k.float() / (2**15) |
53 |
| - model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, |
54 |
| - request.zero_shot_request.prompt_text, |
55 |
| - prompt_speech_16k) |
56 |
| - elif request.HasField('cross_lingual_request'): |
57 |
| - logging.info('get cross_lingual inference request') |
58 |
| - prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0) |
59 |
| - prompt_speech_16k = prompt_speech_16k.float() / (2**15) |
60 |
| - model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k) |
61 |
| - else: |
62 |
| - logging.info('get instruct inference request') |
63 |
| - model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, |
64 |
| - request.instruct_request.spk_id, |
65 |
| - request.instruct_request.instruct_text) |
66 |
| - |
67 |
| - logging.info('send inference response') |
68 |
| - for i in model_output: |
69 |
| - response = cosyvoice_pb2.Response() |
70 |
| - response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() |
71 |
| - yield response |
72 |
| - |
73 |
| - |
74 |
| -def main(): |
75 |
| - grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc) |
76 |
| - cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer) |
77 |
| - grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port)) |
78 |
| - grpcServer.start() |
79 |
| - logging.info("server listening on 0.0.0.0:{}".format(args.port)) |
80 |
| - grpcServer.wait_for_termination() |
81 |
| - |
82 |
| - |
83 |
| -if __name__ == '__main__': |
| 141 | + """Route to the correct model call based on the oneof field present.""" |
| 142 | + # 1. Supervised fine‑tuning |
| 143 | + if request.HasField("sft_request"): |
| 144 | + logging.info("Received SFT inference request") |
| 145 | + mo = self.cosyvoice.inference_sft( |
| 146 | + request.sft_request.tts_text, |
| 147 | + request.sft_request.spk_id |
| 148 | + ) |
| 149 | + yield from _yield_audio(mo) |
| 150 | + return |
| 151 | + |
| 152 | + # 2. Zero‑shot speaker cloning (bytes OR S3 URL) |
| 153 | + if request.HasField("zero_shot_request"): |
| 154 | + logging.info("Received zero‑shot inference request") |
| 155 | + zr = request.zero_shot_request |
| 156 | + tmp_path = None # initialise so we can delete later |
| 157 | + |
| 158 | + try: |
| 159 | + # ───── determine payload type ────────────────────────────────────── |
| 160 | + if zr.prompt_audio.startswith(b'http'): |
| 161 | + prompt = _load_prompt_from_url(zr.prompt_audio.decode('utf‑8')) |
| 162 | + else: |
| 163 | + # —— legacy raw PCM bytes —— ----------------------------------- |
| 164 | + prompt = _bytes_to_tensor(zr.prompt_audio) |
| 165 | + |
| 166 | + # ───── call the model ────────────────────────────────────────────── |
| 167 | + speed = getattr(zr, "speed", 1.0) |
| 168 | + mo = self.cosyvoice.inference_zero_shot( |
| 169 | + zr.tts_text, |
| 170 | + zr.prompt_text, |
| 171 | + prompt, |
| 172 | + stream=False, |
| 173 | + speed=speed, |
| 174 | + ) |
| 175 | + |
| 176 | + finally: |
| 177 | + # clean up any temporary file we created |
| 178 | + if tmp_path and os.path.exists(tmp_path): |
| 179 | + try: |
| 180 | + os.remove(tmp_path) |
| 181 | + except Exception as e: |
| 182 | + logging.warning("Could not remove temp file %s: %s", tmp_path, e) |
| 183 | + |
| 184 | + yield from _yield_audio(mo) |
| 185 | + return |
| 186 | + |
| 187 | + # 3. Cross‑lingual |
| 188 | + if request.HasField("cross_lingual_request"): |
| 189 | + logging.info("Received cross‑lingual inference request") |
| 190 | + cr = request.cross_lingual_request |
| 191 | + tmp_path = None |
| 192 | + |
| 193 | + try: |
| 194 | + if cr.prompt_audio.startswith(b'http'): # S3 URL case |
| 195 | + prompt = _load_prompt_from_url(cr.prompt_audio.decode('utf‑8')) |
| 196 | + else: # legacy raw bytes |
| 197 | + prompt = _bytes_to_tensor(cr.prompt_audio) |
| 198 | + |
| 199 | + mo = self.cosyvoice.inference_cross_lingual( |
| 200 | + cr.tts_text, |
| 201 | + prompt |
| 202 | + ) |
| 203 | + |
| 204 | + finally: |
| 205 | + if tmp_path and os.path.exists(tmp_path): |
| 206 | + try: |
| 207 | + os.remove(tmp_path) |
| 208 | + except Exception as e: |
| 209 | + logging.warning("Could not remove temp file %s: %s", |
| 210 | + tmp_path, e) |
| 211 | + |
| 212 | + yield from _yield_audio(mo) |
| 213 | + return |
| 214 | + |
| 215 | + |
| 216 | + # 4. Instruct‑2 (CosyVoice2 supports this variant only) |
| 217 | + if request.HasField("instruct_request"): |
| 218 | + |
| 219 | + ir = request.instruct_request |
| 220 | + |
| 221 | + # ---- require that the descriptor contains the field ------------------- |
| 222 | + if 'prompt_audio' not in ir.DESCRIPTOR.fields_by_name: |
| 223 | + context.abort( |
| 224 | + grpc.StatusCode.INVALID_ARGUMENT, |
| 225 | + "Server expects instruct‑2 proto with a 'prompt_audio' field." |
| 226 | + ) |
| 227 | + |
| 228 | + # ---- make sure it is non‑empty (no HasField for proto3 scalars) ------- |
| 229 | + if len(ir.prompt_audio) == 0: |
| 230 | + context.abort( |
| 231 | + grpc.StatusCode.INVALID_ARGUMENT, |
| 232 | + "'prompt_audio' must not be empty for instruct‑2 requests." |
| 233 | + ) |
| 234 | + |
| 235 | + logging.info("Received instruct‑2 inference request") |
| 236 | + |
| 237 | + # convert to bytes no matter what scalar type the proto uses |
| 238 | + pa_bytes = (ir.prompt_audio.encode('utf-8') if isinstance(ir.prompt_audio, str) |
| 239 | + else ir.prompt_audio) |
| 240 | + |
| 241 | + # URL vs raw bytes |
| 242 | + if pa_bytes.startswith(b"http"): |
| 243 | + prompt = _load_prompt_from_url(pa_bytes.decode('utf-8')) |
| 244 | + else: |
| 245 | + prompt = _bytes_to_tensor(pa_bytes) |
| 246 | + |
| 247 | + speed = getattr(ir, "speed", 1.0) |
| 248 | + mo = self.cosyvoice.inference_instruct2( |
| 249 | + ir.tts_text, |
| 250 | + ir.instruct_text, |
| 251 | + prompt, |
| 252 | + stream=False, |
| 253 | + speed=speed, |
| 254 | + ) |
| 255 | + |
| 256 | + yield from _yield_audio(mo) |
| 257 | + return |
| 258 | + |
| 259 | + |
| 260 | + # unknown request type |
| 261 | + context.abort(grpc.StatusCode.INVALID_ARGUMENT, |
| 262 | + "Unsupported request type in oneof field.") |
| 263 | + |
| 264 | + |
| 265 | +# ──────────────────────────────────────────────────────────────────────────────── |
| 266 | +# entry‑point |
| 267 | +# ──────────────────────────────────────────────────────────────────────────────── |
| 268 | +def serve(args): |
| 269 | + server = grpc.server( |
| 270 | + futures.ThreadPoolExecutor(max_workers=args.max_conc), |
| 271 | + maximum_concurrent_rpcs=args.max_conc |
| 272 | + ) |
| 273 | + cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server( |
| 274 | + CosyVoiceServiceImpl(args), server |
| 275 | + ) |
| 276 | + server.add_insecure_port(f"0.0.0.0:{args.port}") |
| 277 | + server.start() |
| 278 | + logging.info("CosyVoice gRPC server listening on 0.0.0.0:%d", args.port) |
| 279 | + server.wait_for_termination() |
| 280 | + |
| 281 | + |
| 282 | +if __name__ == "__main__": |
84 | 283 | parser = argparse.ArgumentParser()
|
85 |
| - parser.add_argument('--port', |
86 |
| - type=int, |
87 |
| - default=50000) |
88 |
| - parser.add_argument('--max_conc', |
89 |
| - type=int, |
90 |
| - default=4) |
91 |
| - parser.add_argument('--model_dir', |
92 |
| - type=str, |
93 |
| - default='iic/CosyVoice-300M', |
94 |
| - help='local path or modelscope repo id') |
95 |
| - args = parser.parse_args() |
96 |
| - main() |
| 284 | + parser.add_argument("--port", type=int, default=8000) |
| 285 | + parser.add_argument("--max_conc", type=int, default=4, |
| 286 | + help="maximum concurrent requests / threads") |
| 287 | + parser.add_argument("--model_dir", type=str, |
| 288 | + default="pretrained_models/CosyVoice2-0.5B", |
| 289 | + help="local path or ModelScope repo id") |
| 290 | + serve(parser.parse_args()) |
0 commit comments