Skip to content

Commit 97b15f1

Browse files
committed
simplify
1 parent 505f2f2 commit 97b15f1

File tree

8 files changed

+56
-110
lines changed

8 files changed

+56
-110
lines changed

Utils/JDC/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def forward(self, x):
134134
# sizes: (b, 31, 722), (b, 31, 2)
135135
# classifier output consists of predicted pitch classes per frame
136136
# detector output consists of: (isvoice, notvoice) estimates per frame
137-
return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
137+
return torch.abs(classifier_out.squeeze(-1)), GAN_feature, poolblock_out
138138

139139
@staticmethod
140140
def init_weights(m):

convert_mp.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from env import AttrDict
1717
from meldataset import mel_spectrogram, MAX_WAV_VALUE
1818
from models import Generator
19-
from stft import TorchSTFT
2019
from Utils.JDC.model import JDCNet
2120
from asv import compute_similarity2, compute_embedding, get_asv_models
2221

@@ -37,8 +36,8 @@ def get_sim(y, emb_tgts, embedding_models, feature_extractor):
3736
return similarity
3837

3938

40-
def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feature_extractor, search):
41-
y = generator.infer(x, initial_f0, stft)
39+
def get_best_wav(x, initial_f0, wav_tgt, generator, embedding_models, feature_extractor, search):
40+
y = generator.infer(x, initial_f0)
4241
if not search:
4342
return y
4443

@@ -62,7 +61,7 @@ def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feat
6261
lf0 = initial_lf0 + step * i
6362
f0 = torch.exp(lf0)
6463
f0 = torch.where(voiced, f0, initial_f0)
65-
y = generator.infer(x, initial_f0, stft)
64+
y = generator.infer(x, initial_f0)
6665

6766
similarity = get_sim(y, emb_tgts, embedding_models, feature_extractor)
6867

@@ -78,7 +77,7 @@ def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feat
7877
return best_wav
7978

8079

81-
def process_one(line, generator, stft, wavlm, embedding_models, feature_extractor, device, args, h, spk2id, f0_stats):
80+
def process_one(line, generator, wavlm, embedding_models, feature_extractor, device, args, h, spk2id, f0_stats):
8281
with torch.no_grad():
8382
title, src_wav, tgt_wav, tgt_spk, tgt_emb = line.strip().split("|")
8483

