Skip to content

Commit

Permalink
Add non-targeting experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Sep 3, 2024
1 parent 85430e2 commit 6f05288
Show file tree
Hide file tree
Showing 11 changed files with 1,235 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,5 @@ pyrepo
# Log Folders
embed_time_runs/
embed_time_static_runs/
notebooks/
# notebooks/
*.ipynb
134 changes: 134 additions & 0 deletions notebooks/nontargeting_experiments/get_vgg_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# %% [markdown]
# Loading the results of vgg experiments and showing their losses, accuracies, and confusion matrices.
#
# %%
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import torch
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader
from embed_time.dataset_static import ZarrCellDataset
from embed_time.dataloader_static import collate_wrapper
from funlib.learn.torch.models import Vgg2D
from embed_time.static_utils import read_config
from torchvision import transforms as v2
import seaborn as sns

# %% Utilities
def plot_metrics(metrics):
metrics.plot(subplots=True, figsize=(10, 10))
plt.show()

def load_best_checkpoint(directory, metrics):
# get epoch in metric with highest val_accuracy
best_index = metrics['val_accuracy'].idxmax()
best_epoch = metrics['epoch'][best_index]
checkpoint = directory / f"{best_epoch}.pth"
return checkpoint

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_confusion_matrix(model, val_dataloader, class_names, label_type, normalize='true'):
model.eval()
predictions = []
labels = []

for batch in tqdm(val_dataloader, desc="Validation", total=len(val_dataloader)):
images, batch_labels = batch['cell_image'], batch[label_type]
batch_labels = torch.tensor(
[class_names.index(label) for label in batch_labels]
)
images = images.to(device)
batch_labels = batch_labels.to(device)

output = model(images)
predictions.append(output.argmax(dim=1).cpu().numpy())
labels.append(batch_labels.cpu().numpy())

cm = confusion_matrix(np.concatenate(labels), np.concatenate(predictions), normalize=normalize)
return cm


def create_dataloader(dataset, label_type, batch_size=16, num_workers=8, balance_dataset=True):
csv_file = f"/mnt/efs/dlmbl/G-et/csv/dataset_split_{dataset}.csv"
subdir = Path(f"/mnt/efs/dlmbl/G-et/da_testing/vgg2d_{dataset}/{label_type}_{balance_dataset}")
df = pd.read_csv(csv_file)
class_names = df[label_type].sort_values().unique().tolist()
num_classes = len(class_names)

metadata_keys = ['gene', 'barcode', 'stage']
images_keys = ['cell_image']
crop_size = 96
normalizations = v2.Compose([v2.CenterCrop(crop_size)])
yaml_file_path = "/mnt/efs/dlmbl/G-et/yaml/dataset_info_20240901_155625.yaml"
dataset = "benchmark"
dataset_mean, dataset_std = read_config(yaml_file_path)

val_dataset = ZarrCellDataset(
parent_dir = '/mnt/efs/dlmbl/S-md/',
csv_file = csv_file,
split='val',
channels=[0, 1, 2, 3],
mask='min',
normalizations=normalizations,
interpolations=None,
mean=dataset_mean,
std=dataset_std
)

# Create a DataLoader for the validation dataset
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_wrapper(metadata_keys, images_keys),
drop_last=False
)
return subdir, val_dataloader, class_names, num_classes

# %% Setup happens here
dataset = "benchmark_nontargeting_barcode"
label_type = 'barcode'
batch_size = 16
num_workers = 8
balance_dataset = True

subdir, val_dataloader, class_names, num_classes = create_dataloader(dataset, label_type, batch_size, num_workers)

metrics = pd.read_csv(subdir / "metrics.csv")
plot_metrics(metrics)
# %% Get the model to load the best checkpoint, create a confusion matrix
checkpoint = load_best_checkpoint(subdir, metrics)
model = Vgg2D(
input_size=(96, 96),
input_fmaps=4,
output_classes=num_classes,
)
model = model.to(device)
model.load_state_dict(torch.load(checkpoint)["model_state_dict"])
model.eval()

cm = get_confusion_matrix(model, val_dataloader, class_names, label_type)

