Skip to content

Commit

Permalink
modified augmentations, loss calc to only sox2, and latent space expl…
Browse files Browse the repository at this point in the history
…orer
  • Loading branch information
ijan780 committed Sep 4, 2024
1 parent 59263d8 commit 7e71c6d
Show file tree
Hide file tree
Showing 4 changed files with 595 additions and 14 deletions.
196 changes: 184 additions & 12 deletions notebooks/time_series_subgroup/investigate_latent_space_ij.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

# %%
base_dir = "/mnt/efs/dlmbl/G-et/checkpoints/time-series"
# checkpoint_dir = Path(base_dir) / "2024-09-03_Resnet18_VAE_norm_+aug_sox2only_02_ij_checkpoints"
checkpoint_dir = Path(base_dir) / "2024-09-02_Resnet18_VAE_norm_01_ij_checkpoints"
print(checkpoint_dir)
checkpoint_dir.mkdir(exist_ok=True)
Expand Down Expand Up @@ -65,6 +66,7 @@
# %%
model_params = dict['model']
model.load_state_dict(model_params)

# %%
dataloader = DataLoader(dataset_w_t, batch_size=1, shuffle=False, pin_memory=True, num_workers=8)

Expand All @@ -82,8 +84,6 @@

ax[0].imshow(test_image.squeeze(0).numpy()[0])
ax[1].imshow(test_image.squeeze(0).numpy()[1])

# %%
test_image.shape

# %%
Expand All @@ -92,7 +92,6 @@
result = result[0].detach().cpu().squeeze().numpy()
result.shape


# %%
fig, ax = plt.subplots(2,2,figsize=(2*plot_size,2*plot_size))

Expand Down Expand Up @@ -154,25 +153,26 @@

# %%
print(latents.shape)

# %%
flat_lat = np.array([lat.flatten() for lat in latents])
print(flat_lat.size)
# x = 5 samples x 529 rafts x 9 tp
# y = (21 x 21 px) x 20 z_dim

# %%
tabular_data = "/mnt/efs/dlmbl/G-et/tabular_data"
if not os.path.isdir(tabular_data):
os.mkdir(tabular_data)
df_lat = pd.DataFrame(
flat_lat,
columns = [f"LD_mu_{i+1}" for i in range(flat_lat.shape[1])]
latents,
columns = [f"LD_mu_{i+1}" for i in range(latents.shape[1])]
)
df_lat['Time'] = timepoints
df_lat['Raft'] = rafts

df_lat.to_csv(Path(tabular_data) / "20240902_Resnet_20z_LatentSpace_Norm_ij.csv")
df_lat.to_csv(Path(tabular_data) / "20240903_Resnet_20z_LatentSpace_Norm_+aug_sox2only_ij.csv",
index=False)

# %%

# Only run if loading df_lat from saved csv file.
tabular_data = "/mnt/efs/dlmbl/G-et/tabular_data"
df_lat = pd.read_csv(Path(tabular_data) / "20240902_Resnet_20z_LatentSpace_Norm_ij.csv")

Expand All @@ -184,7 +184,9 @@

# %%
df_lat = StandardScaler().fit_transform(
df_lat.drop(columns=['Unnamed: 0', 'Time', 'Raft']))
df_lat.drop(columns=['Time', 'Raft']))
# df_lat = StandardScaler().fit_transform(
# df_lat.drop(columns=['Unnamed: 0', 'Time', 'Raft']))

# %%
components=5
Expand Down Expand Up @@ -248,6 +250,176 @@
palette="viridis")

# %%
umap_df.to_csv(Path(tabular_data) / "20240902_Resnet_20z_LatentSpace_Norm_ij_umap.csv")
print(len(rafts))

# %%
from skimage import io
from scipy import stats

arr = 1
raft_dict = {}

for img_name in os.listdir(folder_imgs):
raft_num = str(img_name)[-7:-4]
img = io.imread(Path(folder_imgs) / img_name)
img = img[:, 1].flatten()

per_25 = np.percentile(img, 25)
stdev = np.std(img)
mean = np.mean(img)

histo = np.histogram(img, bins=256, range=(0,1), density=True)[0]
ent = stats.entropy(histo, base=2)

raft_dict[raft_num] = (per_25, stdev, mean, ent)


# %%
per_25_list = []
stdev_list = []
mean_list = []
ent_list = []

count = 0
for raft in rafts:
raft_num = str(raft)[-7:-4]
per_25_list.append(raft_dict[raft_num][0])
stdev_list.append(raft_dict[raft_num][1])
mean_list.append(raft_dict[raft_num][2])
ent_list.append(raft_dict[raft_num][3])
count += 1
print(len(ent_list))

# %%

# Create UMAP
umap_transformer = umap.UMAP(n_neighbors = 30)
umap_out = umap_transformer.fit_transform(df_lat)

umap_df = pd.DataFrame(umap_out,columns=["UMAP_1","UMAP_2"])
umap_df["Time"] = timepoints
umap_df["Rafts"] = rafts
umap_df["Per_25"] = per_25_list
umap_df["Stdev"] = stdev_list
umap_df["Mean"] = mean_list
umap_df["Ent"] = ent_list

