Skip to content

Commit

Permalink
updated evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
mat10d committed Sep 4, 2024
1 parent 79469cb commit 285efe3
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 36 deletions.
75 changes: 49 additions & 26 deletions scripts/evaluate_md.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
#%%
from embed_time.evaluate_static import ModelEvaluator
import pandas as pd
import matplotlib.pyplot as plt
import re
import os
import torch
import re
from embed_time.evaluate_static import ModelEvaluator

import re
def get_checkpoint_dirs():
parent_dir = '/mnt/efs/dlmbl/G-et/checkpoints/static/Matteo/'
checkpoint_dirs = os.listdir(parent_dir)
checkpoint_dirs = [os.path.join(parent_dir, d) for d in checkpoint_dirs]
checkpoint_dirs = [d for d in checkpoint_dirs if os.path.isdir(d)]

def get_timestamp(checkpoint_dir):
filename = checkpoint_dir.split('/')[-1]
match = re.search(r'(\d{8}_\d{4})', filename)
if match:
return match.group(1)
return ''

checkpoint_dirs = sorted(checkpoint_dirs, key=lambda x: get_timestamp(x))
checkpoint_dirs = [d for d in checkpoint_dirs if get_timestamp(d) > '20240903_2130']
print("number of checkpoints:", len(checkpoint_dirs))

return checkpoint_dirs

def parse_checkpoint_dir(checkpoint_dir):
filename = checkpoint_dir.split('/')[-1]
Expand All @@ -17,10 +31,9 @@ def parse_checkpoint_dir(checkpoint_dir):
if model_match:
result['model'] = model_match.group(1)

# Extract other parameters
for param in params:
if param == 'model':
continue # we've already handled this
continue
match = re.search(rf'{param}_([^_]+)', filename)
if match:
value = match.group(1)
Expand All @@ -38,23 +51,33 @@ def parse_checkpoint_dir(checkpoint_dir):

return result

# model checkpoint directory
checkpoint_dir = '/mnt/efs/dlmbl/G-et/checkpoints/static/Matteo/20240903_2130_VAE_ResNet18_crop_size_64_nc_4_z_dim_30_lr_0.0001_beta_1e-05_transform_min_loss_L1_benchmark'
# variant parameters
config = parse_checkpoint_dir(checkpoint_dir)
def generate_config(checkpoint_dir):
config = parse_checkpoint_dir(checkpoint_dir)

# Add invariant parameters
config.update({
'checkpoint_dir': checkpoint_dir,
'parent_dir': '/mnt/efs/dlmbl/S-md/',
'channels': [0, 1, 2, 3],
'yaml_file_path': '/mnt/efs/dlmbl/G-et/yaml/dataset_info_20240901_155625.yaml',
'output_dir': os.path.join('/home/S-md/embed_time/scripts/latent', checkpoint_dir.split('/')[-1]),
'sampling_number': 3,
'csv_file': '/mnt/efs/dlmbl/G-et/csv/' + config['csv_file'],
'batch_size': 16,
'num_workers': 8,
'metadata_keys': ['gene', 'barcode', 'stage', 'cell_idx'],
'images_keys': ['cell_image']
})

return config

# invariant parameters
config['checkpoint_dir'] = checkpoint_dir
config['parent_dir'] = '/mnt/efs/dlmbl/S-md/'
config['channels'] = [0, 1, 2, 3]
config['yaml_file_path'] = '/mnt/efs/dlmbl/G-et/yaml/dataset_info_20240901_155625.yaml'
config['output_dir'] = os.path.join('/home/S-md/embed_time/scripts/latent', checkpoint_dir.split('/')[-1])
config['sampling_number'] = 3
config['csv_file'] = '/mnt/efs/dlmbl/G-et/csv/' + config['csv_file']
config['batch_size'] = 16
config['num_workers'] = 8
config['metadata_keys'] = ['gene', 'barcode', 'stage']
config['images_keys'] = ['cell_image']
def run_evaluator(checkpoint_dir):
config = generate_config(checkpoint_dir)
return ModelEvaluator(config)

