1
1
import sys
2
+ import os
2
3
import pandas as pd
3
4
import tensorflow as tf
4
5
from tensorflow .keras .preprocessing .text import Tokenizer
17
18
from tensorflow .keras .callbacks import EarlyStopping
18
19
from sklearn .model_selection import KFold
19
20
from sklearn .metrics import accuracy_score , precision_score , recall_score
21
+ from sklearn .utils .class_weight import compute_class_weight
20
22
import numpy as np
21
23
import matplotlib .pyplot as plt
22
24
@@ -54,11 +56,7 @@ def build_model(input_dim, output_dim=128):
54
56
model .compile (
55
57
loss = "binary_crossentropy" ,
56
58
optimizer = "adam" ,
57
- metrics = [
58
- "accuracy" ,
59
- tf .keras .metrics .Precision (name = "precision" ),
60
- tf .keras .metrics .Recall (name = "recall" ),
61
- ],
59
+ metrics = ["accuracy" , tf .keras .metrics .Precision (), tf .keras .metrics .Recall ()],
62
60
)
63
61
return model
64
62
@@ -75,31 +73,43 @@ def calculate_f1_f2(precision, recall, beta=1):
75
73
76
74
def plot_history (history ):
77
75
"""Plot the training and validation loss, accuracy, precision, and recall."""
76
+ available_metrics = history .history .keys () # Check which metrics are available
78
77
plt .figure (figsize = (12 , 8 ))
79
- for i , metric in enumerate (["loss" , "accuracy" , "precision" , "recall" ], start = 1 ):
80
- plt .subplot (2 , 2 , i )
81
- plt .plot (history .history [metric ], label = f"Training { metric .capitalize ()} " )
82
- plt .plot (
83
- history .history [f"val_{ metric } " ], label = f"Validation { metric .capitalize ()} "
84
- )
85
- plt .title (metric .capitalize ())
86
- plt .xlabel ("Epochs" )
87
- plt .ylabel (metric .capitalize ())
88
- plt .legend ()
78
+
79
+ # Define metrics to plot
80
+ metrics_to_plot = ["loss" , "accuracy" , "precision" , "recall" ]
81
+ for i , metric in enumerate (metrics_to_plot , start = 1 ):
82
+ if metric in available_metrics :
83
+ plt .subplot (2 , 2 , i )
84
+ plt .plot (history .history [metric ], label = f"Training { metric .capitalize ()} " )
85
+ plt .plot (
86
+ history .history [f"val_{ metric } " ],
87
+ label = f"Validation { metric .capitalize ()} " ,
88
+ )
89
+ plt .title (metric .capitalize ())
90
+ plt .xlabel ("Epochs" )
91
+ plt .ylabel (metric .capitalize ())
92
+ plt .legend ()
93
+
89
94
plt .tight_layout ()
90
95
plt .savefig ("training_history.png" )
91
96
92
97
93
- # Main function
94
98
if __name__ == "__main__" :
95
99
if len (sys .argv ) != 3 :
96
100
print ("Usage: python train.py <input_file> <output_dir>" )
97
101
sys .exit (1 )
98
102
103
+ # Constants
104
+ MAX_WORDS = 10000
105
+ MAX_LEN = 100
106
+ EPOCHS = 50
107
+ BATCH_SIZE = 32
108
+
99
109
# Load and preprocess data
100
110
data = load_data (sys .argv [1 ])
101
111
X , tokenizer = preprocess_text (data )
102
- y = data ["Label" ]
112
+ y = data ["Label" ]. values # Convert to NumPy array to avoid KeyError in KFold
103
113
104
114
# Initialize cross-validation
105
115
k_folds = 5
@@ -111,7 +121,13 @@ def plot_history(history):
111
121
112
122
# Split the data
113
123
X_train , X_val = X [train_idx ], X [val_idx ]
114
- y_train , y_val = y .iloc [train_idx ], y .iloc [val_idx ]
124
+ y_train , y_val = y [train_idx ], y [val_idx ]
125
+
126
+ # Compute class weights to handle imbalance
127
+ class_weights = compute_class_weight (
128
+ "balanced" , classes = np .unique (y_train ), y = y_train
129
+ )
130
+ class_weight_dict = {i : class_weights [i ] for i in range (len (class_weights ))}
115
131
116
132
# Build and train the model
117
133
model = build_model (input_dim = len (tokenizer .word_index ) + 1 )
@@ -121,15 +137,16 @@ def plot_history(history):
121
137
history = model .fit (
122
138
X_train ,
123
139
y_train ,
124
- epochs = 50 ,
125
- batch_size = 32 ,
140
+ epochs = EPOCHS ,
141
+ batch_size = BATCH_SIZE ,
126
142
validation_data = (X_val , y_val ),
143
+ class_weight = class_weight_dict ,
127
144
callbacks = [early_stopping ],
128
145
verbose = 1 ,
129
146
)
130
147
131
- # Make predictions to manually calculate metrics
132
- y_val_pred = (model .predict (X_val ) > 0.5 ).astype (int )
148
+ # Make predictions to calculate metrics
149
+ y_val_pred = (model .predict (X_val ) > 0.8 ).astype (int )
133
150
accuracy = accuracy_score (y_val , y_val_pred )
134
151
precision = precision_score (y_val , y_val_pred )
135
152
recall = recall_score (y_val , y_val_pred )
@@ -143,12 +160,17 @@ def plot_history(history):
143
160
fold_metrics ["f1" ].append (f1_score )
144
161
fold_metrics ["f2" ].append (f2_score )
145
162
146
- # Calculate average metrics across folds
163
+ # Calculate and display average metrics across folds
147
164
avg_metrics = {metric : np .mean (scores ) for metric , scores in fold_metrics .items ()}
148
165
print ("\n Cross-validation results:" )
149
166
for metric , value in avg_metrics .items ():
150
167
print (f"{ metric .capitalize ()} : { value :.2f} " )
151
168
152
169
# Save the final model trained on the last fold
153
- model .export (sys .argv [2 ])
170
+ output_dir = sys .argv [2 ]
171
+ if not os .path .exists (output_dir ):
172
+ os .makedirs (output_dir )
173
+ model .export (output_dir )
174
+
175
+ # Plot training history of the last fold
154
176
plot_history (history )
0 commit comments