Skip to content

Commit c00ba37

Browse files
authored
feat(webui): support external spk_emb (#467)
1 parent cc58be2 commit c00ba37

File tree

2 files changed

+61
-30
lines changed

2 files changed

+61
-30
lines changed

examples/web/funcs.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
custom_path: Optional[str] = None
2222

2323
has_interrupted = False
24+
is_in_generate = False
25+
26+
seed_min = 1
27+
seed_max = 4294967295
2428

2529
# 音色选项:用于预置合适的音色
2630
voices = {
@@ -38,13 +42,18 @@
3842

3943

4044
def generate_seed():
41-
return gr.update(value=random.randint(1, 100000000))
45+
return gr.update(value=random.randint(seed_min, seed_max))
4246

4347

4448
# 返回选择音色对应的seed
4549
def on_voice_change(vocie_selection):
4650
return voices.get(vocie_selection)["seed"]
4751

52+
def on_audio_seed_change(audio_seed_input):
53+
with TorchSeedContext(audio_seed_input):
54+
rand_spk = chat.sample_random_speaker()
55+
return rand_spk
56+
4857

4958
def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool:
5059
if cust_path == None:
@@ -79,6 +88,12 @@ def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool:
7988

8089

8190
def reload_chat(coef: Optional[str]) -> str:
91+
global is_in_generate
92+
93+
if is_in_generate:
94+
gr.Warning("Cannot reload when generating!")
95+
return coef
96+
8297
chat.unload()
8398
gr.Info("Model unloaded.")
8499
if len(coef) != 230:
@@ -119,37 +134,33 @@ def refine_text(
119134

120135
return text[0] if isinstance(text, list) else text
121136

122-
def generate_audio(text, temperature, top_P, top_K, audio_seed_input, stream):
137+
def generate_audio(text, temperature, top_P, top_K, spk_emb_text: str, stream):
123138
global chat, has_interrupted
124139

125-
if not text or has_interrupted:
140+
if not text or has_interrupted or not spk_emb_text.startswith("蘁淰"):
126141
return None
127142

128-
with TorchSeedContext(audio_seed_input):
129-
rand_spk = chat.sample_random_speaker()
130-
131143
params_infer_code = ChatTTS.Chat.InferCodeParams(
132-
spk_emb=rand_spk,
144+
spk_emb=spk_emb_text,
133145
temperature=temperature,
134146
top_P=top_P,
135147
top_K=top_K,
136148
)
137149

138-
with TorchSeedContext(audio_seed_input):
139-
wav = chat.infer(
140-
text,
141-
skip_refine_text=True,
142-
params_infer_code=params_infer_code,
143-
stream=stream,
144-
)
145-
if stream:
146-
for gen in wav:
147-
audio = gen[0]
148-
if audio is not None and len(audio) > 0:
149-
yield wav_arr_to_mp3_view(audio[0]).tobytes()
150-
del audio
151-
else:
152-
yield wav_arr_to_mp3_view(np.array(wav[0]).flatten()).tobytes()
150+
wav = chat.infer(
151+
text,
152+
skip_refine_text=True,
153+
params_infer_code=params_infer_code,
154+
stream=stream,
155+
)
156+
if stream:
157+
for gen in wav:
158+
audio = gen[0]
159+
if audio is not None and len(audio) > 0:
160+
yield wav_arr_to_mp3_view(audio[0]).tobytes()
161+
del audio
162+
else:
163+
yield wav_arr_to_mp3_view(np.array(wav[0]).flatten()).tobytes()
153164

154165

155166
def interrupt_generate():
@@ -159,17 +170,20 @@ def interrupt_generate():
159170
chat.interrupt()
160171

161172
def set_buttons_before_generate(generate_button, interrupt_button):
162-
global has_interrupted
173+
global has_interrupted, is_in_generate
163174

164175
has_interrupted = False
176+
is_in_generate = True
165177

166178
return _set_generate_buttons(
167179
generate_button,
168180
interrupt_button,
169181
)
170182

171183
def set_buttons_after_generate(generate_button, interrupt_button, audio_output):
172-
global has_interrupted
184+
global has_interrupted, is_in_generate
185+
186+
is_in_generate = False
173187

174188
return _set_generate_buttons(
175189
generate_button,

examples/web/webui.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,31 @@ def main():
5555
voice_selection = gr.Dropdown(
5656
label="Timbre", choices=voices.keys(), value="Default"
5757
)
58-
audio_seed_input = gr.Number(value=2, label="Audio Seed", interactive=True)
58+
audio_seed_input = gr.Number(
59+
value=2, label="Audio Seed", interactive=True,
60+
minimum=seed_min, maximum=seed_max,
61+
)
5962
generate_audio_seed = gr.Button("\U0001F3B2")
60-
text_seed_input = gr.Number(value=42, label="Text Seed")
63+
text_seed_input = gr.Number(
64+
value=42, label="Text Seed", interactive=True,
65+
minimum=seed_min, maximum=seed_max,
66+
)
6167
generate_text_seed = gr.Button("\U0001F3B2")
6268

6369
with gr.Row():
70+
spk_emb_text = gr.Textbox(
71+
label="Speaker Embedding",
72+
max_lines=3,
73+
show_copy_button=True,
74+
interactive=True,
75+
scale=2,
76+
)
6477
dvae_coef_text = gr.Textbox(
6578
label="DVAE Coefficient",
6679
max_lines=3,
6780
show_copy_button=True,
68-
scale=4,
81+
interactive=True,
82+
scale=2,
6983
)
7084
reload_chat_button = gr.Button("Reload", scale=1)
7185

@@ -88,9 +102,11 @@ def main():
88102
fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input
89103
)
90104

91-
generate_audio_seed.click(generate_seed, inputs=[], outputs=audio_seed_input)
105+
generate_audio_seed.click(generate_seed, outputs=audio_seed_input)
106+
107+
generate_text_seed.click(generate_seed, outputs=text_seed_input)
92108

93-
generate_text_seed.click(generate_seed, inputs=[], outputs=text_seed_input)
109+
audio_seed_input.change(on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text)
94110

95111
reload_chat_button.click(
96112
reload_chat, inputs=dvae_coef_text, outputs=dvae_coef_text
@@ -123,7 +139,7 @@ def make_audio(autoplay, stream):
123139
temperature_slider,
124140
top_p_slider,
125141
top_k_slider,
126-
audio_seed_input,
142+
spk_emb_text,
127143
stream_mode_checkbox,
128144
],
129145
outputs=audio_output,
@@ -168,6 +184,7 @@ def make_audio(autoplay, stream):
168184
logger.error("Models load failed.")
169185
sys.exit(1)
170186

187+
spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
171188
dvae_coef_text.value = chat.coef
172189

173190
demo.launch(

0 commit comments

Comments
 (0)