umap_df.head()

# %%

label = 'Per_25'

plt.figure(figsize=(8, 6))
scatter = sns.scatterplot(shuffle(umap_df),
x="UMAP_1",
y="UMAP_2",
hue=label,
alpha=0.5,
palette="viridis",
legend=False)

norm = plt.Normalize(umap_df[label].min(), umap_df[label].max())
sm = plt.cm.ScalarMappable(cmap="viridis", norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=scatter.axes)

# %%

label = 'Stdev'

plt.figure(figsize=(8, 6))
scatter = sns.scatterplot(shuffle(umap_df),
x="UMAP_1",
y="UMAP_2",
hue=label,
alpha=0.5,
palette="viridis",
legend=False)

norm = plt.Normalize(umap_df[label].min(), umap_df[label].max())
sm = plt.cm.ScalarMappable(cmap="viridis", norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=scatter.axes)

# %%

label = 'Mean'

plt.figure(figsize=(8, 6))
scatter = sns.scatterplot(shuffle(umap_df),
x="UMAP_1",
y="UMAP_2",
hue=label,
alpha=0.5,
palette="viridis",
legend=False)

norm = plt.Normalize(umap_df[label].min(), umap_df[label].max())
sm = plt.cm.ScalarMappable(cmap="viridis", norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=scatter.axes)

# %%

cov_list = np.divide(np.array(stdev_list), np.array(mean_list))
umap_df["COV"] = cov_list

label = 'COV'

plt.figure(figsize=(8, 6))
scatter = sns.scatterplot(shuffle(umap_df),
x="UMAP_1",
y="UMAP_2",
hue=label,
alpha=0.5,
palette="viridis",
legend=False)

norm = plt.Normalize(umap_df[label].min(), umap_df[label].max())
sm = plt.cm.ScalarMappable(cmap="viridis", norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=scatter.axes)

# %%

label = 'Ent'

plt.figure(figsize=(8, 6))
scatter = sns.scatterplot(shuffle(umap_df),
x="UMAP_1",
y="UMAP_2",
hue=label,
alpha=0.5,
palette="viridis",
legend=False)

norm = plt.Normalize(umap_df[label].min(), umap_df[label].max())
sm = plt.cm.ScalarMappable(cmap="viridis", norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=scatter.axes)

# %%

label = 'Time'

plt.figure(figsize=(8, 6))
scatter = sns.scatterplot(shuffle(umap_df),
x="UMAP_1",
y="UMAP_2",
hue=label,
alpha=0.5,
palette="viridis",
legend=False)

norm = plt.Normalize(umap_df[label].min(), umap_df[label].max())
sm = plt.cm.ScalarMappable(cmap="viridis", norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=scatter.axes)


# %%
umap_df.to_csv(Path(tabular_data) / "20240903_Resnet_20z_LatentSpace_Norm_+aug_sox2only_ij_umap.csv",
index=False)

# %%
68 changes: 68 additions & 0 deletions notebooks/time_series_subgroup/test_transforms_ij.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# %%
from pathlib import Path
import os
import pandas as pd
import numpy as np

import torch
from torch.utils.data import DataLoader
from torch.nn import functional as F

import torchvision.transforms as trans
from torchvision.transforms import v2

from embed_time.dataloader_ij import LiveGastruloidDataset
from embed_time.transforms_ij import CustomToTensor, ShiftIntensity, SelectRandomTPNumpy

from skimage import io


# %%
folder_imgs = r"/mnt/efs/dlmbl/G-et/data/live_gastruloid/240722_R2GLR_1.8e6_0-48hrBMP4_0%aneu_2_Analysis/Individual Raft Images Norm/"
loading_transforms = trans.Compose([
SelectRandomTPNumpy(0),
CustomToTensor(),
v2.Resize((336,336)),
ShiftIntensity(bf_factor=2),
v2.RandomAffine(
degrees=90,
translate=[0.1,0.1],
),
v2.RandomHorizontalFlip(),
v2.RandomVerticalFlip(),
v2.GaussianBlur(kernel_size=15, sigma=(0.1,20.0)),
])
# loading_transforms = trans.Compose([
# SelectRandomTPNumpy(0),
# CustomToTensor(),
# v2.Resize((336,336)),
# v2.RandomAffine(
# degrees=90,
# translate=[0.1,0.1],
# ),
# v2.RandomHorizontalFlip(),
# v2.RandomVerticalFlip(),
# v2.GaussianBlur(kernel_size=15, sigma=(0.1,20.0)),
# ])

dataset_w_t = LiveGastruloidDataset(
img_dir = folder_imgs,
transform = loading_transforms,
)


dataloader_train = DataLoader(dataset_w_t, batch_size=5, shuffle=True, pin_memory=True, num_workers=8)

# %%
for data in dataloader_train:
example_tensor = data
break

# %%
batch_idx = 3
io.imshow(np.array(example_tensor[batch_idx][1]))
np.max(np.array(example_tensor[batch_idx][1]))

# %%

# %%
Loading

0 comments on commit 7e71c6d

Please sign in to comment.