Skip to content

Commit

Permalink
solution changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonks684 committed Aug 20, 2024
1 parent ff33a70 commit 97fe2a6
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,12 @@
If you are having issues loading the tensorboard session click "Launch TensorBoard session". You should then be able to add the log_dir path below and a tensorboard session shouls then load.
"""

# %%
log_dir = f"{top_dir}/model_tensorboard/{opt.name}/"
%reload_ext tensorboard
# %%
%tensorboard --logdir $log_dir

# %% [markdown]
"""
<div class="alert alert-info">
Expand Down Expand Up @@ -614,7 +615,6 @@ def cellpose_segmentation(prediction:ArrayLike,target:ArrayLike)->Tuple[torch.Sh
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cp_nuc_kwargs = {
"diameter": 65,
"channels": [0],
"cellprob_threshold": 0.0,
}
cellpose_model = models.CellposeModel(
Expand Down Expand Up @@ -681,19 +681,20 @@ def cellpose_segmentation(prediction:ArrayLike,target:ArrayLike)->Tuple[torch.Sh
test_segmentation_metrics.head()
# %%
# Define function to visualize the segmentation results.
def visualise_results_and_masks(segmentation_results, test_segmentation_metrics: Tuple[dict], rows: int = 5, crop_size: int = None, crop_type: str = 'center'):
def visualise_results_and_masks(segmentation_results: Tuple[dict], segmentation_metrics: pd.DataFrame, rows: int = 5, crop_size: int = None, crop_type: str = 'center'):

# Sample a subset of the segmentation results.
sample_indices = np.random.choice(len(phase_images),rows)
segmentation_metrics_subset = segmentation_metrics_subset.iloc[sample_indices,:]
print(sample_indices)
segmentation_metrics = segmentation_metrics.iloc[sample_indices,:]
segmentation_results = [segmentation_results[i] for i in sample_indices]
# Define the figure and axes.
fig, axes = plt.subplots(rows, 5, figsize=(rows*3, 15))

# Visualize the segmentation results.
for i, idx in enumerate(test_segmentation_metrics):
result = segmentation_results[idx]
segmentation_metrics = segmentation_metrics_subset.iloc[i]
for i in range(len((segmentation_results))):
segmentation_metric = segmentation_metrics.iloc[i]
result = segmentation_results[i]
phase_image = result["phase_image"]
target_stain = result["target_stain"]
target_label = result["target_label"]
Expand All @@ -706,6 +707,7 @@ def visualise_results_and_masks(segmentation_results, test_segmentation_metrics:
target_label = crop(target_label, crop_size, crop_type)
pred_stain = crop(pred_stain, crop_size, crop_type)
pred_label = crop(pred_label, crop_size, crop_type)

axes[i, 0].imshow(phase_image, cmap="gray")
axes[i, 0].set_title("Phase")
axes[i, 1].imshow(
Expand All @@ -721,15 +723,15 @@ def visualise_results_and_masks(segmentation_results, test_segmentation_metrics:
axes[i, 3].set_title("Target Fluorescence Mask")
axes[i, 4].imshow(pred_label, cmap="inferno")
# Add Metric values to the title
axes[i, 4].set_title(f"Virtual Stain Mask\nAcc:{segmentation_metrics['accuracy']:.2f} Dice:{segmentation_metrics['dice']:.2f} Jaccard:{segmentation_metrics['jaccard']:.2f} MAP:{segmentation_metrics['mAP']:.2f}")
axes[i, 4].set_title(f"Virtual Stain Mask\nAcc:{segmentation_metric['accuracy']:.2f} Dice:{segmentation_metric['dice']:.2f}\nJaccard:{segmentation_metric['jaccard']:.2f} MAP:{segmentation_metric['mAP']:.2f}")
# Turn off the axes.
for ax in axes.flatten():
ax.axis("off")

plt.tight_layout()
plt.show()

visualise_results_and_masks(test_segmentation_metrics, crop_size=256, crop_type='center')
visualise_results_and_masks(segmentation_results,test_segmentation_metrics, crop_size=256, crop_type='center')
# %% [markdown]
"""
<div class="alert alert-success">
Expand Down Expand Up @@ -865,7 +867,6 @@ def visualise_both_methods(
# Append the images to the arrays.
sample_images[index] = sample_image

# Plot the phase image, the target image, the variance of samples and 3 samples
# %%
# Create a matplotlib plot with animation through images.
def animate_images(images):
Expand Down Expand Up @@ -896,7 +897,6 @@ def animate_images(images):
## Checkpoint 5
Congratulations! You have generated predictions from a pre-trained model and evaluated the performance of the model on unseen data. You have computed pixel-level metrics and instance-level metrics to evaluate the performance of the model. You may have also began training your own Pix2PixHD GAN models with alternative hyperparameters.
Congratulations! This is the end of the conditional generative modelling approach to image translation notebook. You have trained and examined the loss components of Pix2PixHD GAN. You have compared the results of a regression-based approach vs. generative modelling approach and explored the variability in virtual staining solutions. I hope you have enjoyed learning experience!
</div>
"""

0 comments on commit 97fe2a6

Please sign in to comment.