Skip to content

Commit 95f2f41

Browse files
committed
Update model to detect class imbalance and retrain it
1 parent 226f4b2 commit 95f2f41

File tree

5 files changed

+47
-25
lines changed

5 files changed

+47
-25
lines changed

sqli_model/3/fingerprint.pb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
������ܾ?�ɘ��������月�� �����Ϗ�(����қ��P2
1+
��������=�������ʒ���月�� �����Ϗ�(���������2

sqli_model/3/saved_model.pb

0 Bytes
Binary file not shown.
Binary file not shown.
0 Bytes
Binary file not shown.

training/train_v3.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
import os
23
import pandas as pd
34
import tensorflow as tf
45
from tensorflow.keras.preprocessing.text import Tokenizer
@@ -17,6 +18,7 @@
1718
from tensorflow.keras.callbacks import EarlyStopping
1819
from sklearn.model_selection import KFold
1920
from sklearn.metrics import accuracy_score, precision_score, recall_score
21+
from sklearn.utils.class_weight import compute_class_weight
2022
import numpy as np
2123
import matplotlib.pyplot as plt
2224

@@ -54,11 +56,7 @@ def build_model(input_dim, output_dim=128):
5456
model.compile(
5557
loss="binary_crossentropy",
5658
optimizer="adam",
57-
metrics=[
58-
"accuracy",
59-
tf.keras.metrics.Precision(name="precision"),
60-
tf.keras.metrics.Recall(name="recall"),
61-
],
59+
metrics=["accuracy", tf.keras.metrics.Precision(), tf.keras.metrics.Recall()],
6260
)
6361
return model
6462

@@ -75,31 +73,43 @@ def calculate_f1_f2(precision, recall, beta=1):
7573

7674
def plot_history(history):
7775
"""Plot the training and validation loss, accuracy, precision, and recall."""
76+
available_metrics = history.history.keys() # Check which metrics are available
7877
plt.figure(figsize=(12, 8))
79-
for i, metric in enumerate(["loss", "accuracy", "precision", "recall"], start=1):
80-
plt.subplot(2, 2, i)
81-
plt.plot(history.history[metric], label=f"Training {metric.capitalize()}")
82-
plt.plot(
83-
history.history[f"val_{metric}"], label=f"Validation {metric.capitalize()}"
84-
)
85-
plt.title(metric.capitalize())
86-
plt.xlabel("Epochs")
87-
plt.ylabel(metric.capitalize())
88-
plt.legend()
78+
79+
# Define metrics to plot
80+
metrics_to_plot = ["loss", "accuracy", "precision", "recall"]
81+
for i, metric in enumerate(metrics_to_plot, start=1):
82+
if metric in available_metrics:
83+
plt.subplot(2, 2, i)
84+
plt.plot(history.history[metric], label=f"Training {metric.capitalize()}")
85+
plt.plot(
86+
history.history[f"val_{metric}"],
87+
label=f"Validation {metric.capitalize()}",
88+
)
89+
plt.title(metric.capitalize())
90+
plt.xlabel("Epochs")
91+
plt.ylabel(metric.capitalize())
92+
plt.legend()
93+
8994
plt.tight_layout()
9095
plt.savefig("training_history.png")
9196

9297

93-
# Main function
9498
if __name__ == "__main__":
9599
if len(sys.argv) != 3:
96100
print("Usage: python train.py <input_file> <output_dir>")
97101
sys.exit(1)
98102

103+
# Constants
104+
MAX_WORDS = 10000
105+
MAX_LEN = 100
106+
EPOCHS = 50
107+
BATCH_SIZE = 32
108+
99109
# Load and preprocess data
100110
data = load_data(sys.argv[1])
101111
X, tokenizer = preprocess_text(data)
102-
y = data["Label"]
112+
y = data["Label"].values # Convert to NumPy array to avoid KeyError in KFold
103113

104114
# Initialize cross-validation
105115
k_folds = 5
@@ -111,7 +121,13 @@ def plot_history(history):
111121

112122
# Split the data
113123
X_train, X_val = X[train_idx], X[val_idx]
114-
y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
124+
y_train, y_val = y[train_idx], y[val_idx]
125+
126+
# Compute class weights to handle imbalance
127+
class_weights = compute_class_weight(
128+
"balanced", classes=np.unique(y_train), y=y_train
129+
)
130+
class_weight_dict = {i: class_weights[i] for i in range(len(class_weights))}
115131

116132
# Build and train the model
117133
model = build_model(input_dim=len(tokenizer.word_index) + 1)
@@ -121,15 +137,16 @@ def plot_history(history):
121137
history = model.fit(
122138
X_train,
123139
y_train,
124-
epochs=50,
125-
batch_size=32,
140+
epochs=EPOCHS,
141+
batch_size=BATCH_SIZE,
126142
validation_data=(X_val, y_val),
143+
class_weight=class_weight_dict,
127144
callbacks=[early_stopping],
128145
verbose=1,
129146
)
130147

131-
# Make predictions to manually calculate metrics
132-
y_val_pred = (model.predict(X_val) > 0.5).astype(int)
148+
# Make predictions to calculate metrics
149+
y_val_pred = (model.predict(X_val) > 0.8).astype(int)
133150
accuracy = accuracy_score(y_val, y_val_pred)
134151
precision = precision_score(y_val, y_val_pred)
135152
recall = recall_score(y_val, y_val_pred)
@@ -143,12 +160,17 @@ def plot_history(history):
143160
fold_metrics["f1"].append(f1_score)
144161
fold_metrics["f2"].append(f2_score)
145162

146-
# Calculate average metrics across folds
163+
# Calculate and display average metrics across folds
147164
avg_metrics = {metric: np.mean(scores) for metric, scores in fold_metrics.items()}
148165
print("\nCross-validation results:")
149166
for metric, value in avg_metrics.items():
150167
print(f"{metric.capitalize()}: {value:.2f}")
151168

152169
# Save the final model trained on the last fold
153-
model.export(sys.argv[2])
170+
output_dir = sys.argv[2]
171+
if not os.path.exists(output_dir):
172+
os.makedirs(output_dir)
173+
model.export(output_dir)
174+
175+
# Plot training history of the last fold
154176
plot_history(history)

0 commit comments

Comments
 (0)