# Initialize ModelEvaluator
evaluator = ModelEvaluator(config)
# Example usage
if __name__ == "__main__":
# checkpoint_dir = '/mnt/efs/dlmbl/G-et/checkpoints/static/Matteo/20240903_2130_VAE_ResNet18_crop_size_64_nc_4_z_dim_30_lr_0.0001_beta_1e-05_transform_min_loss_L1_benchmark'
checkpoint_dirs = get_checkpoint_dirs()
for checkpoint_dir in checkpoint_dirs:
run_evaluator(checkpoint_dir)
221 changes: 211 additions & 10 deletions src/embed_time/evaluate_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
import yaml
import argparse
import piq
from sklearn.decomposition import PCA
from matplotlib.colors import ListedColormap
import umap
from sklearn.preprocessing import StandardScaler
import seaborn as sns

loss_ssim = piq.SSIMLoss()

Expand All @@ -26,8 +31,22 @@ def __init__(self, config):
self.model = self._init_model()
self.dataset_mean, self.dataset_std = self._read_config()
self.output_dir = self._create_output_dir()
self.train_df = self._evaluate('train')
self.val_df = self._evaluate('val')
self.train_df, train_loss, train_mse, train_kld = self._evaluate('train')
self.val_df, val_loss, val_mse, val_kld = self._evaluate('val')
self.create_pca_plots(self.train_df, self.val_df)
self.create_umap_plots(self.train_df, self.val_df)
accuracy = self.classifier(self.train_df, self.val_df)
# create a csv file with the results
results = pd.DataFrame({
'train_loss': [train_loss],
'train_mse': [train_mse],
'train_kld': [train_kld],
'val_loss': [val_loss],
'val_mse': [val_mse],
'val_kld': [val_kld],
'classification_accuracy': [accuracy]
})
results.to_csv(os.path.join(self.config['output_dir'], 'results.csv'), index=False)

def _init_model(self):
model = None # Initialize model to None
Expand Down Expand Up @@ -60,7 +79,7 @@ def _load_checkpoint(self, checkpoint_path, model):
model.load_state_dict(checkpoint['model_state_dict'])
return model, checkpoint['epoch']

def _create_dataloader(self, split):
def _create_dataloader(self, split, drop_last=True):
dataset = ZarrCellDataset(
self.config['parent_dir'],
self.config['csv_file'],
Expand All @@ -77,12 +96,12 @@ def _create_dataloader(self, split):
batch_size=self.config['batch_size'],
shuffle=False,
num_workers=self.config['num_workers'],
drop_last=drop_last,
collate_fn=collate_wrapper(self.config['metadata_keys'], self.config['images_keys'])
)

def _create_output_dir(self):
output_dir = os.path.join(self.config['output_dir'], 'reconstructed_images')
os.makedirs(output_dir, exist_ok=True)
output_dir = os.makedirs(self.config['output_dir'], exist_ok=True)
return output_dir

def _evaluate_model(self, dataloader):
Expand All @@ -95,7 +114,7 @@ def _evaluate_model(self, dataloader):
for batch_idx, batch in enumerate(dataloader):
data = batch['cell_image'].to(self.device)
metadata = [batch[key] for key in self.config['metadata_keys']]

if self.config['model'] == 'VAE_ResNet18_Linear':
recon_batch, _, mu, logvar = self.model(data)
elif self.config['model'] == 'VAE_ResNet18':
Expand All @@ -108,7 +127,7 @@ def _evaluate_model(self, dataloader):
elif self.config['loss'] == "SSIM":
# normalize x for ssim (remember shape is BxCxHxW)
x_norm = (data - data.min()) / (data.max() - data.min())
recon_x_norm = (recon_batch - recon_batch.min()) / (recon_x.max() - recon_x.min())
recon_x_norm = (recon_batch - recon_batch.min()) / (recon_batch.max() - recon_batch.min())
ssim = loss_ssim(recon_x_norm, x_norm)
RECON = F.l1_loss(recon_batch, data, reduction='mean') + ssim * 0.5
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
Expand Down Expand Up @@ -141,7 +160,11 @@ def _evaluate_model(self, dataloader):
return avg_loss, avg_mse, avg_kld, latent_vectors, all_metadata

