Skip to content

Commit e699e3e

Browse files
committed
Added DNSMOS in evaluation, updated Chinese README
1 parent 7994a0f commit e699e3e

File tree

7 files changed

+225
-23
lines changed

7 files changed

+225
-23
lines changed

README-CN.md

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,41 @@
99

1010
我们将继续改进模型质量并添加更多功能。
1111

12+
## 评估📊
13+
14+
我们对 Seed-VC 的语音转换能力进行了系列客观评估。
15+
为了便于复现,源音频是来自 LibriTTS-test-clean 的 100 个随机语句,参考音频是 12 个随机挑选的具有独特特征的自然声音。<br>
16+
17+
源音频位于 `./examples/libritts-test-clean` <br>
18+
参考音频位于 `./examples/reference` <br>
19+
20+
我们从说话人嵌入余弦相似度(SECS)、词错误率(WER)和字符错误率(CER)三个方面评估了转换结果,并将我们的结果与两个强大的开源基线模型,即 [OpenVoice](https://github.com/myshell-ai/OpenVoice)[CosyVoice](https://github.com/FunAudioLLM/CosyVoice),进行了比较。
21+
下表的结果显示,我们的 Seed-VC 模型在发音清晰度和说话人相似度上均显著优于基线模型。<br>
22+
23+
| 模型\指标 | SECS↑ | WER↓ | CER↓ | SIG↑ | BAK↑ | OVRL↑ |
24+
|---------------|------------|------------|------------|----------|----------|----------|
25+
| Ground Truth | 1.0000 | 0.0802 | 0.0157 | ~ | ~ | ~ |
26+
| OpenVoice | 0.7547 | 0.1546 | 0.0473 | **3.56** | **4.02** | **3.27** |
27+
| CosyVoice | 0.8440 | 0.1898 | 0.0729 | 3.51 | **4.02** | 3.21 |
28+
| Seed-VC(Ours) | **0.8676** | **0.1199** | **0.0292** | 3.42 | 3.97 | 3.11 |
29+
30+
*ASR 结果由 [facebook/hubert-large-ls960-ft](https://huggingface.co/facebook/hubert-large-ls960-ft) 模型计算*
31+
*说话人嵌入由 [resemblyzer](https://github.com/resemble-ai/Resemblyzer) 模型计算* <br>
32+
33+
你可以通过运行 `eval.py` 脚本来复现评估。
34+
```bash
35+
python eval.py
36+
--source ./examples/libritts-test-clean
37+
--target ./examples/reference
38+
--output ./examples/eval/converted
39+
--diffusion-steps 25
40+
--length-adjust 1.0
41+
--inference-cfg-rate 0.7
42+
--xvector-extractor "resemblyzer"
43+
--baseline "" # 填入 openvoice 或 cosyvoice 来计算基线结果
44+
--max-samples 100 # 要处理的最大源语句数
45+
```
46+
在此之前,如果你想运行基线评估,请确保已在 `../OpenVoice/``../CosyVoice/` 目录下正确安装了 openvoice 和 cosyvoice 仓库。
1247
## 安装 📥
1348
建议在 Windows 或 Linux 上使用 Python 3.10:
1449
```bash
@@ -20,15 +55,14 @@ pip install -r requirements.txt
2055

2156
命令行推理:
2257
```bash
23-
python inference.py --source <源语音文件路径> \
24-
--target <参考语音文件路径> \
25-
--output <输出目录> \
26-
--diffusion-steps 25 \ # 建议歌声转换时使用50~100
27-
--length-adjust 1.0 \
28-
--inference-cfg-rate 0.7 \
29-
--n-quantizers 3 \
30-
--f0-condition False \ # 歌声转换时设置为 True
31-
--auto-f0-condition False \ # 设置为 True 可自动调整源音高到目标音高,歌声转换中通常不使用
58+
python inference.py --source <源语音文件路径>
59+
--target <参考语音文件路径>
60+
--output <输出目录>
61+
--diffusion-steps 25 # 建议歌声转换时使用50~100
62+
--length-adjust 1.0
63+
--inference-cfg-rate 0.7
64+
--f0-condition False # 歌声转换时设置为 True
65+
--auto-f0-condition False # 设置为 True 可自动调整源音高到目标音高,歌声转换中通常不使用
3266
--semi-tone-shift 0 # 歌声转换的半音移调
3367
```
3468
其中:
@@ -38,7 +72,6 @@ python inference.py --source <源语音文件路径> \
3872
- `diffusion-steps` 使用的扩散步数,默认25,最佳质量建议使用50-100,最快推理使用4-10
3973
- `length-adjust` 长度调整系数,默认1.0,<1.0加速语音,>1.0减慢语音
4074
- `inference-cfg-rate` 对输出有细微影响,默认0.7
41-
- `n-quantizers` 用的 FAcodec 码本数量,默认3,使用的码本越少,保留的源音频韵律越少
4275
- `f0-condition` 是否根据源音频的音高调整输出音高,默认 False,歌声转换时设置为 True
4376
- `auto-f0-condition` 是否自动将源音高调整到目标音高水平,默认 False,歌声转换中通常不使用
4477
- `semi-tone-shift` 歌声转换中的半音移调,默认0
@@ -59,13 +92,16 @@ python app.py
5992
- [x] 这已在 f0 条件模型中启用,但不确定效果如何...
6093
- [ ] 潜在的架构改进
6194
- [x] 类似U-ViT 的skip connection
62-
- [x] 将输入更改为 [FAcodec](https://github.com/Plachtaa/FAcodec) tokens
95+
- [x] 将输入更改为 OpenAI Whisper
6396
- [ ] 自定义数据训练代码
6497
- [x] 歌声解码器更改为 NVIDIA 的 BigVGAN
6598
- [ ] 44k Hz 歌声转换模型
6699
- [ ] 更多待添加
67100

68101
## 更新日志 🗒️
102+
- 2024-09-26:
103+
- 添加了客观指标评估结果
104+
- 将语音内容编码器更改为 OpenAI Whisper
69105
- 2024-09-22:
70106
- 将歌声转换模型的解码器更改为 BigVGAN,解决了大部分高音部分无法正确转换的问题
71107
- 在Web UI中支持对长输入音频的分段处理以及流式输出

README.md

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ We are keeping on improving the model quality and adding more features.
1010

1111
## Evaluation📊
1212
We have performed a series of objective evaluations on our Seed-VC's voice conversion capabilities.
13-
For ease for reproduction, source audios are 100 random utterances from LibriTTS-test-clean, and reference audios are 12 randomly picked in-the-wild voices with unique characteristics. <br>
13+
For ease of reproduction, source audios are 100 random utterances from LibriTTS-test-clean, and reference audios are 12 randomly picked in-the-wild voices with unique characteristics. <br>
1414

1515
Source audios can be found under `./examples/libritts-test-clean` <br>
1616
Reference audios can be found under `./examples/reference` <br>
@@ -19,20 +19,22 @@ We evaluate the conversion results in terms of speaker embedding cosine similari
1919
our results with two strong open sourced baselines, namely [OpenVoice](https://github.com/myshell-ai/OpenVoice) and [CosyVoice](https://github.com/FunAudioLLM/CosyVoice).
2020
Results in the table below shows that our Seed-VC model significantly outperforms the baseline models in both intelligibility and speaker similarity.<br>
2121

22-
| Models\Metrics | SECS↑ | WER↓ | CER↓ |
23-
|----------------|------------|------------|------------|
24-
| OpenVoice | 0.7547 | 0.1546 | 0.0473 |
25-
| CosyVoice | 0.8440 | 0.1898 | 0.0729 |
26-
| Seed-VC(Ours) | **0.8676** | **0.1199** | **0.0292** |
22+
| Models\Metrics | SECS↑ | WER↓ | CER↓ | SIG↑ | BAK↑ | OVRL↑ |
23+
|----------------|------------|------------|------------|----------|----------|----------|
24+
| Ground Truth | 1.0000 | 0.0802 | 0.0157 | ~ | ~ | ~ |
25+
| OpenVoice | 0.7547 | 0.1546 | 0.0473 | **3.56** | **4.02** | **3.27** |
26+
| CosyVoice | 0.8440 | 0.1898 | 0.0729 | 3.51 | **4.02** | 3.21 |
27+
| Seed-VC(Ours) | **0.8676** | **0.1199** | **0.0292** | 3.42 | 3.97 | 3.11 |
2728

2829
*ASR result computed by facebook/hubert-large-ls960-ft model*
2930
*Speaker embedding computed by resemblyzer model* <br>
3031

3132
You can reproduce the evaluation by running `eval.py` script.
3233
```bash
33-
python eval.py --source ./examples/libritts-test-clean \
34-
--reference ./examples/reference
35-
--output ./examples/eval/converted/
34+
python eval.py
35+
--source ./examples/libritts-test-clean
36+
--target ./examples/reference
37+
--output ./examples/eval/converted
3638
--diffusion-steps 25
3739
--length-adjust 1.0
3840
--inference-cfg-rate 0.7
@@ -53,7 +55,7 @@ Checkpoints of the latest model release will be downloaded automatically when fi
5355

5456
Command line inference:
5557
```bash
56-
python inference.py --source <source-wav> \
58+
python inference.py --source <source-wav>
5759
--target <referene-wav>
5860
--output <output-dir>
5961
--diffusion-steps 25 # recommended 50~100 for singingvoice conversion

baselines/dnsmos/dnsmos_computor.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import glob
2+
import librosa
3+
import tqdm
4+
import numpy as np
5+
import torchaudio
6+
import torch
7+
8+
# ignore all warning
9+
import warnings
10+
11+
warnings.filterwarnings("ignore")
12+
13+
import concurrent.futures
14+
import glob
15+
import os
16+
import librosa
17+
import numpy as np
18+
import onnxruntime as ort
19+
import pandas as pd
20+
from tqdm import tqdm
21+
22+
SAMPLING_RATE = 16000
23+
INPUT_LENGTH = 9.01
24+
25+
26+
class DNSMOSComputer:
27+
def __init__(
28+
self, primary_model_path, p808_model_path, device="cuda", device_id=0
29+
) -> None:
30+
self.onnx_sess = ort.InferenceSession(
31+
primary_model_path, providers=["CUDAExecutionProvider"]
32+
)
33+
self.p808_onnx_sess = ort.InferenceSession(
34+
p808_model_path, providers=["CUDAExecutionProvider"]
35+
)
36+
self.onnx_sess.set_providers(["CUDAExecutionProvider"], [{"device_id": device_id}])
37+
self.p808_onnx_sess.set_providers(
38+
["CUDAExecutionProvider"], [{"device_id": device_id}]
39+
)
40+
kwargs = {
41+
"sample_rate": 16000,
42+
"hop_length": 160,
43+
"n_fft": 320 + 1,
44+
"n_mels": 120,
45+
"mel_scale": "slaney",
46+
}
47+
self.mel_transform = torchaudio.transforms.MelSpectrogram(**kwargs).to(f"cuda:{device_id}")
48+
49+
def audio_melspec(
50+
self, audio, n_mels=120, frame_size=320, hop_length=160, sr=16000, to_db=True
51+
):
52+
mel_specgram = self.mel_transform(torch.Tensor(audio).cuda())
53+
mel_spec = mel_specgram.cpu()
54+
if to_db:
55+
mel_spec = (librosa.power_to_db(mel_spec, ref=np.max) + 40) / 40
56+
return mel_spec.T
57+
58+
def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS):
59+
if is_personalized_MOS:
60+
p_ovr = np.poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046])
61+
p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726])
62+
p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132])
63+
else:
64+
p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535])
65+
p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439])
66+
p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546])
67+
sig_poly = p_sig(sig)
68+
bak_poly = p_bak(bak)
69+
ovr_poly = p_ovr(ovr)
70+
return sig_poly, bak_poly, ovr_poly
71+
72+
def compute(self, audio, sampling_rate, is_personalized_MOS=False):
73+
fs = SAMPLING_RATE
74+
if isinstance(audio, str):
75+
audio, _ = librosa.load(audio, sr=fs)
76+
elif sampling_rate != fs:
77+
# resample audio
78+
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=fs)
79+
actual_audio_len = len(audio)
80+
len_samples = int(INPUT_LENGTH * fs)
81+
while len(audio) < len_samples:
82+
audio = np.append(audio, audio)
83+
num_hops = int(np.floor(len(audio) / fs) - INPUT_LENGTH) + 1
84+
hop_len_samples = fs
85+
predicted_mos_sig_seg_raw = []
86+
predicted_mos_bak_seg_raw = []
87+
predicted_mos_ovr_seg_raw = []
88+
predicted_mos_sig_seg = []
89+
predicted_mos_bak_seg = []
90+
predicted_mos_ovr_seg = []
91+
predicted_p808_mos = []
92+
93+
for idx in range(num_hops):
94+
audio_seg = audio[
95+
int(idx * hop_len_samples) : int((idx + INPUT_LENGTH) * hop_len_samples)
96+
]
97+
if len(audio_seg) < len_samples:
98+
continue
99+
input_features = np.array(audio_seg).astype("float32")[np.newaxis, :]
100+
p808_input_features = np.array(
101+
self.audio_melspec(audio=audio_seg[:-160])
102+
).astype("float32")[np.newaxis, :, :]
103+
oi = {"input_1": input_features}
104+
p808_oi = {"input_1": p808_input_features}
105+
p808_mos = self.p808_onnx_sess.run(None, p808_oi)[0][0][0]
106+
mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run(None, oi)[0][0]
107+
mos_sig, mos_bak, mos_ovr = self.get_polyfit_val(
108+
mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_MOS
109+
)
110+
predicted_mos_sig_seg_raw.append(mos_sig_raw)
111+
predicted_mos_bak_seg_raw.append(mos_bak_raw)
112+
predicted_mos_ovr_seg_raw.append(mos_ovr_raw)
113+
predicted_mos_sig_seg.append(mos_sig)
114+
predicted_mos_bak_seg.append(mos_bak)
115+
predicted_mos_ovr_seg.append(mos_ovr)
116+
predicted_p808_mos.append(p808_mos)
117+
clip_dict = {
118+
"filename": "audio_clip",
119+
"len_in_sec": actual_audio_len / fs,
120+
"sr": fs,
121+
}
122+
clip_dict["num_hops"] = num_hops
123+
clip_dict["OVRL_raw"] = np.mean(predicted_mos_ovr_seg_raw)
124+
clip_dict["SIG_raw"] = np.mean(predicted_mos_sig_seg_raw)
125+
clip_dict["BAK_raw"] = np.mean(predicted_mos_bak_seg_raw)
126+
clip_dict["OVRL"] = np.mean(predicted_mos_ovr_seg)
127+
clip_dict["SIG"] = np.mean(predicted_mos_sig_seg)
128+
clip_dict["BAK"] = np.mean(predicted_mos_bak_seg)
129+
clip_dict["P808_MOS"] = np.mean(predicted_p808_mos)
130+
return clip_dict

baselines/dnsmos/model_v8.onnx

220 KB
Binary file not shown.

baselines/dnsmos/sig_bak_ovr.onnx

1.1 MB
Binary file not shown.

eval.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,29 @@
3131
import jiwer
3232
import string
3333

34+
from baselines.dnsmos.dnsmos_computor import DNSMOSComputer
35+
36+
def calc_mos(computor, audio, orin_sr):
37+
# only 16k audio is supported
38+
target_sr = 16000
39+
if orin_sr != 16000:
40+
audio = librosa.resample(
41+
audio, orig_sr=orin_sr, target_sr=target_sr, res_type="kaiser_fast"
42+
)
43+
result = computor.compute(audio, target_sr, False)
44+
sig, bak, ovr = result["SIG"], result["BAK"], result["OVRL"]
45+
46+
if ovr == 0:
47+
print("calculate dns mos failed")
48+
return sig, bak, ovr
49+
50+
mos_computer = DNSMOSComputer(
51+
"baselines/dnsmos/sig_bak_ovr.onnx",
52+
"baselines/dnsmos/model_v8.onnx",
53+
device="cuda",
54+
device_id=0,
55+
)
56+
3457
def load_models(args):
3558
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
3659
"DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
@@ -175,6 +198,7 @@ def main(args):
175198
gt_cer_list = []
176199
vc_wer_list = []
177200
vc_cer_list = []
201+
dnsmos_list = []
178202
for source_i, source_line in enumerate(tqdm(source_audio_list)):
179203
if source_i >= max_samples:
180204
break
@@ -287,12 +311,20 @@ def main(args):
287311
vc_wer_list.append(vc_wer)
288312
vc_cer_list.append(vc_cer)
289313

314+
# calculate dnsmos
315+
sig, bak, ovr = calc_mos(mos_computer, vc_wave_16k.squeeze(0).cpu().numpy(), 16000)
316+
dnsmos_list.append((sig, bak, ovr))
317+
290318
print(f"Average GT WER: {sum(gt_wer_list) / len(gt_wer_list)}")
291319
print(f"Average GT CER: {sum(gt_cer_list) / len(gt_cer_list)}")
292320
print(f"Average VC WER: {sum(vc_wer_list) / len(vc_wer_list)}")
293321
print(f"Average VC CER: {sum(vc_cer_list) / len(vc_cer_list)}")
294322
print(f"Average similarity: {sum(similarity_list) / len(similarity_list)}")
295323

324+
print(f"Average DNS MOS SIG: {sum([x[0] for x in dnsmos_list]) / len(dnsmos_list)}")
325+
print(f"Average DNS MOS BAK: {sum([x[1] for x in dnsmos_list]) / len(dnsmos_list)}")
326+
print(f"Average DNS MOS OVR: {sum([x[2] for x in dnsmos_list]) / len(dnsmos_list)}")
327+
296328
# save wer and cer result into this directory as a txt
297329
with open(osp.join(conversion_result_dir, source_index, "result.txt"), 'w') as f:
298330
f.write(f"GT WER: {sum(gt_wer_list[-len(target_audio_list):]) / len(target_audio_list)}\n")
@@ -316,6 +348,10 @@ def main(args):
316348
f.write(f"VC WER: {sum(vc_wer_list) / len(vc_wer_list)}\n")
317349
f.write(f"VC CER: {sum(vc_cer_list) / len(vc_cer_list)}\n")
318350

351+
print(f"Average DNS MOS SIG: {sum([x[0] for x in dnsmos_list]) / len(dnsmos_list)}")
352+
print(f"Average DNS MOS BAK: {sum([x[1] for x in dnsmos_list]) / len(dnsmos_list)}")
353+
print(f"Average DNS MOS OVR: {sum([x[2] for x in dnsmos_list]) / len(dnsmos_list)}")
354+
319355

320356
def convert(
321357
source_path,

inference.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,6 @@ def main(args):
253253
# Convert to waveform
254254
# if f0_condition:
255255
vc_wave = bigvgan_model(vc_target).squeeze(1) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
256-
# else:
257-
# vc_wave = hift_gen.inference(vc_target, f0=None)
258256

259257
time_vc_end = time.time()
260258
print(f"RTF: {(time_vc_end - time_vc_start) / vc_wave.size(-1) * sr}")

0 commit comments

Comments
 (0)