Skip to content

Commit 64efb57

Browse files
authored
Improve package category ML model (#2809)
Signed-off-by: Sergio Castaño Arteaga <[email protected]>
1 parent 35097b2 commit 64efb57

File tree

5 files changed

+45
-254
lines changed

5 files changed

+45
-254
lines changed

ml/category/category.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,13 @@
33
import shutil
44
import tempfile
55
import tensorflow as tf
6+
import time
67
from tensorflow import keras
78
from keras import layers
89

910
# Set TF log level to INFO
1011
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
1112

12-
# Paths configuration
13-
TRAIN_CSV_PATH = "data/csv/train.csv"
14-
TEST_CSV_PATH = "data/csv/test.csv"
15-
TRAIN_DS_PATH = "data/generated/train"
16-
TEST_DS_PATH = "data/generated/test"
17-
MODEL_PATH = "model"
18-
19-
# Maximum vocabulary size
20-
VOCABULARY_SIZE = 2500
21-
2213
# Categories
2314
CATEGORIES = [
2415
"0-unknown",
@@ -32,52 +23,77 @@
3223
"8-streaming-messaging"
3324
]
3425

26+
# Datasets configuration
27+
TRAIN_CSV_PATH = "data/csv/train.csv"
28+
TEST_CSV_PATH = "data/csv/test.csv"
29+
TRAIN_DS_PATH = "data/generated/train"
30+
TEST_DS_PATH = "data/generated/test"
31+
MODEL_PATH = "model"
32+
VOCABULARY_SIZE = 2500
33+
34+
# Model configuration
35+
LAYER_SIZE = 32
36+
LEARNING_RATE = 0.001
37+
EPOCHS = 20
3538

36-
def build_model():
39+
40+
def build_model(use_validation_data=True, save=False):
3741
"""Build and train model"""
3842

39-
# Load train and test datasets
40-
(train_ds, test_ds) = load_datasets()
43+
# Load datasets
44+
(train_ds, validation_ds, test_ds) = load_datasets(validation_split=use_validation_data)
4145

4246
# Create, train and save model
4347
model = keras.Sequential([
4448
keras.Input(shape=(1,), dtype="string"),
4549
setup_vectorizer(train_ds),
4650
keras.Input(shape=(VOCABULARY_SIZE,)),
47-
layers.Dense(32, activation="relu"),
51+
layers.Dense(LAYER_SIZE, activation="relu"),
4852
layers.Dense(len(CATEGORIES), activation="softmax")
4953
])
5054
model.compile(
51-
optimizer="rmsprop",
55+
optimizer=keras.optimizers.RMSprop(learning_rate=LEARNING_RATE),
5256
loss="categorical_crossentropy",
5357
metrics=["accuracy"]
5458
)
5559
model.fit(
56-
train_ds.cache(),
57-
epochs=30
60+
train_ds,
61+
validation_data=validation_ds,
62+
epochs=EPOCHS
5863
)
59-
model.save(MODEL_PATH)
64+
if save:
65+
model.save(MODEL_PATH)
6066

61-
# Evaluate accuracy with test data
67+
# Evaluate model accuracy with test data
6268
print(f"Test accuracy: {model.evaluate(test_ds)[1]:.3f}")
6369

6470

65-
def load_datasets():
66-
"""Load train and test datasets"""
67-
68-
# Load train dataset
69-
train_ds = keras.utils.text_dataset_from_directory(
70-
TRAIN_DS_PATH,
71-
label_mode="categorical"
72-
)
71+
def load_datasets(validation_split):
72+
"""Load train, validation and test datasets"""
73+
74+
# Load train and validation datasets
75+
if validation_split:
76+
(train_ds, validation_ds) = keras.utils.text_dataset_from_directory(
77+
TRAIN_DS_PATH,
78+
label_mode="categorical",
79+
validation_split=0.2,
80+
subset="both",
81+
seed=int(time.time())
82+
)
83+
else:
84+
train_ds = keras.utils.text_dataset_from_directory(
85+
TRAIN_DS_PATH,
86+
label_mode="categorical"
87+
)
88+
validation_ds = None
7389

7490
# Load test dataset
7591
test_ds = keras.utils.text_dataset_from_directory(
7692
TEST_DS_PATH,
7793
label_mode="categorical"
7894
)
7995

80-
return (train_ds, test_ds)
96+
return (train_ds, validation_ds, test_ds)
8197

8298

8399
def setup_vectorizer(train_ds):
@@ -159,4 +175,4 @@ def predict(raw_text):
159175

160176
if __name__ == "__main__":
161177
build_data_trees()
162-
build_model()
178+
build_model(use_validation_data=False, save=False)

0 commit comments

Comments
 (0)