def _evaluate(self, split):
dataloader = self._create_dataloader(split)
if split == 'val':
drop_last = False
else:
drop_last = True
dataloader = self._create_dataloader(split, drop_last)
print(f"Evaluating on {split} data...")
loss, mse, kld, latents, metadata = self._evaluate_model(dataloader)
print(f"{split.capitalize()} - Loss: {loss:.4f}, MSE: {mse:.4f}, KLD: {kld:.4f}")
Expand All @@ -159,7 +182,7 @@ def _evaluate(self, split):
# Save the latent vectors
df.to_csv(os.path.join(self.config['output_dir'], f"{split}_{self.config['sampling_number']}_latent_vectors.csv"), index=False)

return df
return df, loss, mse, kld

def _save_image(self, data, recon, output_dir):
image_idx = np.random.randint(data.shape[0])
Expand Down Expand Up @@ -192,7 +215,185 @@ def _save_image(self, data, recon, output_dir):
plt.savefig(os.path.join(output_dir, filename), dpi=300, bbox_inches='tight')
plt.close(fig) # Close the figure to free up memory

#
# add pca and umap
def create_pca_plots(self, train_latents, val_latents):

# Step 0: split the datasets into label data and latent data
train_df = train_latents[['gene', 'barcode', 'stage']]
val_df = val_latents[['gene', 'barcode', 'stage']]
train_latents = train_latents.drop(columns=['gene', 'barcode', 'stage'])
val_latents = val_latents.drop(columns=['gene', 'barcode', 'stage'])

# Step 1: Perform PCA
pca = PCA(n_components=2)
train_latents_pca = pca.fit_transform(train_latents)
val_latents_pca = pca.transform(val_latents)

# Step 2: Prepare the plot
fig, axes = plt.subplots(1,2, figsize=(25, 10))

# Helper function to create a color map
def create_color_map(n):
return ListedColormap(plt.cm.viridis(np.linspace(0, 1, n)))
# Assuming you have 3 unique labels

# Convert 'gene' to categorical and get codes
train_df['gene'] = pd.Categorical(train_df['gene'])
val_df['gene'] = pd.Categorical(val_df['gene'])
train_gene_codes = train_df['gene'].cat.codes
val_gene_codes = val_df['gene'].cat.codes

# Step 3: Plot PCA for the training set
ax = axes[0]
scatter = ax.scatter(train_latents_pca[:, 0], train_latents_pca[:, 1],
c=train_gene_codes,
cmap=create_color_map(len(train_df['gene'].cat.categories)),
s=25, alpha=0.5)
ax.set_title('PCA of Training Latents', fontsize=40)
ax.set_xlabel('PCA Component 1', fontsize=20)
ax.set_ylabel('PCA Component 2', fontsize=20)
cbar = fig.colorbar(scatter, ax=ax)
cbar.set_ticks(range(len(train_df['gene'].cat.categories)))
cbar.set_ticklabels(train_df['gene'].cat.categories, fontsize=20)

# Step 4: Plot PCA for the validation set
ax = axes[1]
scatter = ax.scatter(val_latents_pca[:, 0], val_latents_pca[:, 1],
c=val_gene_codes,
cmap=create_color_map(len(val_df['gene'].cat.categories)),
s=25, alpha=0.5)
ax.set_title('PCA of Validation Latents', fontsize=40)
ax.set_xlabel('PCA Component 1', fontsize=20)
ax.set_ylabel('PCA Component 2', fontsize=20)
cbar = fig.colorbar(scatter, ax=ax)
cbar.set_ticks(range(len(val_df['gene'].cat.categories)))
cbar.set_ticklabels(val_df['gene'].cat.categories, fontsize=20)

print(f"Unique labels in training set: {np.unique(train_df['gene'])}")
print(f"Unique labels in validation set: {np.unique(val_df['gene'])}")

# Adjust layout to prevent overlap
plt.tight_layout()

# Step 5: Save the plot in the output directory
plt.savefig(os.path.join(self.config['output_dir'], 'pca_plot.png'))
plt.close(fig) # Close the figure to free up memory

def create_umap_plots(self, train_latents, val_latents):

# Step 0: split the datasets into label data and latent data
train_df = train_latents[['gene', 'barcode', 'stage']]
val_df = val_latents[['gene', 'barcode', 'stage']]
train_latents = train_latents.drop(columns=['gene', 'barcode', 'stage'])
val_latents = val_latents.drop(columns=['gene', 'barcode', 'stage'])

# Scale the data
Scaler = StandardScaler()
train_latents = Scaler.fit_transform(train_latents)
val_latents = Scaler.transform(val_latents)

