Skip to content

Commit

Permalink
Updated dependencies, Stochastic Weight Averaging
Browse files Browse the repository at this point in the history
  • Loading branch information
fgdfgfthgr-fox committed Nov 13, 2024
1 parent 395d3dd commit 0421c24
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 26 deletions.
29 changes: 14 additions & 15 deletions pl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch.utils.data
import time
import torch.utils.tensorboard
from sympy.core.evalf import fastlog

from Components import DataComponents
from Components import Metrics
Expand Down Expand Up @@ -309,24 +308,24 @@ def on_test_epoch_end(self):

if __name__ == "__main__":

val_dataset = DataComponents.ValDataset("Datasets/val", 256, 64, False, "Augmentation Parameters.csv")
predict_dataset = DataComponents.Predict_Dataset("Datasets/predict", 232, 24, 12, 4, True)
#val_dataset = DataComponents.ValDataset("Datasets/val", 96, 96, False, "Augmentation Parameters.csv")
predict_dataset = DataComponents.Predict_Dataset("Datasets/predict", 160, 56, 16, 4, True)
train_dataset_pos = DataComponents.TrainDataset("Datasets/train", "Augmentation Parameters.csv",
64,
256, 64, True, False, 0,
32,
192, 64, False, False, 0,
0,
1, 'positive')
train_dataset_neg = DataComponents.TrainDataset("Datasets/train", "Augmentation Parameters.csv",
64,
256, 64, True, False, 0,
32,
192, 64, False, False, 0,
0,
1, 'negative')
train_label_mean = train_dataset_pos.get_label_mean()
train_contour_mean = train_dataset_pos.get_contour_mean()
#train_contour_mean = train_dataset_pos.get_contour_mean()
unsupervised_train_dataset = DataComponents.UnsupervisedDataset("Datasets/unsupervised_train",
"Augmentation Parameters.csv",
128,
256, 64)
64,
192, 64)
train_dataset = DataComponents.CollectedDataset(train_dataset_pos, train_dataset_neg, unsupervised_train_dataset)
#train_dataset = DataComponents.CollectedDataset(train_dataset_pos, train_dataset_neg)
sampler = DataComponents.CollectedSampler(train_dataset, 2, unsupervised_train_dataset)
Expand All @@ -337,22 +336,22 @@ def on_test_epoch_end(self):
num_workers=8, pin_memory=True, persistent_workers=True)
meta_info = predict_dataset.__getmetainfo__()
predict_loader = torch.utils.data.DataLoader(dataset=predict_dataset, batch_size=1, num_workers=0)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=1)
#val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=1)
#model_checkpoint = pl.callbacks.ModelCheckpoint(dirpath="", filename="{epoch}-{Val_epoch_dice:.2f}", mode="max", save_weights_only=True)
arch_args = ('InstanceResidual', 8, 4, 1, True, train_label_mean, train_contour_mean)
arch_args = ('UNetResidualBottleneck', 32, 4, 1, True, train_label_mean, torch.tensor(0.5))
model = PLModule(arch_args,
True, True, 'Datasets/mid_visualiser/Ts-4c_visualiser.tif', True,
False, False, 'Datasets/mid_visualiser/Ts-4c_visualiser.tif', False,
False, False, False, True)
model_checkpoint = pl.callbacks.ModelCheckpoint(dirpath="", filename="test", mode="max",
monitor="Val_epoch_dice", save_weights_only=True, enable_version_counter=False)
trainer = pl.Trainer(max_epochs=5, log_every_n_steps=1, logger=TensorBoardLogger(f'lightning_logs', name='test'),
accelerator="gpu", enable_checkpointing=True, gradient_clip_val=0.3,
precision="bf16-mixed", enable_progress_bar=True, num_sanity_val_steps=0, callbacks=[model_checkpoint,])
precision="bf16-mixed", enable_progress_bar=True, num_sanity_val_steps=0)#, callbacks=[model_checkpoint,])
#FineTuneLearningRateFinder(min_lr=0.00001, max_lr=0.1, attr_name='initial_lr')])
# print(subprocess.run("tensorboard --logdir='lightning_logs'", shell=True))
start_time = time.time()
trainer.fit(model,
val_dataloaders=val_loader,
#val_dataloaders=val_loader,
train_dataloaders=train_loader)
model = PLModule.load_from_checkpoint('test.ckpt')
'''predictions = trainer.predict(model, predict_loader)
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ pyimagej==1.4.1
scikit-image==0.21.0
matplotlib==3.8.0
lightning==2.3.3
gradio==5.0.2
gradio==5.5.0
h5py==3.12.1
scipy==1.11.2
tensorboard==2.15.1
joblib==1.3.2
opencv-python==4.10.0.82
imagecodecs==2024.1.1
imagecodecs==2024.1.1
overrides==7.7.0
10 changes: 5 additions & 5 deletions test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def start_tensorboard():
2, 'negative')
loader_for_lr = torch.utils.data.DataLoader(dataset=train_dataset_pos, batch_size=2, num_workers=16, pin_memory=True, persistent_workers=True)
train_label_mean = train_dataset_pos.get_label_mean()
train_contour_mean = train_dataset_pos.get_contour_mean()
#train_contour_mean = train_dataset_pos.get_contour_mean()
unsupervised_train_dataset = DataComponents.UnsupervisedDataset("Datasets/unsupervised_train",
"Augmentation Parameters.csv",
64,
Expand All @@ -53,12 +53,12 @@ def start_tensorboard():
model_checkpoint = pl.callbacks.ModelCheckpoint(dirpath="", filename="test", mode="max",
monitor="Val_epoch_dice", save_weights_only=True,
enable_version_counter=False)
arch_args = ('InstanceResidual', 8, 4, 1, True, train_label_mean, train_contour_mean)
arch_args = ('InstanceResidual', 8, 4, 1, True, train_label_mean, 0.5)

def train_model(weight):
def train_model():
model = PLModule(arch_args,
True, True, 'Datasets/mid_visualiser/Ts-4c_visualiser.tif', False,
False, False, False, True, weight)
False, False, False, True)
trainer = pl.Trainer(max_epochs=100, log_every_n_steps=1, logger=TensorBoardLogger(f'lightning_logs', name='test'),
accelerator="gpu", enable_checkpointing=True,
precision='bf16-mixed', enable_progress_bar=True, num_sanity_val_steps=0, callbacks=[model_checkpoint, LearningRateMonitor(logging_interval='epoch')])
Expand All @@ -72,5 +72,5 @@ def train_model(weight):
train_dataloaders=train_loader)
#i = 0.1
#while i <= 1.5:
train_model(0.1)
train_model()
# i += 0.1
14 changes: 10 additions & 4 deletions workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.callbacks import LearningRateFinder
from lightning.pytorch.callbacks import StochasticWeightAveraging
from lightning.pytorch.tuner import Tuner


Expand All @@ -35,7 +36,6 @@ def on_train_epoch_start(self, trainer, pl_module):
self.lr_find(trainer, pl_module)



def create_logger(args):
logger = TensorBoardLogger(f'{args.tensorboard_path}', name='Run')
return logger
Expand Down Expand Up @@ -199,13 +199,19 @@ def find_max_channel(min_channel, max_channel):
else:
to_monitor = 'Train_epoch_dice'
callbacks = []
model_checkpoint = pl.callbacks.ModelCheckpoint(dirpath=f"{args.save_model_path}",
'''model_checkpoint = pl.callbacks.ModelCheckpoint(dirpath=f"{args.save_model_path}",
filename=f"{args.save_model_name}",
mode="max", monitor=to_monitor,
save_weights_only=True, enable_version_counter=False)
save_weights_only=True, enable_version_counter=False)'''
model_checkpoint_last = pl.callbacks.ModelCheckpoint(dirpath=f"{args.save_model_path}",
filename=f"{args.save_model_name}",
save_weights_only=True, enable_version_counter=False)
swa_callback = StochasticWeightAveraging(5e-5, 0.8, int(0.2*args.num_epochs-1))
print(f'SWA starts at {int(0.8*args.num_epochs)}\n')
if logger:
callbacks.append(LearningRateMonitor(logging_interval='epoch'))
callbacks.append(model_checkpoint)
callbacks.append(model_checkpoint_last)
callbacks.append(swa_callback)
trainer = pl.Trainer(max_epochs=args.num_epochs, log_every_n_steps=1, logger=logger,
accelerator="gpu", enable_checkpointing=True,
precision=args.precision, enable_progress_bar=True, num_sanity_val_steps=0,
Expand Down

0 comments on commit 0421c24

Please sign in to comment.