Skip to content

Commit 844db2b

Browse files
authored
Update server.py
1 parent 614c308 commit 844db2b

File tree

1 file changed

+272
-78
lines changed

1 file changed

+272
-78
lines changed

runtime/python/grpc/server.py

+272-78
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,290 @@
1-
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
1+
#!/usr/bin/env python3
2+
# coding: utf‑8
3+
"""
4+
CosyVoice gRPC back‑end – updated to mirror the FastAPI logic
5+
* loads CosyVoice2 with TRT / FP16 first (falls back to CosyVoice)
6+
* inference_zero_shot ➜ adds stream=False + speed
7+
* inference_instruct ➜ keeps original “speaker‑ID” path
8+
* inference_instruct2 ➜ new: prompt‑audio + speed (no speaker‑ID)
9+
"""
10+
11+
import io, tempfile, requests, soundfile as sf, torchaudio
1412
import os
1513
import sys
1614
from concurrent import futures
1715
import argparse
18-
import cosyvoice_pb2
19-
import cosyvoice_pb2_grpc
2016
import logging
21-
logging.getLogger('matplotlib').setLevel(logging.WARNING)
2217
import grpc
23-
import torch
2418
import numpy as np
19+
import torch
20+
21+
import cosyvoice_pb2
22+
import cosyvoice_pb2_grpc
23+
24+
# ────────────────────────────────────────────────────────────────────────────────
25+
# set‑up
26+
# ────────────────────────────────────────────────────────────────────────────────
27+
logging.getLogger("matplotlib").setLevel(logging.WARNING)
28+
logging.basicConfig(level=logging.INFO,
29+
format="%(asctime)s %(levelname)s %(message)s")
30+
2531
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
26-
sys.path.append('{}/../../..'.format(ROOT_DIR))
27-
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
28-
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
32+
sys.path.extend([
33+
f"{ROOT_DIR}/../../..",
34+
f"{ROOT_DIR}/../../../third_party/Matcha-TTS",
35+
])
36+
37+
from cosyvoice.cli.cosyvoice import CosyVoice2 # noqa: E402
38+
39+
40+
# ────────────────────────────────────────────────────────────────────────────────
41+
# helpers
42+
# ────────────────────────────────────────────────────────────────────────────────
43+
def _bytes_to_tensor(wav_bytes: bytes) -> torch.Tensor:
44+
"""
45+
Convert int16 little‑endian PCM bytes → torch.FloatTensor in range [‑1,1]
46+
"""
47+
speech = torch.from_numpy(
48+
np.frombuffer(wav_bytes, dtype=np.int16)
49+
).unsqueeze(0).float() / (2 ** 15)
50+
return speech # [1, T]
51+
52+
53+
def _yield_audio(model_output):
54+
"""
55+
Generator that converts CosyVoice output → protobuf Response messages.
56+
"""
57+
for seg in model_output:
58+
pcm16 = (seg["tts_speech"].numpy() * (2 ** 15)).astype(np.int16)
59+
resp = cosyvoice_pb2.Response(tts_audio=pcm16.tobytes())
60+
yield resp
61+
62+
import os, io, tempfile, requests, torch, torchaudio
63+
from urllib.parse import urlparse
64+
65+
def _load_prompt_from_url(url: str, target_sr: int = 16_000) -> torch.Tensor:
66+
"""Download an audio file from ``url`` (wav / mp3 / flac / ogg …),
67+
convert it to mono, resample to ``target_sr`` if necessary,
68+
and return a 1×T float‑tensor in the range ‑1…1."""
69+
70+
# ─── 1. Download ────────────────────────────────────────────────────────────
71+
resp = requests.get(url, timeout=10)
72+
if resp.status_code != 200:
73+
raise HTTPException(status_code=400,
74+
detail=f"Failed to download audio from URL: {url}")
2975

30-
logging.basicConfig(level=logging.DEBUG,
31-
format='%(asctime)s %(levelname)s %(message)s')
76+
# Infer extension from URL *or* Content‑Type header
77+
ext = os.path.splitext(urlparse(url).path)[1].lower()
78+
if not ext and 'content-type' in resp.headers:
79+
mime = resp.headers['content-type'].split(';')[0].strip()
80+
ext = {
81+
'audio/mpeg': '.mp3',
82+
'audio/wav': '.wav',
83+
'audio/x-wav': '.wav',
84+
'audio/flac': '.flac',
85+
'audio/ogg': '.ogg',
86+
'audio/x-m4a': '.m4a',
87+
}.get(mime, '.audio') # generic fallback
3288

89+
with tempfile.NamedTemporaryFile(suffix=ext or '.audio', delete=False) as f:
90+
f.write(resp.content)
91+
temp_path = f.name
3392

