Skip to content

Commit

Permalink
brain
Browse files Browse the repository at this point in the history
  • Loading branch information
cmaloney111 committed Apr 1, 2024
1 parent 2227d99 commit 2f386c3
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
16 changes: 13 additions & 3 deletions api/Neural_Network2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@ def __init__(self):
self.batch_count = 0

def on_batch_end(self, batch, logs=None):
print("batch end")
self.batch_count += 1
if self.batch_count % 50 == 0:
self.update_progress(logs)

def on_epoch_end(self, epoch, logs=None):
print("epoch end")
self.update_progress(logs)

def update_progress(self, logs):
print("updating progress")
total_epochs = self.params['epochs']
current_batch = self.model._train_counter
total_batches = self.params['steps'] * total_epochs
Expand Down Expand Up @@ -99,12 +102,20 @@ def create_and_train_model(train_texts, train_labels, name, epochs=5, batch_size
tf.keras.layers.Dense(num_classes, activation='softmax')
])


model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

try:
progress_callback = TrainingProgressCallback()

history = model.fit(train_dataset, epochs=epochs, batch_size=batch_size)

print("train")
print(train_dataset)
print("epochs")
print(epochs)
print("batch_size")
print(batch_size)
# history = model.fit(train_dataset, epochs=epochs, batch_size=batch_size, verbose=2, callbacks=[progress_callback])
history = model.fit(train_dataset, epochs=epochs, batch_size=batch_size, verbose=2)

model_filename = f"api/models/{str(num_classes)}-Trained-Model-{name}.weights.h5"
model.save_weights(model_filename)
Expand All @@ -119,4 +130,3 @@ def create_and_train_model(train_texts, train_labels, name, epochs=5, batch_size
except Exception as e:
return f"Error during model creation/training: {str(e)}"


10 changes: 8 additions & 2 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
import nltk
import json
import asyncio
import logging
nltk.download('wordnet')


# log = logging.getLogger('werkzeug')
# log.setLevel(logging.ERROR)

app = Flask(__name__)
server_thread = None
CORS(app) # Permite todas as origens por padrão (não recomendado para produção)
Expand Down Expand Up @@ -83,12 +88,13 @@ def train_model():

# reseta status
training_progress = {
'training_progress': 1,
'training_progress': 0,
'training_in_progress': True
}
with open('training_progress.json', 'w') as file:
json.dump(training_progress, file)

print("Beginning training")
create_and_train_model(selected_data, selected_label, name, epochs, batch_size)

return jsonify({"message": "Model train started successfully."}), 200
Expand All @@ -113,6 +119,6 @@ def get_training_status():
#shutdown_server()

if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
app.run(host='127.0.0.1', port=5000, debug=True)
#server_thread = threading.Thread(target=run_flask_app)
#server_thread.start()
36 changes: 18 additions & 18 deletions src/pages/train.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export default function Train() {

const handleSubmit = async () => {
setIsLoading(true);
setLoadingProgress(1);
setLoadingProgress(0);

let selectedData = data.map((row) => ({
value: row[selectedColumn],
Expand All @@ -51,23 +51,23 @@ export default function Train() {
let retryCount = 0;


const url = "http://localhost:5000/neural-network";
// const url = "http://localhost:5000/neural-network";


async function postData(url: string, data: { data: any[]; label: any[]; batch_size: number; epochs: number; learning_rate: number; name: string; }) {
try {
const response = await axios.post(url, data);
} catch (error) {
if (retryCount < maxRetries) {
retryCount++;
console.error(`Error occurred, retrying (attempt ${retryCount})...`);
postData(url, data); // Retry recursively
} else {
console.error("Max retry limit reached. Unable to post data.");
throw error; // Throw the error after maximum retries
}
}
}
// async function postData(url: string, data: { data: any[]; label: any[]; batch_size: number; epochs: number; learning_rate: number; name: string; }) {
// try {
// const response = await axios.post(url, data);
// } catch (error) {
// if (retryCount < maxRetries) {
// retryCount++;
// console.error(`Error occurred, retrying (attempt ${retryCount})...`);
// postData(url, data); // Retry recursively
// } else {
// console.error("Max retry limit reached. Unable to post data.");
// throw error; // Throw the error after maximum retries
// }
// }
// }

await axios
.post("http://localhost:5000/neural-network", sendData)
Expand Down Expand Up @@ -141,7 +141,7 @@ export default function Train() {
setLoadingProgress(
training_in_progress || training_progress === 100
? training_progress
: 1
: 0
);
} catch (error) {
console.error("Erro ao buscar progresso:", error);
Expand Down Expand Up @@ -341,4 +341,4 @@ export default function Train() {
</div>
</div>
);
}
}

0 comments on commit 2f386c3

Please sign in to comment.