# %% Validation loop for confusion matrix
sns.heatmap(cm, annot=True, fmt='.2f', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
# Set tick labels
# plt.xticks(np.arange(num_classes) + 0.5, class_names)
# plt.yticks(np.arange(num_classes) + 0.5, class_names)
plt.show()

# %%
len(class_names)
# %%
df = pd.read_csv(f"/mnt/efs/dlmbl/G-et/csv/dataset_split_{dataset}_{balance_dataset}.csv")
df = df[df.split == 'val']
df.barcode.value_counts()
# %%
dataset
# %%
18 changes: 18 additions & 0 deletions notebooks/nontargeting_experiments/make_nontargeting_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

# %% Make an intermediate dataset
import pandas as pd

location = "/mnt/efs/dlmbl/G-et/csv/dataset_split_1168.csv"
benchmark_location = "/mnt/efs/dlmbl/G-et/csv/dataset_split_benchmark.csv"

metadata = pd.read_csv(location)
benchmark_metadata = pd.read_csv(benchmark_location)

# %% Randomly samply a subset of metadata that is the same size as the benchmark data
sample = metadata[metadata['gene'] == "nontargeting"]
sample = sample.sample(n=benchmark_metadata.shape[0])

# %%
sample.to_csv("/mnt/efs/dlmbl/G-et/csv/dataset_split_benchmark_nontargeting.csv", index=False)

# %%
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

# %% Make an intermediate dataset
# This includes *only* a subset of barcodes that are nontargeting
import pandas as pd
import numpy as np

# %%
location = "/mnt/efs/dlmbl/G-et/csv/dataset_split_1168.csv"

metadata = pd.read_csv(location)
# %%
sample = metadata[metadata['gene'] == "nontargeting"]
np.random.seed(42)
barcodes = np.random.choice(
sample["barcode"].sort_values().unique(),
size=10,
replace=False,
)
# %% Randomly samply a subset of metadata that is the same size as the benchmark data
sample = metadata[metadata['barcode'].isin(barcodes)]

# %%
sample["split"].value_counts()
# %%
# make sure each barcode is in each split
for split in ["train", "val", "test"]:
assert set(barcodes) == set(sample[sample["split"] == split]["barcode"].unique())

# %%
sample.to_csv("/mnt/efs/dlmbl/G-et/csv/dataset_split_benchmark_nontargeting_barcode.csv", index=False)


# %%
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

# %% Make an intermediate dataset
# This includes *only* a subset of barcodes that are nontargeting and *all* barcodes that are CCT2
import pandas as pd
import numpy as np

# %%
location = "/mnt/efs/dlmbl/G-et/csv/dataset_split_1168.csv"
nontargeting_location = "/mnt/efs/dlmbl/G-et/csv/dataset_split_benchmark_nontargeting_barcode.csv"

metadata = pd.read_csv(location)
nontargeting_metadata = pd.read_csv(nontargeting_location)
# %%
cct2 = metadata[metadata['gene'] == "CCT2"]
# %%
sample = pd.concat([nontargeting_metadata, cct2])
sample["split"].value_counts()
# %%
barcodes = sample["barcode"].sort_values().unique()
genes = sample["gene"].sort_values().unique()
# %%
# make sure each barcode is in each split
for split in ["train", "val", "test"]:
assert set(barcodes) == set(sample[sample["split"] == split]["barcode"].unique())

# %%
sample.to_csv("/mnt/efs/dlmbl/G-et/csv/dataset_split_benchmark_nontargeting_barcode_with_cct2.csv", index=False)

# %%
17 changes: 17 additions & 0 deletions notebooks/nontargeting_experiments/make_nontargeting_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

# %% Make an intermediate dataset
import pandas as pd

location = "/mnt/efs/dlmbl/G-et/csv/dataset_split_1168.csv"

metadata = pd.read_csv(location)

# %%
assert "nontargeting" in metadata['gene'].values
# %% Keep only the nontargeting and CCT2 genes
sample = metadata[metadata['gene'] == "nontargeting"]

# %%
sample.to_csv("/mnt/efs/dlmbl/G-et/csv/dataset_split_nontargeting.csv", index=False)

# %%
Loading

0 comments on commit 6f05288

Please sign in to comment.