93+
# ─── 2. Decode (torchaudio first, pydub fallback) ──────────────────────────
94+
try:
95+
# Let torchaudio pick the right backend automatically
96+
speech, sample_rate = torchaudio.load(temp_path)
97+
except Exception:
98+
# Fallback that works as long as ffmpeg is present
99+
from pydub import AudioSegment
100+
import numpy as np
101+
102+
seg = AudioSegment.from_file(temp_path) # any ffmpeg‑supported format
103+
seg = seg.set_channels(1) # force mono
104+
sample_rate = seg.frame_rate
105+
np_audio = np.array(seg.get_array_of_samples()).astype(np.float32)
106+
# normalise to −1…1 based on sample width
107+
np_audio /= float(1 << (8 * seg.sample_width - 1))
108+
speech = torch.from_numpy(np_audio).unsqueeze(0)
109+
110+
finally:
111+
os.unlink(temp_path)
112+
113+
# ─── 3. Ensure mono + correct sample‑rate ──────────────────────────────────
114+
if speech.dim() > 1 and speech.size(0) > 1:
115+
speech = speech.mean(dim=0, keepdim=True) # average to mono
116+
117+
if sample_rate != target_sr:
118+
speech = torchaudio.transforms.Resample(orig_freq=sample_rate,
119+
new_freq=target_sr)(speech)
120+
return speech
121+
122+
# ────────────────────────────────────────────────────────────────────────────────
123+
# gRPC service
124+
# ────────────────────────────────────────────────────────────────────────────────
34125
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
35126
def __init__(self, args):
127+
# try CosyVoice2 first (preferred runtime: TRT / FP16)
36128
try:
37-
self.cosyvoice = CosyVoice(args.model_dir)
129+
self.cosyvoice = CosyVoice2(args.model_dir,
130+
load_jit=False,
131+
load_trt=True,
132+
fp16=True)
133+
logging.info("Loaded CosyVoice2 (TRT / FP16).")
38134
except Exception:
39-
try:
40-
self.cosyvoice = CosyVoice2(args.model_dir)
41-
except Exception:
42-
raise TypeError('no valid model_type!')
43-
logging.info('grpc service initialized')
135+
raise TypeError("No valid CosyVoice model found!")
44136

