Skip to content

Commit

Permalink
Fixs various bugs relating to Windows installation and WebUI, multith…
Browse files Browse the repository at this point in the history
…readed data loading, update gradio
  • Loading branch information
fgdfgfthgr-fox committed Sep 23, 2024
1 parent ce39835 commit e7703d5
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 173 deletions.
8 changes: 4 additions & 4 deletions Augmentation Parameters.csv
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
Augmentation,Probability,Low Bound,High Bound,Value
Rotate xy,0.25,0,360,
Rotate xz,0.0,0,360,
Rotate yz,0.0,0,360,
Rotate xz,0,0,360,
Rotate yz,0,0,360,
Rescaling,0.25,0.75,1.25,
Edge Replicate Pad,0.0,0,0,0.075
Edge Replicate Pad,0,0,0,0.075
Vertical Flip,0.5,,,
Horizontal Flip,0.5,,,
Depth Flip,0.5,,,
Expand All @@ -15,5 +15,5 @@ Adjust Contrast,0.75,0.5,1.5,
Adjust Gamma,0.75,0.5,1.5,
Adjust Brightness,0.75,0.75,1.4,
Salt And Pepper,0.25,0.005,0.01,
Label Blur,1,1,1,5
Label Blur,0,1,1,5
Contour Blur,0,0.5,0.5,7
64 changes: 32 additions & 32 deletions Components/DataComponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,38 +627,6 @@ def __getitem__(self, idx):
return self.dataset_negative[idx]


class CollectedDatasetAlt(torch.utils.data.Dataset):
"""
For handling unsupervised learning without dataset pairing.
"""
def __init__(self, dataset, dataset_unsupervised):
self.dataset = dataset
self.dataset_unsupervised = dataset_unsupervised
self.rectified_unsupervised_size = math.floor(len(self.dataset_unsupervised)/2) * 2

def __len__(self):
return self.rectified_unsupervised_size + len(self.dataset)

def __getitem__(self, idx):
if idx < self.rectified_unsupervised_size:
return self.dataset_unsupervised[idx]
else:
idx -= self.rectified_unsupervised_size
return self.dataset[idx]


class CollectedSamplerAlt(torch.utils.data.Sampler):
def __init__(self, data_source):
super(CollectedSamplerAlt, self).__init__(data_source)
self.data_source = data_source

def __iter__(self):
return iter(np.random.permutation(len(self.data_source)))

def __len__(self):
return len(self.data_source)


class CollectedSampler(torch.utils.data.Sampler):
def __init__(self, data_source, batch_size, dataset_unsupervised=None):
super(CollectedSampler, self).__init__(data_source)
Expand Down Expand Up @@ -715,6 +683,38 @@ def __len__(self):
return len(self.data_source)


class CollectedDatasetAlt(torch.utils.data.Dataset):
"""
For handling unsupervised learning without dataset pairing.
"""
def __init__(self, dataset, dataset_unsupervised):
self.dataset = dataset
self.dataset_unsupervised = dataset_unsupervised
self.rectified_unsupervised_size = math.floor(len(self.dataset_unsupervised)/2) * 2

def __len__(self):
return self.rectified_unsupervised_size + len(self.dataset)

def __getitem__(self, idx):
if idx < self.rectified_unsupervised_size:
return self.dataset_unsupervised[idx]
else:
idx -= self.rectified_unsupervised_size
return self.dataset[idx]


class CollectedSamplerAlt(torch.utils.data.Sampler):
def __init__(self, data_source):
super(CollectedSamplerAlt, self).__init__(data_source)
self.data_source = data_source

def __iter__(self):
return iter(np.random.permutation(len(self.data_source)))

def __len__(self):
return len(self.data_source)


def custom_collate(batch):
if len(batch) == 1:
return batch[0].unsqueeze(0)
Expand Down
4 changes: 2 additions & 2 deletions Components/Metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor, sparse_label=Fals
tp, fn, tn, fp = self.calculate_other_metrices(inputs, targets)
return F_loss.mean(), intersection, union, tp, fn, tn, fp
elif self.loss_mode == "bce_no_dice":
# Scale down to 10% since it's used for unsupervised learning and is often much higher than supervised
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') * 0.1
# Scale down to 20% since it's used for unsupervised learning and is often much higher than supervised
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') * 0.2
return BCE_loss.mean(), torch.nan, torch.nan, torch.nan, torch.nan, torch.nan, torch.nan
elif self.loss_mode == "dice":
inputs = torch.sigmoid(inputs)
Expand Down
182 changes: 94 additions & 88 deletions WebUI.py