@@ -90,6 +89,7 @@ def process_one(line, generator, stft, wavlm, embedding_models, feature_extracto
9089
spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
9190

9291
f0_mean_tgt = f0_stats[tgt_spk]["mean"]
92+
f0_mean_tgt = torch.FloatTensor([f0_mean_tgt]).unsqueeze(0).to(device)
9393

9494
wav_tgt, sr = librosa.load(tgt_wav, sr=16000)
9595
wav_tgt = torch.FloatTensor(wav_tgt).to(device)
@@ -106,7 +106,7 @@ def process_one(line, generator, stft, wavlm, embedding_models, feature_extracto
106106
# cvt
107107
f0 = generator.get_f0(mel, f0_mean_tgt)
108108
x = generator.get_x(x, spk_emb, spk_id)
109-
y = get_best_wav(x, f0, wav_tgt, generator, stft, embedding_models, feature_extractor, search=args.search)
109+
y = get_best_wav(x, f0, wav_tgt, generator, embedding_models, feature_extractor, search=args.search)
110110

111111
audio = y.squeeze()
112112
audio = audio / torch.max(torch.abs(audio)) * 0.95
@@ -128,7 +128,6 @@ def process_batch(batch, args, h, spk2id, f0_stats):
128128
# load models
129129
F0_model = JDCNet(num_class=1, seq_len=192)
130130
generator = Generator(h, F0_model).to(device)
131-
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
132131

133132
state_dict_g = torch.load(args.ptfile, map_location=device)
134133
generator.load_state_dict(state_dict_g['generator'], strict=True)
@@ -156,7 +155,7 @@ def process_batch(batch, args, h, spk2id, f0_stats):
156155
rank = rank[0] if len(rank) > 0 else 0
157156

158157
for line in tqdm(batch, position=rank):
159-
process_one(line, generator, stft, wavlm, embedding_models, feature_extractor, device, args, h, spk2id, f0_stats)
158+
process_one(line, generator, wavlm, embedding_models, feature_extractor, device, args, h, spk2id, f0_stats)
160159

161160

162161
if __name__ == "__main__":

convert_sp.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from env import AttrDict
1515
from meldataset import mel_spectrogram, MAX_WAV_VALUE
1616
from models import Generator
17-
from stft import TorchSTFT
1817
from Utils.JDC.model import JDCNet
1918
from asv import compute_similarity2, compute_embedding, get_asv_models
2019

@@ -35,8 +34,8 @@ def get_sim(y, emb_tgts, embedding_models, feature_extractor):
3534
return similarity
3635

3736

38-
def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feature_extractor, search):
39-
y = generator.infer(x, initial_f0, stft)
37+
def get_best_wav(x, initial_f0, wav_tgt, generator, embedding_models, feature_extractor, search):
38+
y = generator.infer(x, initial_f0)
4039
if not search:
4140
return y
4241

@@ -60,7 +59,7 @@ def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feat
6059
lf0 = initial_lf0 + step * i
6160
f0 = torch.exp(lf0)
6261
f0 = torch.where(voiced, f0, initial_f0)
63-
y = generator.infer(x, initial_f0, stft)
62+
y = generator.infer(x, initial_f0)
6463

6564
similarity = get_sim(y, emb_tgts, embedding_models, feature_extractor)
6665

@@ -104,7 +103,6 @@ def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feat
104103
# load models
105104
F0_model = JDCNet(num_class=1, seq_len=192)
106105
generator = Generator(h, F0_model).to(device)
107-
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
108106

109107
state_dict_g = torch.load(args.ptfile, map_location=device)
110108
generator.load_state_dict(state_dict_g['generator'], strict=True)
@@ -150,6 +148,7 @@ def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feat
150148
spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
151149

152150
f0_mean_tgt = f0_stats[tgt_spk]["mean"]
151+
f0_mean_tgt = torch.FloatTensor([f0_mean_tgt]).unsqueeze(0).to(device)
153152

154153
wav_tgt, sr = librosa.load(tgt_wav, sr=16000)
155154
wav_tgt = torch.FloatTensor(wav_tgt).to(device)
@@ -166,7 +165,7 @@ def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feat
166165
# cvt
167166
f0 = generator.get_f0(mel, f0_mean_tgt)
168167
x = generator.get_x(x, spk_emb, spk_id)
169-
y = get_best_wav(x, f0, wav_tgt, generator, stft, embedding_models, feature_extractor, search=args.search)
168+
y = get_best_wav(x, f0, wav_tgt, generator, embedding_models, feature_extractor, search=args.search)
170169

171170
audio = y.squeeze()
172171
audio = audio / torch.max(torch.abs(audio)) * 0.95

metrics/macs/macs.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,16 @@
1414
from env import AttrDict
1515
from meldataset import mel_spectrogram, MAX_WAV_VALUE
1616
from models import Generator
17-
from stft import TorchSTFT
1817
from Utils.JDC.model import JDCNet
1918
from asv import compute_similarity2, compute_embedding, get_asv_models
2019

2120
from thop import profile, clever_format
2221

2322

2423
class Model(torch.nn.Module):
25-
def __init__(self, generator, stft, wavlm):
24+
def __init__(self, generator, wavlm):
2625
super().__init__()
2726
self.generator = generator
28-
self.stft = stft
2927
self.wavlm = wavlm
3028

3129
def forward(self, wav, mel, f0_mean_tgt, spk_emb, spk_id):
@@ -35,17 +33,25 @@ def forward(self, wav, mel, f0_mean_tgt, spk_emb, spk_id):
3533

3634
f0 = self.generator.get_f0(mel, f0_mean_tgt)
3735
x = self.generator.get_x(x, spk_emb, spk_id)
38-
y = self.generator.infer(x, f0, self.stft)
36+
y = self.generator.infer(x, f0)
37+
38+
39+
class F0(torch.nn.Module):
40+
def __init__(self, generator):
41+
super().__init__()
42+
self.generator = generator
43+
44+
def forward(self, mel, f0_mean_tgt):
45+
f0 = self.generator.get_f0(mel, f0_mean_tgt)
3946

4047

4148
class Voc(torch.nn.Module):
42-
def __init__(self, generator, stft):
49+
def __init__(self, generator):
4350
super().__init__()
4451
self.generator = generator
45-
self.stft = stft
4652

4753
def forward(self, x, f0):
48-
y = self.generator.infer(x, f0, self.stft)
54+
y = self.generator.infer(x, f0)
4955

5056

5157
class Enc(torch.nn.Module):
@@ -104,7 +110,6 @@ def forward(self, spk_id, spk_emb):
104110
# load models
105111
F0_model = JDCNet(num_class=1, seq_len=192)
106112
generator = Generator(h, F0_model)#.to(device)
107-
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft)#.to(device)
108113

109114
# state_dict_g = torch.load(args.ptfile, map_location=device)
110115
# generator.load_state_dict(state_dict_g['generator'], strict=True)
@@ -126,8 +131,9 @@ def forward(self, spk_id, spk_emb):
126131
lines = f.readlines()
127132

128133
# define model & modules
129-
model = Model(generator, stft, wavlm)#.to(device)
130-
mvoc = Voc(generator, stft)#.to(device)
134+
model = Model(generator, wavlm)#.to(device)
135+
mf0 = F0(generator)#.to(device)
136+
mvoc = Voc(generator)#.to(device)
131137
menc = Enc(generator)#.to(device)
132138
mdec = Dec(generator)#.to(device)
133139
mspk = Spk(generator)#.to(device)
@@ -145,6 +151,7 @@ def forward(self, spk_id, spk_emb):
145151
spk_emb = torch.from_numpy(spk_emb).unsqueeze(0)#.to(device)
146152

