Skip to content

Commit

Permalink
general updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonks684 committed Aug 15, 2024
1 parent f8a00c7 commit 0c514a1
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@
- Configure Pix2PixHD GAN to train for translating from phase to nuclei.
"""
# %%
from pathlib import Path
# TO DO: Change the path to the directory where the data and code is stored is stored.
import os
import sys
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
parent_dir = os.path.abspath("~/data/06_image_translation/part2/")
sys.path.append(parent_dir)

# %%
from pathlib import Path
import torch
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -215,19 +218,15 @@
## A heads up of what to expect from the training...
<br><br>
**Visualise Phase, Fluorescence and Virtual Stain for Validation Examples**<br>
- We can observe how the performance improves over time using the images tab and the sliding window.
**- Visualise results**: We can observe how the performance improves over time using the images tab and the sliding window.
<br><br>
**Discriminator Predicted Probabilities**<br>
- We plot the discriminator's predicted probabilities that the phase with fluorescence is phase and fluorescence and that the phase with virtual stain is phase with virtual stain. It is typically trained until the discriminator can no longer classify whether or not the generated images are real or fake better than a random guess (p(0.5)). We plot this for both the training and validation datasets.
**- Discriminator Predicted Probabilities**: We plot the discriminator's predicted probabilities that the phase with fluorescence is phase and fluorescence and that the phase with virtual stain is phase with virtual stain. It is typically trained until the discriminator can no longer classify whether or not the generated images are real or fake better than a random guess (p(0.5)). We plot this for both the training and validation datasets.
<br><br>
**Adversarial Loss**<br>
- We can formulate the adversarial loss as a Least Squared Error Loss in which for real data the discriminator should output a value close to 1 and for fake data a value close to 0. The generator's goal is to make the discriminator output a value as close to 1 for fake data. We plot the least squared error loss.
**- Adversarial Loss**: We can formulate the adversarial loss as a Least Squared Error Loss in which for real data the discriminator should output a value close to 1 and for fake data a value close to 0. The generator's goal is to make the discriminator output a value as close to 1 for fake data. We plot the least squared error loss.
<br><br>
**Feature Matching Loss**<br>
- Both networks are also trained using the generator feature matching loss which encourages the generator to produce images that contain similar statistics to the real images at each scale. We also plot the feature matching L1 loss for the training and validation sets together to observe the performance and how the model is fitting the data.<br><br>
**- Feature Matching Loss**: Both networks are also trained using the generator feature matching loss which encourages the generator to produce images that contain similar statistics to the real images at each scale. We also plot the feature matching L1 loss for the training and validation sets together to observe the performance and how the model is fitting the data.<br><br>
<br><br>
This implementation allows for the turning on/off of the least-square loss term by setting the opt.no_lsgan flag to the model options. As well as the turning off of the feature matching loss term by setting the opt.no_ganFeat_loss flag to the model options. Something you might want to explore in the next section!<br><br>
- This implementation allows for the turning on/off of the least-square loss term by setting the opt.no_lsgan flag to the model options. As well as the turning off of the feature matching loss term by setting the opt.no_ganFeat_loss flag to the model options. Something you might want to explore in the next section!<br><br>
</div>
"""
# %% [markdown]
Expand Down Expand Up @@ -369,7 +368,7 @@
test_data_loader = CreateDataLoader(opt)
test_dataset = test_data_loader.load_data()
visualizer = Visualizer(opt)

print(f"Total Test Images = {len(test_data_loader)}")
# Load pre-trained model
model = create_model(opt)

Expand Down Expand Up @@ -411,7 +410,7 @@
If you can incorporate the crop function below to zoom in on the images that would be great!
</div>
"""

# %%
# Define a function to crop the images so we can zoom in.
def crop(img, crop_size, type=None):
"""
Expand Down Expand Up @@ -564,9 +563,9 @@ def min_max_scale(input):
column=["psnr_nuc"],
rot=30,
)
test_pixel_metrics.head()
#%%[markdown]
"""
########## TODO ##############
- What do these metrics tells us about the performance of the model?
- How do the pixel-level metrics compare to the regression-based approach?
- Could these metrics be skewed by the presence of hallucinations or background pilxels in the virtual stains?
Expand Down Expand Up @@ -752,15 +751,16 @@ def visualise_both_methods():
##########################

def visualise_both_methods(
phase_images: np.array, target_stains: np.array, pix2pixHD_results: np.array, viscy_results: np.array,crop_size=None
phase_images: np.array, target_stains: np.array, pix2pixHD_results: np.array, viscy_results: np.array,crop_size=None,crop_type='center'
):
fig, axes = plt.subplots(5, 4, figsize=(15, 15))
sample_indices = np.random.choice(len(phase_images), 5)
if crop is not None:
phase_images = phase_images[:,:crop_size,:crop_size]
target_stains = target_stains[:,:crop_size,:crop_size]
pix2pixHD_results = pix2pixHD_results[:,:crop_size,:crop_size]
viscy_results = viscy_results[:,:crop_size,:crop_size]
if crop_size is not None:
phase_image = crop(phase_image, crop_size, crop_type)
target_stain = crop(target_stain, crop_size, crop_type)
target_label = crop(target_label, crop_size, crop_type)
pred_stain = crop(pred_stain, crop_size, crop_type)
pred_label = crop(pred_label, crop_size, crop_type)

for i, idx in enumerate(sample_indices):
axes[i, 0].imshow(phase_images[idx], cmap="gray")
Expand Down

0 comments on commit 0c514a1

Please sign in to comment.