|
3 | 3 | import shutil
|
4 | 4 | import tempfile
|
5 | 5 | import tensorflow as tf
|
| 6 | +import time |
6 | 7 | from tensorflow import keras
|
7 | 8 | from keras import layers
|
8 | 9 |
|
9 | 10 | # Set TF log level to INFO
|
10 | 11 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
11 | 12 |
|
12 |
| -# Paths configuration |
13 |
| -TRAIN_CSV_PATH = "data/csv/train.csv" |
14 |
| -TEST_CSV_PATH = "data/csv/test.csv" |
15 |
| -TRAIN_DS_PATH = "data/generated/train" |
16 |
| -TEST_DS_PATH = "data/generated/test" |
17 |
| -MODEL_PATH = "model" |
18 |
| - |
19 |
| -# Maximum vocabulary size |
20 |
| -VOCABULARY_SIZE = 2500 |
21 |
| - |
22 | 13 | # Categories
|
23 | 14 | CATEGORIES = [
|
24 | 15 | "0-unknown",
|
|
32 | 23 | "8-streaming-messaging"
|
33 | 24 | ]
|
34 | 25 |
|
| 26 | +# Datasets configuration |
| 27 | +TRAIN_CSV_PATH = "data/csv/train.csv" |
| 28 | +TEST_CSV_PATH = "data/csv/test.csv" |
| 29 | +TRAIN_DS_PATH = "data/generated/train" |
| 30 | +TEST_DS_PATH = "data/generated/test" |
| 31 | +MODEL_PATH = "model" |
| 32 | +VOCABULARY_SIZE = 2500 |
| 33 | + |
| 34 | +# Model configuration |
| 35 | +LAYER_SIZE = 32 |
| 36 | +LEARNING_RATE = 0.001 |
| 37 | +EPOCHS = 20 |
35 | 38 |
|
36 |
| -def build_model(): |
| 39 | + |
| 40 | +def build_model(use_validation_data=True, save=False): |
37 | 41 | """Build and train model"""
|
38 | 42 |
|
39 |
| - # Load train and test datasets |
40 |
| - (train_ds, test_ds) = load_datasets() |
| 43 | + # Load datasets |
| 44 | + (train_ds, validation_ds, test_ds) = load_datasets(validation_split=use_validation_data) |
41 | 45 |
|
42 | 46 | # Create, train and save model
|
43 | 47 | model = keras.Sequential([
|
44 | 48 | keras.Input(shape=(1,), dtype="string"),
|
45 | 49 | setup_vectorizer(train_ds),
|
46 | 50 | keras.Input(shape=(VOCABULARY_SIZE,)),
|
47 |
| - layers.Dense(32, activation="relu"), |
| 51 | + layers.Dense(LAYER_SIZE, activation="relu"), |
48 | 52 | layers.Dense(len(CATEGORIES), activation="softmax")
|
49 | 53 | ])
|
50 | 54 | model.compile(
|
51 |
| - optimizer="rmsprop", |
| 55 | + optimizer=keras.optimizers.RMSprop(learning_rate=LEARNING_RATE), |
52 | 56 | loss="categorical_crossentropy",
|
53 | 57 | metrics=["accuracy"]
|
54 | 58 | )
|
55 | 59 | model.fit(
|
56 |
| - train_ds.cache(), |
57 |
| - epochs=30 |
| 60 | + train_ds, |
| 61 | + validation_data=validation_ds, |
| 62 | + epochs=EPOCHS |
58 | 63 | )
|
59 |
| - model.save(MODEL_PATH) |
| 64 | + if save: |
| 65 | + model.save(MODEL_PATH) |
60 | 66 |
|
61 |
| - # Evaluate accuracy with test data |
| 67 | + # Evaluate model accuracy with test data |
62 | 68 | print(f"Test accuracy: {model.evaluate(test_ds)[1]:.3f}")
|
63 | 69 |
|
64 | 70 |
|
65 |
| -def load_datasets(): |
66 |
| - """Load train and test datasets""" |
67 |
| - |
68 |
| - # Load train dataset |
69 |
| - train_ds = keras.utils.text_dataset_from_directory( |
70 |
| - TRAIN_DS_PATH, |
71 |
| - label_mode="categorical" |
72 |
| - ) |
| 71 | +def load_datasets(validation_split): |
| 72 | + """Load train, validation and test datasets""" |
| 73 | + |
| 74 | + # Load train and validation datasets |
| 75 | + if validation_split: |
| 76 | + (train_ds, validation_ds) = keras.utils.text_dataset_from_directory( |
| 77 | + TRAIN_DS_PATH, |
| 78 | + label_mode="categorical", |
| 79 | + validation_split=0.2, |
| 80 | + subset="both", |
| 81 | + seed=int(time.time()) |
| 82 | + ) |
| 83 | + else: |
| 84 | + train_ds = keras.utils.text_dataset_from_directory( |
| 85 | + TRAIN_DS_PATH, |
| 86 | + label_mode="categorical" |
| 87 | + ) |
| 88 | + validation_ds = None |
73 | 89 |
|
74 | 90 | # Load test dataset
|
75 | 91 | test_ds = keras.utils.text_dataset_from_directory(
|
76 | 92 | TEST_DS_PATH,
|
77 | 93 | label_mode="categorical"
|
78 | 94 | )
|
79 | 95 |
|
80 |
| - return (train_ds, test_ds) |
| 96 | + return (train_ds, validation_ds, test_ds) |
81 | 97 |
|
82 | 98 |
|
83 | 99 | def setup_vectorizer(train_ds):
|
@@ -159,4 +175,4 @@ def predict(raw_text):
|
159 | 175 |
|
160 | 176 | if __name__ == "__main__":
|
161 | 177 | build_data_trees()
|
162 |
| - build_model() |
| 178 | + build_model(use_validation_data=False, save=False) |
0 commit comments