From a4cf82bda02f2e16ff65aff01d3f1fe71236b70c Mon Sep 17 00:00:00 2001 From: jonasgabriel18 Date: Fri, 7 Jun 2024 11:54:16 -0300 Subject: [PATCH] fixed loadind model as ordered dict --- api/Neural_Network2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/api/Neural_Network2.py b/api/Neural_Network2.py index 80440fe8..caa027a1 100644 --- a/api/Neural_Network2.py +++ b/api/Neural_Network2.py @@ -18,6 +18,7 @@ from torch.utils.data import Dataset, DataLoader from torch.utils.data.dataset import random_split from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +from torchtext.vocab import build_vocab_from_iterator tqdm.pandas() @@ -332,7 +333,7 @@ def create_and_train_rnn_model(df, name, epochs = 10, batch_size = 32, learning_ # Preparação dos DataLoader train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate) - test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate) + #test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate) # Inicialização do modelo model = RNNClassifier( @@ -391,7 +392,8 @@ def create_and_train_rnn_model(df, name, epochs = 10, batch_size = 32, learning_ if not canceled: model_path = os.path.join('api', 'models', name) os.makedirs(os.path.dirname(model_path), exist_ok=True) - torch.save(model.state_dict(), model_path) + #torch.save(model.state_dict(), model_path) + torch.save(model, model_path) # Atualizar e salvar o estado de treinamento final training_progress = {