forked from noahchalifour/rnnt-speech-recognition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtranscribe_file.py
59 lines (38 loc) · 1.42 KB
/
transcribe_file.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from argparse import ArgumentParser
import os
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
tf.autograph.set_verbosity(0)
from utils import preprocessing, encoding, decoding
from utils import model as model_utils
from model import build_keras_model
from hparams import *
def main(args):
model_dir = os.path.dirname(os.path.realpath(args.checkpoint))
hparams = model_utils.load_hparams(model_dir)
encode_fn, tok_to_text, vocab_size = encoding.get_encoder(
encoder_dir=model_dir,
hparams=hparams)
hparams[HP_VOCAB_SIZE.name] = vocab_size
model = build_keras_model(hparams)
model.load_weights(args.checkpoint)
audio, sr = preprocessing.tf_load_audio(args.input)
log_melspec = preprocessing.preprocess_audio(
audio=audio,
sample_rate=sr,
hparams=hparams)
log_melspec = tf.expand_dims(log_melspec, axis=0)
decoder_fn = decoding.greedy_decode_fn(model, hparams)
decoded = decoder_fn(log_melspec)[0]
transcription = tok_to_text(decoded)
print('Transcription:', transcription.numpy().decode('utf8'))
def parse_args():
ap = ArgumentParser()
ap.add_argument('--checkpoint', type=str, required=True,
help='Checkpoint to load.')
ap.add_argument('-i', '--input', type=str, required=True,
help='Wav file to transcribe.')
return ap.parse_args()
if __name__ == '__main__':
args = parse_args()
main(args)