Large diffs are not rendered by default.

19 changes: 16 additions & 3 deletions install_dependencies_Windows.bat
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,24 @@ set ERROR_REPORTING=FALSE
:: Do not reinstall existing pip packages on Windows
set PIP_IGNORE_INSTALLED=0

:: Check GPU info
:: Check GPU info, if neither AMD or NVIDIA gpu found, install CPU only version of PyTorch
set gpu_info=
for /f "tokens=*" %%i in ('"wmic path win32_videocontroller get caption"') do set gpu_info=!gpu_info! %%i
if not "!gpu_info:~-8!"=="NVIDIA" (
if not "!gpu_info:~-3!"=="AMD" if not defined TORCH_COMMAND set TORCH_COMMAND=pip install torch torchvision

echo !gpu_info! | findstr /i "NVIDIA" >nul
if %ERRORLEVEL% neq 0 (
echo !gpu_info! | findstr /i "AMD" >nul
if %ERRORLEVEL% equ 0 (
echo.
echo Warning: AMD GPU on Windows is not yet supported by PyTorch and will resort to CPU-only installation.
echo.
if not defined TORCH_COMMAND set TORCH_COMMAND=pip install torch torchvision
) else (
echo.
rem Neither NVIDIA nor AMD GPUs are found, proceed with CPU-only installation
echo.
if not defined TORCH_COMMAND set TORCH_COMMAND=pip install torch torchvision
)
)


Expand Down
47 changes: 23 additions & 24 deletions pl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, network_arch, enable_val, enable_mid_visual, mid_visual_image
self.use_sparse_label_test = use_sparse_label_test
self.logging = logging
self.train_metrics, self.val_metrics, self.test_metrics = [], [], []
self.initial_lr = 0.0005
self.lr = 1 # Not the actual LR since it's automatically computed, but the ratio of lr as ReduceLROnPlateau works.
self.p_loss_fn = Metrics.BinaryMetrics("focal")
self.u_p_loss_fn = Metrics.BinaryMetrics("bce_no_dice")
self.c_loss_fn = Metrics.BinaryMetrics("dice+bce")
Expand Down Expand Up @@ -113,7 +113,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):