137+
# ---------------------------------------------------------------------
138+
# single bi‑di streaming RPC
139+
# ---------------------------------------------------------------------
45140
def Inference(self, request, context):
46-
if request.HasField('sft_request'):
47-
logging.info('get sft inference request')
48-
model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id)
49-
elif request.HasField('zero_shot_request'):
50-
logging.info('get zero_shot inference request')
51-
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
52-
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
53-
model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text,
54-
request.zero_shot_request.prompt_text,
55-
prompt_speech_16k)
56-
elif request.HasField('cross_lingual_request'):
57-
logging.info('get cross_lingual inference request')
58-
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
59-
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
60-
model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
61-
else:
62-
logging.info('get instruct inference request')
63-
model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,
64-
request.instruct_request.spk_id,
65-
request.instruct_request.instruct_text)
66-
67-
logging.info('send inference response')
68-
for i in model_output:
69-
response = cosyvoice_pb2.Response()
70-
response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
71-
yield response
72-
73-
74-
def main():
75-
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
76-
cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
77-
grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port))
78-
grpcServer.start()
79-
logging.info("server listening on 0.0.0.0:{}".format(args.port))
80-
grpcServer.wait_for_termination()
81-
82-
83-
if __name__ == '__main__':
141+
"""Route to the correct model call based on the oneof field present."""
142+
# 1. Supervised fine‑tuning
143+
if request.HasField("sft_request"):
144+
logging.info("Received SFT inference request")
145+
mo = self.cosyvoice.inference_sft(
146+
request.sft_request.tts_text,
147+
request.sft_request.spk_id
148+
)
149+
yield from _yield_audio(mo)
150+
return
151+
152+
# 2. Zero‑shot speaker cloning (bytes OR S3 URL)
153+
if request.HasField("zero_shot_request"):
154+
logging.info("Received zero‑shot inference request")
155+
zr = request.zero_shot_request
156+
tmp_path = None # initialise so we can delete later
157+
158+
try:
159+
# ───── determine payload type ──────────────────────────────────────
160+
if zr.prompt_audio.startswith(b'http'):
161+
prompt = _load_prompt_from_url(zr.prompt_audio.decode('utf‑8'))
162+
else:
163+
# —— legacy raw PCM bytes —— -----------------------------------
164+
prompt = _bytes_to_tensor(zr.prompt_audio)
165+
166+
# ───── call the model ──────────────────────────────────────────────
167+
speed = getattr(zr, "speed", 1.0)
168+
mo = self.cosyvoice.inference_zero_shot(
169+
zr.tts_text,
170+
zr.prompt_text,
171+
prompt,
172+
stream=False,
173+
speed=speed,
174+
)
175+
176+
finally:
177+
# clean up any temporary file we created
178+
if tmp_path and os.path.exists(tmp_path):
179+
try:
180+
os.remove(tmp_path)
181+
except Exception as e:
182+
logging.warning("Could not remove temp file %s: %s", tmp_path, e)
183+
184+
yield from _yield_audio(mo)
185+
return
186+
187+
# 3. Cross‑lingual
188+
if request.HasField("cross_lingual_request"):
189+
logging.info("Received cross‑lingual inference request")
190+
cr = request.cross_lingual_request
191+
tmp_path = None
192+
193+
try:
194+
if cr.prompt_audio.startswith(b'http'): # S3 URL case
195+
prompt = _load_prompt_from_url(cr.prompt_audio.decode('utf‑8'))
196+
else: # legacy raw bytes
197+
prompt = _bytes_to_tensor(cr.prompt_audio)
198+
199+
mo = self.cosyvoice.inference_cross_lingual(
200+
cr.tts_text,
201+
prompt
202+
)
203+
204+
finally:
205+
if tmp_path and os.path.exists(tmp_path):
206+
try:
207+
os.remove(tmp_path)
208+
except Exception as e:
209+
logging.warning("Could not remove temp file %s: %s",
210+
tmp_path, e)
211+
212+
yield from _yield_audio(mo)
213+
return
214+
215+
216+
# 4. Instruct‑2 (CosyVoice2 supports this variant only)
217+
if request.HasField("instruct_request"):
218+
219+
ir = request.instruct_request
220+
221+
# ---- require that the descriptor contains the field -------------------
222+
if 'prompt_audio' not in ir.DESCRIPTOR.fields_by_name:
223+
context.abort(
224+
grpc.StatusCode.INVALID_ARGUMENT,
225+
"Server expects instruct‑2 proto with a 'prompt_audio' field."
226+
)
227+
228+
# ---- make sure it is non‑empty (no HasField for proto3 scalars) -------
229+
if len(ir.prompt_audio) == 0:
230+
context.abort(
231+
grpc.StatusCode.INVALID_ARGUMENT,
232+
"'prompt_audio' must not be empty for instruct‑2 requests."
233+
)
234+
235+
logging.info("Received instruct‑2 inference request")
236+
237+
# convert to bytes no matter what scalar type the proto uses
238+
pa_bytes = (ir.prompt_audio.encode('utf-8') if isinstance(ir.prompt_audio, str)
239+
else ir.prompt_audio)
240+
241+
# URL vs raw bytes
242+
if pa_bytes.startswith(b"http"):
243+
prompt = _load_prompt_from_url(pa_bytes.decode('utf-8'))
244+
else:
245+
prompt = _bytes_to_tensor(pa_bytes)
246+
247+
speed = getattr(ir, "speed", 1.0)
248+
mo = self.cosyvoice.inference_instruct2(
249+
ir.tts_text,
250+
ir.instruct_text,
251+
prompt,
252+
stream=False,
253+
speed=speed,
254+
)
255+
256+
yield from _yield_audio(mo)
257+
return
258+
259+
260+
# unknown request type
261+
context.abort(grpc.StatusCode.INVALID_ARGUMENT,
262+
"Unsupported request type in oneof field.")
263+
264+
265+
# ────────────────────────────────────────────────────────────────────────────────
266+
# entry‑point
267+
# ────────────────────────────────────────────────────────────────────────────────
268+
def serve(args):
269+
server = grpc.server(
270+
futures.ThreadPoolExecutor(max_workers=args.max_conc),
271+
maximum_concurrent_rpcs=args.max_conc
272+
)
273+
cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(
274+
CosyVoiceServiceImpl(args), server
275+
)
276+
server.add_insecure_port(f"0.0.0.0:{args.port}")
277+
server.start()
278+
logging.info("CosyVoice gRPC server listening on 0.0.0.0:%d", args.port)
279+
server.wait_for_termination()
280+
281+
282+
if __name__ == "__main__":
84283
parser = argparse.ArgumentParser()
85-
parser.add_argument('--port',
86-
type=int,
87-
default=50000)
88-
parser.add_argument('--max_conc',
89-
type=int,
90-
default=4)
91-
parser.add_argument('--model_dir',
92-
type=str,
93-
default='iic/CosyVoice-300M',
94-
help='local path or modelscope repo id')
95-
args = parser.parse_args()
96-
main()
284+
parser.add_argument("--port", type=int, default=8000)
285+
parser.add_argument("--max_conc", type=int, default=4,
286+
help="maximum concurrent requests / threads")
287+
parser.add_argument("--model_dir", type=str,
288+
default="pretrained_models/CosyVoice2-0.5B",
289+
help="local path or ModelScope repo id")
290+
serve(parser.parse_args())

0 commit comments

Comments
 (0)