Skip to content

Commit

Permalink
Merge pull request #38 from TailUFPB/cancel-button
Browse files Browse the repository at this point in the history
Cancel button
  • Loading branch information
tahaluh authored May 18, 2024
2 parents 7a9214e + 448c8d0 commit f8d7e20
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 44 deletions.
94 changes: 53 additions & 41 deletions api/Neural_Network2.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,13 @@ def train_epoch(model, optimizer, scheduler, train_loader, criterion, curr_epoch
for inputs, target, text in progress_bar:
target = target.to(device)

# Verifica se o cancelamento foi solicitado a cada batch
with open('training_progress.json', 'r') as f:
data = json.load(f)
if data.get('cancel_requested', False):
print("Training canceled during epoch:", curr_epoch)
return total_loss / max(total, 1), True # Retorna a perda média e o status de cancelamento

# Clean old gradients
optimizer.zero_grad()

Expand All @@ -270,15 +277,17 @@ def train_epoch(model, optimizer, scheduler, train_loader, criterion, curr_epoch
total += len(target)
num_iters += 1
if num_iters % 20 == 0:
with open('training_progress.json', 'w') as f:
progress = 100 * (curr_epoch + num_iters/len(train_loader)) / num_total_epochs
training_progress = {
with open('training_progress.json', 'r+') as f:
progress = 100 * (curr_epoch + num_iters / len(train_loader)) / num_total_epochs
data.update({
'training_progress': progress,
'training_in_progress': True
}
json.dump(training_progress, f)
})
f.seek(0)
json.dump(data, f)
f.truncate()

return total_loss / total
return total_loss / max(total, 1), False

def validate_epoch(model, valid_loader, criterion):
model.eval()
Expand All @@ -300,8 +309,9 @@ def validate_epoch(model, valid_loader, criterion):

return total_loss / total

def create_and_train_model(df, name, epochs = 10, batch_size = 32, learning_rate = 0.001):

def create_and_train_model(df, name, epochs=10, batch_size=32, learning_rate=0.001):
# Configurações iniciais e preparações do modelo
dropout_probability = 0.2
n_rnn_layers = 1
embedding_dimension = 128
Expand All @@ -312,20 +322,19 @@ def create_and_train_model(df, name, epochs = 10, batch_size = 32, learning_rate
valid_ratio = 0.05
test_ratio = 0.05

# Preparação do dataset
dataset = CustomDataset(df, max_vocab, max_len, name)

train_dataset, valid_dataset, test_dataset = split_train_valid_test(
dataset, valid_ratio=valid_ratio, test_ratio=test_ratio)
len(train_dataset), len(valid_dataset), len(test_dataset)


# Preparação dos DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate)


# Inicialização do modelo
model = RNNClassifier(
output_size=len(df.labels),
output_size=len(df['labels'].unique()),
hidden_size=hidden_size,
embedding_dimension=embedding_dimension,
vocab_size=len(dataset.token2idx),
Expand All @@ -334,46 +343,49 @@ def create_and_train_model(df, name, epochs = 10, batch_size = 32, learning_rate
bidirectional=is_bidirectional,
n_layers=n_rnn_layers,
device=device,
batch_size=batch_size,
batch_size=batch_size
)
model = model.to(device)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=learning_rate,
)
scheduler = CosineAnnealingLR(optimizer, 1)
# Definição da função de perda e otimizador
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1)

n_epochs = 0
train_losses, valid_losses = [], []
canceled = False
for curr_epoch in range(epochs):
train_loss = train_epoch(model, optimizer, scheduler, train_loader, criterion, curr_epoch, epochs)
valid_loss = validate_epoch(model, valid_loader, criterion)
train_loss, canceled = train_epoch(model, optimizer, scheduler, train_loader, criterion, curr_epoch, epochs)
if canceled:
print(f"Training canceled during epoch {curr_epoch + 1}")
break

valid_loss = validate_epoch(model, valid_loader, criterion)
tqdm.write(
f'epoch #{n_epochs + 1:3d}\ttrain_loss: {train_loss:.2e}'
f'\tvalid_loss: {valid_loss:.2e}\n',
f'Epoch #{curr_epoch + 1:3d}\ttrain_loss: {train_loss:.2e}'
f'\tvalid_loss: {valid_loss:.2e}'
)

# Early stopping if the current valid_loss is greater than the last three valid losses
if len(valid_losses) > 2 and all(valid_loss >= loss
for loss in valid_losses[-3:]):
print('Stopping early')
if len(valid_losses) > 2 and all(valid_loss >= loss for loss in valid_losses[-3:]):
print('Stopping early due to lack of improvement in validation loss.')
break

train_losses.append(train_loss)
valid_losses.append(valid_loss)

n_epochs += 1

model_path = os.path.join('api', 'models', name)
os.makedirs(os.path.dirname(model_path), exist_ok=True)
torch.save(model, model_path)

training_progress = {
'training_progress': 0,
'training_in_progress': True
}
with open('training_progress.json', 'w') as file:
json.dump(training_progress, file)
# Finalizar e salvar o modelo se não foi cancelado
if not canceled:
model_path = os.path.join('api', 'models', name)
os.makedirs(os.path.dirname(model_path), exist_ok=True)
torch.save(model.state_dict(), model_path)

# Atualizar e salvar o estado de treinamento final
training_progress = {
'training_progress': 100,
'training_in_progress': False,
'cancel_requested': False
}
with open('training_progress.json', 'w') as file:
json.dump(training_progress, file)

print("Training complete.")
21 changes: 18 additions & 3 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def train_model():
# reseta status
training_progress = {
'training_progress': 0,
'training_in_progress': True
'training_in_progress': True,
'cancel_requested': False
}
with open('training_progress.json', 'w') as file:
json.dump(training_progress, file)
Expand Down Expand Up @@ -120,9 +121,23 @@ def get_training_status():
return jsonify({'training_in_progress': True, 'training_progress': 0})
training_status = data.get('training_in_progress', False)
progress = data.get('training_progress', 0)
return jsonify({'training_in_progress': training_status, 'training_progress': progress})
cancel_request = data.get('cancel_requested', False)
return jsonify({'training_in_progress': training_status, 'training_progress': progress, 'cancel_requested': cancel_request})
except FileNotFoundError:
return jsonify({'training_in_progress': False, 'training_progress': 0})
return jsonify({'training_in_progress': False, 'training_progress': 0, 'cancel_requested': False})

@app.route('/cancel-training', methods=['POST'])
def cancel_training():
try:
with open('training_progress.json', 'r+') as file:
data = json.load(file)
data['cancel_requested'] = True
file.seek(0)
json.dump(data, file)
file.truncate()
return jsonify({'message': 'Cancellation requested.'}), 200
except Exception as e:
return jsonify({'error': str(e)}), 500


if __name__ == '__main__':
Expand Down
23 changes: 23 additions & 0 deletions src/pages/train.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export default function Train() {
const [selectedLabel, setSelectedLabel] = useState<number>(0);

const [isLoading, setIsLoading] = useState(false);
const [isCancelling, setIsCancelling] = useState(false);

const handleChangeSelectedColumn = (event: any) => {
setSelectedColumn(event.target.value);
Expand All @@ -23,6 +24,18 @@ export default function Train() {
setSelectedLabel(event.target.value);
};

const handleCancelTraining = async () => {
setIsCancelling(true); // Ativa o estado de cancelamento
try {
await axios.post('http://localhost:5000/cancel-training');
alert('Treinamento cancelado com sucesso!');
} catch (error) {
console.error('Erro ao cancelar o treinamento:', error);
alert('Falha ao cancelar o treinamento.');
}
setIsCancelling(false); // Desativa o estado de cancelamento
};

const handleSubmit = async () => {
setIsLoading(true);
setLoadingProgress(0);
Expand Down Expand Up @@ -331,6 +344,16 @@ export default function Train() {
>
{isLoading ? "Carregando..." : "Treinar"}
</button>

{isLoading && (
<button
className="mt-3 bg-red-500 hover:bg-red-700 text-white font-bold py-2 px-4 rounded-lg"
onClick={handleCancelTraining}
disabled={isCancelling}
>
{isCancelling ? 'Cancelando...' : 'Cancelar Treinamento'}
</button>
)}
</div>
</div>
</>
Expand Down

0 comments on commit f8d7e20

Please sign in to comment.