def configure_optimizers(self):
#fused = True if device == "cuda" else False
optimizer = AdEMAMix(self.parameters(), lr=self.initial_lr, weight_decay=0.0001)
optimizer = AdEMAMix(self.parameters(), lr=self.lr, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
factor=0.5, patience=15,
threshold_mode='rel',
Expand Down Expand Up @@ -202,31 +202,30 @@ def on_before_optimizer_step(self, optimizer):


if __name__ == "__main__":
torch.set_num_threads(12)
train_label_mean = DataComponents.path_to_tensor("Datasets/train/Labels_0.tif", label=True).to(torch.float32).mean()
model = PLModule(Semantic_General.UNet(base_channels=16, z_to_xy_ratio=1, depth=4, type='Residual', se=True, unsupervised=False, label_mean=train_label_mean),
train_label_mean = DataComponents.path_to_tensor("Datasets/train/Labels_image.tif", label=True).to(torch.float32).mean()
model = PLModule(Semantic_General.UNet(base_channels=16, z_to_xy_ratio=1, depth=4, type='Residual', se=True, unsupervised=True, label_mean=train_label_mean),
True, False, 'Datasets/mid_visualiser/Ts-4c_ref_patch.tif', False,
False, False, False, True)
val_dataset = DataComponents.ValDataset("Datasets/val", 128, 128, False, "Augmentation Parameters.csv")
predict_dataset = DataComponents.Predict_Dataset("Datasets/predict", 112, 112, 8, 8, True)
val_dataset = DataComponents.ValDataset("Datasets/val", 256, 32, False, "Augmentation Parameters.csv")
predict_dataset = DataComponents.Predict_Dataset("Datasets/predict", 232, 24, 12, 4, True)
train_dataset_pos = DataComponents.TrainDataset("Datasets/train", "Augmentation Parameters.csv",
64,
128, 128, False, False, 0,
256, 32, False, False, 0,
0,
0, 'positive')
train_dataset_neg = DataComponents.TrainDataset("Datasets/train", "Augmentation Parameters.csv",
64,
128, 128, False, False, 0,
256, 32, False, False, 0,
0,
0, 'negative')
#unsupervised_train_dataset = DataComponents.UnsupervisedDataset("Datasets/unsupervised_train",
# "Augmentation Parameters.csv",
# 64,
# 128, 128)
#train_dataset = DataComponents.CollectedDataset(train_dataset_pos, train_dataset_neg, unsupervised_train_dataset)
train_dataset = DataComponents.CollectedDataset(train_dataset_pos, train_dataset_neg)
#sampler = DataComponents.CollectedSampler(train_dataset, 2, unsupervised_train_dataset)
sampler = DataComponents.CollectedSampler(train_dataset, 2)
unsupervised_train_dataset = DataComponents.UnsupervisedDataset("Datasets/unsupervised_train",
"Augmentation Parameters.csv",
128,
256, 32)
train_dataset = DataComponents.CollectedDataset(train_dataset_pos, train_dataset_neg, unsupervised_train_dataset)
#train_dataset = DataComponents.CollectedDataset(train_dataset_pos, train_dataset_neg)
sampler = DataComponents.CollectedSampler(train_dataset, 2, unsupervised_train_dataset)
#sampler = DataComponents.CollectedSampler(train_dataset, 2)
collate_fn = DataComponents.custom_collate
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=2,
collate_fn=collate_fn, sampler=sampler,
Expand All @@ -235,9 +234,9 @@ def on_before_optimizer_step(self, optimizer):
predict_loader = torch.utils.data.DataLoader(dataset=predict_dataset, batch_size=1, num_workers=0)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=1)
#model_checkpoint = pl.callbacks.ModelCheckpoint(dirpath="", filename="{epoch}-{Val_epoch_dice:.2f}", mode="max", save_weights_only=True)
model_checkpoint = pl.callbacks.ModelCheckpoint(dirpath="", filename="16-true", mode="max",
model_checkpoint = pl.callbacks.ModelCheckpoint(dirpath="", filename="test", mode="max",
monitor="Val_epoch_dice", save_weights_only=True, enable_version_counter=False)
trainer = pl.Trainer(max_epochs=5, log_every_n_steps=1, logger=TensorBoardLogger(f'lightning_logs', name='16-true'),
trainer = pl.Trainer(max_epochs=5, log_every_n_steps=1, logger=TensorBoardLogger(f'lightning_logs', name='test'),
accelerator="gpu", enable_checkpointing=True, gradient_clip_val=0.3,
precision="32", enable_progress_bar=True, num_sanity_val_steps=0, callbacks=[model_checkpoint,
FineTuneLearningRateFinder(min_lr=0.00001, max_lr=0.1, attr_name='initial_lr')])
Expand All @@ -246,13 +245,13 @@ def on_before_optimizer_step(self, optimizer):
trainer.fit(model,
val_dataloaders=val_loader,
train_dataloaders=train_loader)
model = PLModule.load_from_checkpoint('16-true.ckpt')
model = PLModule.load_from_checkpoint('test.ckpt')
predictions = trainer.predict(model, predict_loader)
#del predict_loader, predict_dataset
DataComponents.predictions_to_final_img(predictions, meta_info, direc='Datasets/result',
hw_size=112, depth_size=112,
hw_overlap=8,
depth_overlap=8,
hw_size=232, depth_size=24,
hw_overlap=12,
depth_overlap=4,
TTA_hw=True)
#end_time = time.time()
#elapsed_time = end_time - start_time
Expand All @@ -274,4 +273,4 @@ def on_before_optimizer_step(self, optimizer):
DataComponents.predictions_to_final_img_instance(predictions, meta_info, direc='Datasets/result',
hw_size=256, depth_size=64,
hw_overlap=32, depth_overlap=8)
'''
'''
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pyimagej==1.4.1
scikit-image==0.21.0
matplotlib==3.8.0
lightning==2.3.3
gradio==4.29.0
gradio==4.44.0
scipy==1.11.2
tensorboard==2.15.1
joblib==1.3.2
Expand Down
Loading

0 comments on commit e7703d5

Please sign in to comment.