# Initialize UMAP
umap_reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)

# Fit and transform the training data
train_latents_umap = umap_reducer.fit_transform(train_latents)
# Transform the validation data using the same UMAP model
val_latents_umap = umap_reducer.transform(val_latents)

fig, axes = plt.subplots(1,2, figsize=(25, 10))

def create_color_map(n):
return ListedColormap(plt.cm.viridis(np.linspace(0, 1, n)))

# Convert 'gene' to categorical and get codes
train_df['gene'] = pd.Categorical(train_df['gene'])
val_df['gene'] = pd.Categorical(val_df['gene'])
train_gene_codes = train_df['gene'].cat.codes
val_gene_codes = val_df['gene'].cat.codes

# Step 5: Plot UMAP for the training set
ax = axes[0]
scatter = ax.scatter(train_latents_umap[:, 0], train_latents_umap[:, 1],
c=train_gene_codes,
cmap=create_color_map(len(train_df['gene'].cat.categories)),
s=25, alpha=0.5)
ax.set_title('UMAP of Training Latents', fontsize=40)
ax.set_xlabel('UMAP Component 1', fontsize=20)
ax.set_ylabel('UMAP Component 2', fontsize=20)
cbar = fig.colorbar(scatter, ax=ax)
cbar.set_ticks(range(len(train_df['gene'].cat.categories)))
cbar.set_ticklabels(train_df['gene'].cat.categories, fontsize=20)

# Step 6: Plot UMAP for the validation set
ax = axes[1]
scatter = ax.scatter(val_latents_umap[:, 0], val_latents_umap[:, 1],
c=val_gene_codes,
cmap=create_color_map(len(val_df['gene'].cat.categories)),
s=25, alpha=0.5)
ax.set_title('UMAP of Validation Latents', fontsize=40)
ax.set_xlabel('UMAP Component 1', fontsize=20)
ax.set_ylabel('UMAP Component 2', fontsize=20)
cbar = fig.colorbar(scatter, ax=ax)
cbar.set_ticks(range(len(val_df['gene'].cat.categories)))
cbar.set_ticklabels(val_df['gene'].cat.categories, fontsize=20)

# Adjust layout to prevent overlap
plt.tight_layout()

# Step 5: Save the plot in the output directory
plt.savefig(os.path.join(self.config['output_dir'], 'umap_plot.png'))
plt.close(fig) # Close the figure to free up memory

# write a function for random forest classifier
def classifier(self, train_latents, val_latents):
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
# Step 0: split the datasets into label data and latent data
train_df = train_latents[['gene', 'barcode', 'stage']]
val_df = val_latents[['gene', 'barcode', 'stage']]
train_latents = train_latents.drop(columns=['gene', 'barcode', 'stage'])
val_latents = val_latents.drop(columns=['gene', 'barcode', 'stage'])

# Scale the data
Scaler = StandardScaler()
train_latents = Scaler.fit_transform(train_latents)
val_latents = Scaler.transform(val_latents)

# Initialize the Random Forest Classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)

# Fit the model on the training data
clf.fit(train_latents, train_df['gene'])

# Predict the labels for the validation data
val_predictions = clf.predict(val_latents)

# Calculate the accuracy of the model
accuracy = accuracy_score(val_df['gene'], val_predictions)

# Make a confusion matrix
cm = confusion_matrix(val_df['gene'], val_predictions)

# Convert 'gene' to categorical and get codes
train_df['gene'] = pd.Categorical(train_df['gene'])
val_df['gene'] = pd.Categorical(val_df['gene'])
# train_gene_codes = train_df['gene'].cat.codes
# val_gene_codes = val_df['gene'].cat.codes

# Print the accuracy and confusion matrix
plt.figure()
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=val_df['gene'].cat.categories,
yticklabels=val_df['gene'].cat.categories)
plt.title('Confusion Matrix', fontsize=20)
plt.xlabel('Predicted Labels', fontsize=15)
plt.ylabel('True Labels', fontsize=15)
plt.tight_layout()
plt.savefig(os.path.join(self.config['output_dir'], 'rf_confusion_matrix.png'))
plt.close()

return accuracy


def parse_args():
parser = argparse.ArgumentParser(description="Model Evaluation Script")
Expand Down

0 comments on commit 285efe3

Please sign in to comment.