Skip to content

Commit

Permalink
fixed loadind model as ordered dict
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasgabriel18 authored and cmaloney111 committed Jun 8, 2024
1 parent faeba09 commit a4cf82b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions api/Neural_Network2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torchtext.vocab import build_vocab_from_iterator

tqdm.pandas()

Expand Down Expand Up @@ -332,7 +333,7 @@ def create_and_train_rnn_model(df, name, epochs = 10, batch_size = 32, learning_
# Preparação dos DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate)
#test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate)

# Inicialização do modelo
model = RNNClassifier(
Expand Down Expand Up @@ -391,7 +392,8 @@ def create_and_train_rnn_model(df, name, epochs = 10, batch_size = 32, learning_
if not canceled:
model_path = os.path.join('api', 'models', name)
os.makedirs(os.path.dirname(model_path), exist_ok=True)
torch.save(model.state_dict(), model_path)
#torch.save(model.state_dict(), model_path)
torch.save(model, model_path)

# Atualizar e salvar o estado de treinamento final
training_progress = {
Expand Down

0 comments on commit a4cf82b

Please sign in to comment.