147153
f0_mean_tgt = f0_stats[tgt_spk]["mean"]
154+
f0_mean_tgt = torch.FloatTensor([f0_mean_tgt]).unsqueeze(0)#.to(device)
148155

149156
wav_tgt, sr = librosa.load(tgt_wav, sr=16000)
150157
wav_tgt = torch.FloatTensor(wav_tgt)#.to(device)
@@ -160,6 +167,12 @@ def forward(self, spk_id, spk_emb):
160167
macs, params = clever_format([macs, params], "%.3f")
161168
print(macs, params)
162169

170+
# macs: f0
171+
print("--- f0 ---")
172+
macs, params = profile(mf0, inputs=(mel, f0_mean_tgt))
173+
macs, params = clever_format([macs, params], "%.3f")
174+
print(macs, params)
175+
163176
# macs: wavlm
164177
print("--- wavlm ---")
165178
macs, params = profile(wavlm, inputs=(wav.unsqueeze(0),))
@@ -205,7 +218,7 @@ def forward(self, spk_id, spk_emb):
205218
# cvt
206219
# f0 = generator.get_f0(mel, f0_mean_tgt)
207220
# x = generator.get_x(x, spk_emb, spk_id)
208-
# y = get_best_wav(x, f0, wav_tgt, generator, stft, embedding_models, feature_extractor, search=args.search)
221+
# y = get_best_wav(x, f0, wav_tgt, generator, embedding_models, feature_extractor, search=args.search)
209222

210223
# audio = y.squeeze()
211224
# audio = audio / torch.max(torch.abs(audio)) * 0.95

metrics/rtf/rtf.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from env import AttrDict
1616
from meldataset import mel_spectrogram, MAX_WAV_VALUE
1717
from models import Generator
18-
from stft import TorchSTFT
1918
from Utils.JDC.model import JDCNet
2019
from asv import compute_similarity2, compute_embedding, get_asv_models
2120

@@ -36,8 +35,8 @@ def get_sim(y, emb_tgts, embedding_models, feature_extractor):
3635
return similarity
3736

3837

39-
def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feature_extractor, search):
40-
y = generator.infer(x, initial_f0, stft)
38+
def get_best_wav(x, initial_f0, wav_tgt, generator, embedding_models, feature_extractor, search):
39+
y = generator.infer(x, initial_f0)
4140
if not search:
4241
return y
4342

@@ -61,7 +60,7 @@ def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feat
6160
lf0 = initial_lf0 + step * i
6261
f0 = torch.exp(lf0)
6362
f0 = torch.where(voiced, f0, initial_f0)
64-
y = generator.infer(x, initial_f0, stft)
63+
y = generator.infer(x, initial_f0)
6564

6665
similarity = get_sim(y, emb_tgts, embedding_models, feature_extractor)
6766

@@ -105,7 +104,6 @@ def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feat
105104
# load models
106105
F0_model = JDCNet(num_class=1, seq_len=192)
107106
generator = Generator(h, F0_model).to(device)
108-
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
109107

110108
state_dict_g = torch.load(args.ptfile, map_location=device)
111109
generator.load_state_dict(state_dict_g['generator'], strict=True)
@@ -153,6 +151,7 @@ def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feat
153151
spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
154152

155153
f0_mean_tgt = f0_stats[tgt_spk]["mean"]
154+
f0_mean_tgt = torch.FloatTensor([f0_mean_tgt]).unsqueeze(0).to(device)
156155

157156
wav_tgt, sr = librosa.load(tgt_wav, sr=16000)
158157
wav_tgt = torch.FloatTensor(wav_tgt).to(device)
@@ -172,7 +171,7 @@ def get_best_wav(x, initial_f0, wav_tgt, generator, stft, embedding_models, feat
172171
# cvt
173172
f0 = generator.get_f0(mel, f0_mean_tgt)
174173
x = generator.get_x(x, spk_emb, spk_id)
175-
y = get_best_wav(x, f0, wav_tgt, generator, stft, embedding_models, feature_extractor, search=args.search)
174+
y = get_best_wav(x, f0, wav_tgt, generator, embedding_models, feature_extractor, search=args.search)
176175

177176
rtf = (time.time() - start) / length_y
178177
total_rtf += rtf

metrics/rtf/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mv rtf.py ../..
22
cd ../..
33

4-
CUDA_VISIBLE_DEVICES=0 python rtf.py --hpfile config_v1_16k.json --ptfile exp/default/g_00700000 --txtpath test/txts/u2s.txt
4+
CUDA_VISIBLE_DEVICES=-1 python rtf.py --hpfile config_v1_16k.json --ptfile exp/default/g_00700000 --txtpath test/txts/s2s.txt
55

66
mv rtf.py metrics/rtf
77
cd metrics/rtf

0 commit comments

Comments
 (0)