Skip to content

Commit ef29144

Browse files
committed
Refactor infer.py
1 parent 0e4a4b6 commit ef29144

File tree

3 files changed

+327
-311
lines changed

3 files changed

+327
-311
lines changed

infer.py

Lines changed: 19 additions & 311 deletions
Original file line numberDiff line numberDiff line change
@@ -2,292 +2,14 @@
22

33
import click
44
import lightning as pl
5-
import numpy as np
6-
import pandas as pd
7-
import textgrid
85
import torch
96

107
import modules.AP_detector
118
import modules.g2p
9+
from modules.utils.export_tool import Exporter
10+
from modules.utils.post_processing import post_processing
1211
from train import LitForcedAlignmentTask
1312

14-
MIN_SP_LENGTH = 0.1
15-
SP_MERGE_LENGTH = 0.3
16-
17-
18-
def add_SP(word_seq, word_intervals, wav_length):
19-
word_seq_res = []
20-
word_intervals_res = []
21-
if len(word_seq) == 0:
22-
word_seq_res.append("SP")
23-
word_intervals_res.append([0, wav_length])
24-
return word_seq_res, word_intervals_res
25-
26-
word_seq_res.append("SP")
27-
word_intervals_res.append([0, word_intervals[0, 0]])
28-
for word, (start, end) in zip(word_seq, word_intervals):
29-
if word_intervals_res[-1][1] < start:
30-
word_seq_res.append("SP")
31-
word_intervals_res.append([word_intervals_res[-1][1], start])
32-
word_seq_res.append(word)
33-
word_intervals_res.append([start, end])
34-
if word_intervals_res[-1][1] < wav_length:
35-
word_seq_res.append("SP")
36-
word_intervals_res.append([word_intervals_res[-1][1], wav_length])
37-
if word_intervals[0, 0] <= 0:
38-
word_seq_res = word_seq_res[1:]
39-
word_intervals_res = word_intervals_res[1:]
40-
41-
return word_seq_res, word_intervals_res
42-
43-
44-
def fill_small_gaps(word_seq, word_intervals, wav_length):
45-
if word_intervals[0, 0] > 0:
46-
if word_intervals[0, 0] < MIN_SP_LENGTH:
47-
word_intervals[0, 0] = 0
48-
49-
for idx in range(len(word_seq) - 1):
50-
if word_intervals[idx, 1] < word_intervals[idx + 1, 0]:
51-
if word_intervals[idx + 1, 0] - word_intervals[idx, 1] < SP_MERGE_LENGTH:
52-
if word_seq[idx] == "AP":
53-
if word_seq[idx + 1] == "AP":
54-
# 情况1:gap的左右都是AP
55-
mean = (word_intervals[idx, 1] + word_intervals[idx + 1, 0]) / 2
56-
word_intervals[idx, 1] = mean
57-
word_intervals[idx + 1, 0] = mean
58-
else:
59-
# 情况2:只有左边是AP
60-
word_intervals[idx, 1] = word_intervals[idx + 1, 0]
61-
elif word_seq[idx + 1] == "AP":
62-
# 情况3:只有右边是AP
63-
word_intervals[idx + 1, 0] = word_intervals[idx, 1]
64-
else:
65-
# 情况4:gap的左右都不是AP
66-
if (
67-
word_intervals[idx + 1, 0] - word_intervals[idx, 1]
68-
< MIN_SP_LENGTH
69-
):
70-
mean = (word_intervals[idx, 1] + word_intervals[idx + 1, 0]) / 2
71-
word_intervals[idx, 1] = mean
72-
word_intervals[idx + 1, 0] = mean
73-
74-
if word_intervals[-1, 1] < wav_length:
75-
if wav_length - word_intervals[-1, 1] < MIN_SP_LENGTH:
76-
word_intervals[-1, 1] = wav_length
77-
78-
return word_seq, word_intervals
79-
80-
81-
def post_processing(predictions):
82-
print("Post-processing...")
83-
84-
res = []
85-
for (
86-
wav_path,
87-
wav_length,
88-
confidence,
89-
ph_seq,
90-
ph_intervals,
91-
word_seq,
92-
word_intervals,
93-
) in predictions:
94-
try:
95-
# fill small gaps
96-
word_seq, word_intervals = fill_small_gaps(
97-
word_seq, word_intervals, wav_length
98-
)
99-
ph_seq, ph_intervals = fill_small_gaps(ph_seq, ph_intervals, wav_length)
100-
# add SP
101-
word_seq, word_intervals = add_SP(word_seq, word_intervals, wav_length)
102-
ph_seq, ph_intervals = add_SP(ph_seq, ph_intervals, wav_length)
103-
104-
res.append(
105-
[
106-
wav_path,
107-
wav_length,
108-
confidence,
109-
ph_seq,
110-
ph_intervals,
111-
word_seq,
112-
word_intervals,
113-
]
114-
)
115-
except Exception as e:
116-
e.args += (wav_path,)
117-
raise e
118-
return res
119-
120-
121-
def save_textgrids(predictions):
122-
print("Saving TextGrids...")
123-
124-
for (
125-
wav_path,
126-
wav_length,
127-
confidence,
128-
ph_seq,
129-
ph_intervals,
130-
word_seq,
131-
word_intervals,
132-
) in predictions:
133-
tg = textgrid.TextGrid()
134-
word_tier = textgrid.IntervalTier(name="words")
135-
ph_tier = textgrid.IntervalTier(name="phones")
136-
137-
for word, (start, end) in zip(word_seq, word_intervals):
138-
word_tier.add(start, end, word)
139-
140-
for ph, (start, end) in zip(ph_seq, ph_intervals):
141-
ph_tier.add(minTime=float(start), maxTime=end, mark=ph)
142-
143-
tg.append(word_tier)
144-
tg.append(ph_tier)
145-
146-
label_path = (
147-
wav_path.parent / "TextGrid" / wav_path.with_suffix(".TextGrid").name
148-
)
149-
label_path.parent.mkdir(parents=True, exist_ok=True)
150-
tg.write(label_path)
151-
152-
153-
def save_htk(predictions):
154-
print("Saving htk labels...")
155-
156-
for (
157-
wav_path,
158-
wav_length,
159-
confidence,
160-
ph_seq,
161-
ph_intervals,
162-
word_seq,
163-
word_intervals,
164-
) in predictions:
165-
label = ""
166-
for ph, (start, end) in zip(ph_seq, ph_intervals):
167-
start_time = int(float(start) * 10000000)
168-
end_time = int(float(end) * 10000000)
169-
label += f"{start_time} {end_time} {ph}\n"
170-
label_path = (
171-
wav_path.parent / "htk" / "phones" / wav_path.with_suffix(".lab").name
172-
)
173-
label_path.parent.mkdir(parents=True, exist_ok=True)
174-
with open(label_path, "w", encoding="utf-8") as f:
175-
f.write(label)
176-
f.close()
177-
178-
label = ""
179-
for word, (start, end) in zip(word_seq, word_intervals):
180-
start_time = int(float(start) * 10000000)
181-
end_time = int(float(end) * 10000000)
182-
label += f"{start_time} {end_time} {word}\n"
183-
label_path = (
184-
wav_path.parent / "htk" / "words" / wav_path.with_suffix(".lab").name
185-
)
186-
label_path.parent.mkdir(parents=True, exist_ok=True)
187-
with open(label_path, "w", encoding="utf-8") as f:
188-
f.write(label)
189-
f.close()
190-
191-
192-
def save_transcriptions(predictions):
193-
print("Saving transcriptions.csv...")
194-
195-
folder_to_data = {}
196-
197-
for (
198-
wav_path,
199-
wav_length,
200-
confidence,
201-
ph_seq,
202-
ph_intervals,
203-
word_seq,
204-
word_intervals,
205-
) in predictions:
206-
folder = wav_path.parent
207-
if folder in folder_to_data:
208-
curr_data = folder_to_data[folder]
209-
else:
210-
curr_data = {
211-
"name": [],
212-
"word_seq": [],
213-
"word_dur": [],
214-
"ph_seq": [],
215-
"ph_dur": [],
216-
}
217-
218-
name = wav_path.with_suffix("").name
219-
word_seq = " ".join(word_seq)
220-
ph_seq = " ".join(ph_seq)
221-
word_dur = []
222-
ph_dur = []
223-
224-
last_word_end = 0
225-
for start, end in word_intervals:
226-
dur = np.round(end - last_word_end, 5)
227-
word_dur.append(dur)
228-
last_word_end += dur
229-
230-
last_ph_end = 0
231-
for start, end in ph_intervals:
232-
dur = np.round(end - last_ph_end, 5)
233-
ph_dur.append(dur)
234-
last_ph_end += dur
235-
236-
word_dur = " ".join([str(i) for i in word_dur])
237-
ph_dur = " ".join([str(i) for i in ph_dur])
238-
239-
curr_data["name"].append(name)
240-
curr_data["word_seq"].append(word_seq)
241-
curr_data["word_dur"].append(word_dur)
242-
curr_data["ph_seq"].append(ph_seq)
243-
curr_data["ph_dur"].append(ph_dur)
244-
245-
folder_to_data[folder] = curr_data
246-
247-
for folder, data in folder_to_data.items():
248-
df = pd.DataFrame(data)
249-
path = folder / "transcriptions"
250-
if not path.exists():
251-
path.mkdir(parents=True, exist_ok=True)
252-
df.to_csv(path / "transcriptions.csv", index=False)
253-
254-
255-
def save_confidence_fn(predictions):
256-
print("saving confidence...")
257-
258-
folder_to_data = {}
259-
260-
for (
261-
wav_path,
262-
wav_length,
263-
confidence,
264-
ph_seq,
265-
ph_intervals,
266-
word_seq,
267-
word_intervals,
268-
) in predictions:
269-
folder = wav_path.parent
270-
if folder in folder_to_data:
271-
curr_data = folder_to_data[folder]
272-
else:
273-
curr_data = {
274-
"name": [],
275-
"confidence": [],
276-
}
277-
278-
name = wav_path.with_suffix("").name
279-
curr_data["name"].append(name)
280-
curr_data["confidence"].append(confidence)
281-
282-
folder_to_data[folder] = curr_data
283-
284-
for folder, data in folder_to_data.items():
285-
df = pd.DataFrame(data)
286-
path = folder / "confidence"
287-
if not path.exists():
288-
path.mkdir(parents=True, exist_ok=True)
289-
df.to_csv(path / "confidence.csv", index=False)
290-
29113

