-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
executable file
·140 lines (110 loc) · 3.39 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from sklearn.utils import validation
from tensorflow.python.keras.engine import training
from utils import get_dataset, create_model
import tensorflow as tf
import random
import argparse
import warnings
# This is required to avoid librosa's warning about n_fft being too large. I don't know how to fix the issue the warning is trying to fix since
# Specifying a smaller n_ftt doesn't seem to fix it.
warnings.filterwarnings("ignore", category=UserWarning)
# Defaults
EPOCHS = None
BATCH_SIZE = None
model_type = None
NUM_LABELS = None
# Argument Parser
help_message = "Check the documentation available at https://www.github.com/AnkushMalaker/speech-emotion-recognition for more info on usage."
parser = argparse.ArgumentParser(description=help_message)
parser.add_argument("epochs", type=int, help="Specify number of epochs")
parser.add_argument(
"-B",
"--batch_size",
type=int,
help="Default batch size is 32. Reduce this if the data doesn't fit in your GPU.",
)
parser.add_argument(
"-C",
"--cache",
action="store_true",
help="Default behaviour is to not use cahce. Caching greatly speeds up the training after 1 epoch but may require a lot of Memory.",
)
parser.add_argument(
"-LR", "--learning_rate", type=float, help="Default Learning rate is 1e-5."
)
parser.add_argument("--train_dir", help="Default data directory is ./train_data")
parser.add_argument(
"--val_dir",
help="Default behaviour is to take given split from train_dir to do validation. Specify the split using --val_split",
)
parser.add_argument(
"--val_split", type=float, help="Default val_split is 0.2 of train data"
)
parser.add_argument(
"--model_type",
help='Specifies the specific architecture to be used. Check README for more info. Defaults to "Ravdess".',
)
# This logic has to be improved to adapt to other datasets mentioned in the paper.
# Or the whole code has to be split into different sections for each, like train_emodb.py, train_xyz,... etc.
# Currently not handling this argument
parser.add_argument("--num_labels", help="Specify number of labels")
parser.add_argument(
"--random_state",
type=int,
help="Specify random state for consistency in experiments. Use -1 to randomize.",
)
args = parser.parse_args()
EPOCHS = args.epochs
print(EPOCHS)
if args.cache:
CACHE = True
else:
CACHE = False
if args.batch_size:
BATCH_SIZE = args.batch_size
else:
BATCH_SIZE = 32
if args.train_dir:
train_dir = args.train_dir
else:
train_dir = "./train_data"
if args.val_dir:
val_dir = args.val_dir
else:
val_dir = None
if args.random_state == -1:
RANDOM_STATE = random.randint(0, 10000)
else:
RANDOM_STATE = 42
if args.val_split:
val_split = args.val_split
else:
val_split = 0.2
if args.model_type:
model_type = model_type
else:
model_type = "ravdess"
if args.num_labels:
NUM_LABELS = args.num_labels
else:
NUM_LABELS = 8
train_ds, val_ds = get_dataset(
training_dir=train_dir,
validation_dir=val_dir,
val_split=val_split,
batch_size=BATCH_SIZE,
random_state=RANDOM_STATE,
cache=CACHE,
)
model = create_model(NUM_LABELS, model_type)
ESCallback = tf.keras.callbacks.EarlyStopping(
patience=2, restore_best_weights=True, verbose=3
)
# Add checkpoint callback
model.fit(
train_ds,
validation_data=val_ds,
callbacks=ESCallback,
epochs=EPOCHS,
)
model.save(f"saved_model/{EPOCHS}_trained_model")