2
2
3
3
import click
4
4
import lightning as pl
5
- import numpy as np
6
- import pandas as pd
7
- import textgrid
8
5
import torch
9
6
10
7
import modules .AP_detector
11
8
import modules .g2p
9
+ from modules .utils .export_tool import Exporter
10
+ from modules .utils .post_processing import post_processing
12
11
from train import LitForcedAlignmentTask
13
12
14
- MIN_SP_LENGTH = 0.1
15
- SP_MERGE_LENGTH = 0.3
16
-
17
-
18
- def add_SP (word_seq , word_intervals , wav_length ):
19
- word_seq_res = []
20
- word_intervals_res = []
21
- if len (word_seq ) == 0 :
22
- word_seq_res .append ("SP" )
23
- word_intervals_res .append ([0 , wav_length ])
24
- return word_seq_res , word_intervals_res
25
-
26
- word_seq_res .append ("SP" )
27
- word_intervals_res .append ([0 , word_intervals [0 , 0 ]])
28
- for word , (start , end ) in zip (word_seq , word_intervals ):
29
- if word_intervals_res [- 1 ][1 ] < start :
30
- word_seq_res .append ("SP" )
31
- word_intervals_res .append ([word_intervals_res [- 1 ][1 ], start ])
32
- word_seq_res .append (word )
33
- word_intervals_res .append ([start , end ])
34
- if word_intervals_res [- 1 ][1 ] < wav_length :
35
- word_seq_res .append ("SP" )
36
- word_intervals_res .append ([word_intervals_res [- 1 ][1 ], wav_length ])
37
- if word_intervals [0 , 0 ] <= 0 :
38
- word_seq_res = word_seq_res [1 :]
39
- word_intervals_res = word_intervals_res [1 :]
40
-
41
- return word_seq_res , word_intervals_res
42
-
43
-
44
- def fill_small_gaps (word_seq , word_intervals , wav_length ):
45
- if word_intervals [0 , 0 ] > 0 :
46
- if word_intervals [0 , 0 ] < MIN_SP_LENGTH :
47
- word_intervals [0 , 0 ] = 0
48
-
49
- for idx in range (len (word_seq ) - 1 ):
50
- if word_intervals [idx , 1 ] < word_intervals [idx + 1 , 0 ]:
51
- if word_intervals [idx + 1 , 0 ] - word_intervals [idx , 1 ] < SP_MERGE_LENGTH :
52
- if word_seq [idx ] == "AP" :
53
- if word_seq [idx + 1 ] == "AP" :
54
- # 情况1:gap的左右都是AP
55
- mean = (word_intervals [idx , 1 ] + word_intervals [idx + 1 , 0 ]) / 2
56
- word_intervals [idx , 1 ] = mean
57
- word_intervals [idx + 1 , 0 ] = mean
58
- else :
59
- # 情况2:只有左边是AP
60
- word_intervals [idx , 1 ] = word_intervals [idx + 1 , 0 ]
61
- elif word_seq [idx + 1 ] == "AP" :
62
- # 情况3:只有右边是AP
63
- word_intervals [idx + 1 , 0 ] = word_intervals [idx , 1 ]
64
- else :
65
- # 情况4:gap的左右都不是AP
66
- if (
67
- word_intervals [idx + 1 , 0 ] - word_intervals [idx , 1 ]
68
- < MIN_SP_LENGTH
69
- ):
70
- mean = (word_intervals [idx , 1 ] + word_intervals [idx + 1 , 0 ]) / 2
71
- word_intervals [idx , 1 ] = mean
72
- word_intervals [idx + 1 , 0 ] = mean
73
-
74
- if word_intervals [- 1 , 1 ] < wav_length :
75
- if wav_length - word_intervals [- 1 , 1 ] < MIN_SP_LENGTH :
76
- word_intervals [- 1 , 1 ] = wav_length
77
-
78
- return word_seq , word_intervals
79
-
80
-
81
- def post_processing (predictions ):
82
- print ("Post-processing..." )
83
-
84
- res = []
85
- for (
86
- wav_path ,
87
- wav_length ,
88
- confidence ,
89
- ph_seq ,
90
- ph_intervals ,
91
- word_seq ,
92
- word_intervals ,
93
- ) in predictions :
94
- try :
95
- # fill small gaps
96
- word_seq , word_intervals = fill_small_gaps (
97
- word_seq , word_intervals , wav_length
98
- )
99
- ph_seq , ph_intervals = fill_small_gaps (ph_seq , ph_intervals , wav_length )
100
- # add SP
101
- word_seq , word_intervals = add_SP (word_seq , word_intervals , wav_length )
102
- ph_seq , ph_intervals = add_SP (ph_seq , ph_intervals , wav_length )
103
-
104
- res .append (
105
- [
106
- wav_path ,
107
- wav_length ,
108
- confidence ,
109
- ph_seq ,
110
- ph_intervals ,
111
- word_seq ,
112
- word_intervals ,
113
- ]
114
- )
115
- except Exception as e :
116
- e .args += (wav_path ,)
117
- raise e
118
- return res
119
-
120
-
121
- def save_textgrids (predictions ):
122
- print ("Saving TextGrids..." )
123
-
124
- for (
125
- wav_path ,
126
- wav_length ,
127
- confidence ,
128
- ph_seq ,
129
- ph_intervals ,
130
- word_seq ,
131
- word_intervals ,
132
- ) in predictions :
133
- tg = textgrid .TextGrid ()
134
- word_tier = textgrid .IntervalTier (name = "words" )
135
- ph_tier = textgrid .IntervalTier (name = "phones" )
136
-
137
- for word , (start , end ) in zip (word_seq , word_intervals ):
138
- word_tier .add (start , end , word )
139
-
140
- for ph , (start , end ) in zip (ph_seq , ph_intervals ):
141
- ph_tier .add (minTime = float (start ), maxTime = end , mark = ph )
142
-
143
- tg .append (word_tier )
144
- tg .append (ph_tier )
145
-
146
- label_path = (
147
- wav_path .parent / "TextGrid" / wav_path .with_suffix (".TextGrid" ).name
148
- )
149
- label_path .parent .mkdir (parents = True , exist_ok = True )
150
- tg .write (label_path )
151
-
152
-
153
- def save_htk (predictions ):
154
- print ("Saving htk labels..." )
155
-
156
- for (
157
- wav_path ,
158
- wav_length ,
159
- confidence ,
160
- ph_seq ,
161
- ph_intervals ,
162
- word_seq ,
163
- word_intervals ,
164
- ) in predictions :
165
- label = ""
166
- for ph , (start , end ) in zip (ph_seq , ph_intervals ):
167
- start_time = int (float (start ) * 10000000 )
168
- end_time = int (float (end ) * 10000000 )
169
- label += f"{ start_time } { end_time } { ph } \n "
170
- label_path = (
171
- wav_path .parent / "htk" / "phones" / wav_path .with_suffix (".lab" ).name
172
- )
173
- label_path .parent .mkdir (parents = True , exist_ok = True )
174
- with open (label_path , "w" , encoding = "utf-8" ) as f :
175
- f .write (label )
176
- f .close ()
177
-
178
- label = ""
179
- for word , (start , end ) in zip (word_seq , word_intervals ):
180
- start_time = int (float (start ) * 10000000 )
181
- end_time = int (float (end ) * 10000000 )
182
- label += f"{ start_time } { end_time } { word } \n "
183
- label_path = (
184
- wav_path .parent / "htk" / "words" / wav_path .with_suffix (".lab" ).name
185
- )
186
- label_path .parent .mkdir (parents = True , exist_ok = True )
187
- with open (label_path , "w" , encoding = "utf-8" ) as f :
188
- f .write (label )
189
- f .close ()
190
-
191
-
192
- def save_transcriptions (predictions ):
193
- print ("Saving transcriptions.csv..." )
194
-
195
- folder_to_data = {}
196
-
197
- for (
198
- wav_path ,
199
- wav_length ,
200
- confidence ,
201
- ph_seq ,
202
- ph_intervals ,
203
- word_seq ,
204
- word_intervals ,
205
- ) in predictions :
206
- folder = wav_path .parent
207
- if folder in folder_to_data :
208
- curr_data = folder_to_data [folder ]
209
- else :
210
- curr_data = {
211
- "name" : [],
212
- "word_seq" : [],
213
- "word_dur" : [],
214
- "ph_seq" : [],
215
- "ph_dur" : [],
216
- }
217
-
218
- name = wav_path .with_suffix ("" ).name
219
- word_seq = " " .join (word_seq )
220
- ph_seq = " " .join (ph_seq )
221
- word_dur = []
222
- ph_dur = []
223
-
224
- last_word_end = 0
225
- for start , end in word_intervals :
226
- dur = np .round (end - last_word_end , 5 )
227
- word_dur .append (dur )
228
- last_word_end += dur
229
-
230
- last_ph_end = 0
231
- for start , end in ph_intervals :
232
- dur = np .round (end - last_ph_end , 5 )
233
- ph_dur .append (dur )
234
- last_ph_end += dur
235
-
236
- word_dur = " " .join ([str (i ) for i in word_dur ])
237
- ph_dur = " " .join ([str (i ) for i in ph_dur ])
238
-
239
- curr_data ["name" ].append (name )
240
- curr_data ["word_seq" ].append (word_seq )
241
- curr_data ["word_dur" ].append (word_dur )
242
- curr_data ["ph_seq" ].append (ph_seq )
243
- curr_data ["ph_dur" ].append (ph_dur )
244
-
245
- folder_to_data [folder ] = curr_data
246
-
247
- for folder , data in folder_to_data .items ():
248
- df = pd .DataFrame (data )
249
- path = folder / "transcriptions"
250
- if not path .exists ():
251
- path .mkdir (parents = True , exist_ok = True )
252
- df .to_csv (path / "transcriptions.csv" , index = False )
253
-
254
-
255
- def save_confidence_fn (predictions ):
256
- print ("saving confidence..." )
257
-
258
- folder_to_data = {}
259
-
260
- for (
261
- wav_path ,
262
- wav_length ,
263
- confidence ,
264
- ph_seq ,
265
- ph_intervals ,
266
- word_seq ,
267
- word_intervals ,
268
- ) in predictions :
269
- folder = wav_path .parent
270
- if folder in folder_to_data :
271
- curr_data = folder_to_data [folder ]
272
- else :
273
- curr_data = {
274
- "name" : [],
275
- "confidence" : [],
276
- }
277
-
278
- name = wav_path .with_suffix ("" ).name
279
- curr_data ["name" ].append (name )
280
- curr_data ["confidence" ].append (confidence )
281
-
282
- folder_to_data [folder ] = curr_data
283
-
284
- for folder , data in folder_to_data .items ():
285
- df = pd .DataFrame (data )
286
- path = folder / "confidence"
287
- if not path .exists ():
288
- path .mkdir (parents = True , exist_ok = True )
289
- df .to_csv (path / "confidence.csv" , index = False )
290
-
291
13
292
14
@click .command ()
293
15
@click .option (
@@ -329,9 +51,9 @@ def save_confidence_fn(predictions):
329
51
required = False ,
330
52
type = str ,
331
53
help = "Types of output file, separated by comma. Supported types:"
332
- "textgrid(praat),"
333
- " htk(lab,nnsvs,sinsy),"
334
- " transcriptions.csv(diffsinger,trans,transcription,transcriptions)" ,
54
+ "textgrid(praat),"
55
+ " htk(lab,nnsvs,sinsy),"
56
+ " transcriptions.csv(diffsinger,trans,transcription,transcriptions)" ,
335
57
)
336
58
@click .option (
337
59
"--save_confidence" ,
@@ -349,15 +71,15 @@ def save_confidence_fn(predictions):
349
71
help = "(only used when --g2p=='Dictionary') path to the dictionary" ,
350
72
)
351
73
def main (
352
- ckpt ,
353
- folder ,
354
- mode ,
355
- g2p ,
356
- ap_detector ,
357
- in_format ,
358
- out_formats ,
359
- save_confidence ,
360
- ** kwargs ,
74
+ ckpt ,
75
+ folder ,
76
+ mode ,
77
+ g2p ,
78
+ ap_detector ,
79
+ in_format ,
80
+ out_formats ,
81
+ save_confidence ,
82
+ ** kwargs ,
361
83
):
362
84
if not g2p .endswith ("G2P" ):
363
85
g2p += "G2P"
@@ -380,27 +102,13 @@ def main(
380
102
predictions = trainer .predict (model , dataloaders = dataset , return_predictions = True )
381
103
382
104
predictions = get_AP .process (predictions )
383
- predictions = post_processing (predictions )
384
- if "textgrid" in out_formats or "praat" in out_formats :
385
- save_textgrids (predictions )
386
- if (
387
- "htk" in out_formats
388
- or "lab" in out_formats
389
- or "nnsvs" in out_formats
390
- or "sinsy" in out_formats
391
- ):
392
- save_htk (predictions )
393
- if (
394
- "trans" in out_formats
395
- or "transcription" in out_formats
396
- or "transcriptions" in out_formats
397
- or "transcriptions.csv" in out_formats
398
- or "diffsinger" in out_formats
399
- ):
400
- save_transcriptions (predictions )
105
+ predictions , log = post_processing (predictions )
106
+ exporter = Exporter (predictions , log )
401
107
402
108
if save_confidence :
403
- save_confidence_fn (predictions )
109
+ out_formats .append ('confidence' )
110
+
111
+ exporter .export (out_formats )
404
112
405
113
print ("Output files are saved to the same folder as the input wav files." )
406
114
0 commit comments