29214
@click.command()
29315
@click.option(
@@ -329,9 +51,9 @@ def save_confidence_fn(predictions):
32951
required=False,
33052
type=str,
33153
help="Types of output file, separated by comma. Supported types:"
332-
"textgrid(praat),"
333-
" htk(lab,nnsvs,sinsy),"
334-
" transcriptions.csv(diffsinger,trans,transcription,transcriptions)",
54+
"textgrid(praat),"
55+
" htk(lab,nnsvs,sinsy),"
56+
" transcriptions.csv(diffsinger,trans,transcription,transcriptions)",
33557
)
33658
@click.option(
33759
"--save_confidence",
@@ -349,15 +71,15 @@ def save_confidence_fn(predictions):
34971
help="(only used when --g2p=='Dictionary') path to the dictionary",
35072
)
35173
def main(
352-
ckpt,
353-
folder,
354-
mode,
355-
g2p,
356-
ap_detector,
357-
in_format,
358-
out_formats,
359-
save_confidence,
360-
**kwargs,
74+
ckpt,
75+
folder,
76+
mode,
77+
g2p,
78+
ap_detector,
79+
in_format,
80+
out_formats,
81+
save_confidence,
82+
**kwargs,
36183
):
36284
if not g2p.endswith("G2P"):
36385
g2p += "G2P"
@@ -380,27 +102,13 @@ def main(
380102
predictions = trainer.predict(model, dataloaders=dataset, return_predictions=True)
381103

382104
predictions = get_AP.process(predictions)
383-
predictions = post_processing(predictions)
384-
if "textgrid" in out_formats or "praat" in out_formats:
385-
save_textgrids(predictions)
386-
if (
387-
"htk" in out_formats
388-
or "lab" in out_formats
389-
or "nnsvs" in out_formats
390-
or "sinsy" in out_formats
391-
):
392-
save_htk(predictions)
393-
if (
394-
"trans" in out_formats
395-
or "transcription" in out_formats
396-
or "transcriptions" in out_formats
397-
or "transcriptions.csv" in out_formats
398-
or "diffsinger" in out_formats
399-
):
400-
save_transcriptions(predictions)
105+
predictions, log = post_processing(predictions)
106+
exporter = Exporter(predictions, log)
401107

402108
if save_confidence:
403-
save_confidence_fn(predictions)
109+
out_formats.append('confidence')
110+
111+
exporter.export(out_formats)
404112

405113
print("Output files are saved to the same folder as the input wav files.")
406114

0 commit comments

Comments
 (0)