Skip to content

Commit

Permalink
updated script headers
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Wong committed Jan 4, 2022
1 parent 98a1844 commit a5c9ea6
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 62 deletions.
3 changes: 3 additions & 0 deletions core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
Contains much of the baseline code for the study and helper functions
"""
import csv
import glob, os
import cv2
Expand Down
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
https://github.com/eriklindernoren/PyTorch-YOLOv3 with small edits
pulled from https://github.com/eriklindernoren/PyTorch-YOLOv3 with small edits
"""
from __future__ import division
from models import *
Expand Down
66 changes: 32 additions & 34 deletions prospective.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
Script to run analyses on the prospective validation with four neuropathologists
Analyses for:
Model version 1 and version 2 benchmarking
The prospective validation with four neuropathologists
"""
from __future__ import division
from models import *
Expand Down Expand Up @@ -75,7 +77,6 @@ def runModelOnValidationImages(val_type="prospective"):
save_path = path.replace("prospective_validation_images/", "")
else:
save_path = path.replace("data/amyloid_test/", "")
# print("(%d) Image: '%s'" % (img_i, path))
validation_predictions_dict[save_path] = []
img = np.array(Image.open(path))
if detections is None:
Expand Down Expand Up @@ -389,7 +390,6 @@ def getInterraterAgreement(iou_threshold=0.50):
for a2 in annotators:
if a1 != a2 and (a1, a2) not in pairs and (a2, a1) not in pairs:
pairs.append((a1, a2))
print(pairs)
pair_map = {pair: {amyloid_class: -1 for amyloid_class in ["Cored", "CAA"]} for pair in pairs}
for annotator1, annotator2 in pairs:

Expand Down Expand Up @@ -437,7 +437,6 @@ def getInterraterAgreement(iou_threshold=0.50):
assert(a1_cored_count + a2_cored_count - overlaps_counts["Cored"] == len(a1_final_annotations["Cored"]) == len(a2_final_annotations["Cored"]))
assert(a1_CAA_count + a2_CAA_count - overlaps_counts["CAA"] == len(a1_final_annotations["CAA"]) == len(a2_final_annotations["CAA"]) )
for amyloid_class in ["Cored", "CAA"]:
print(" {}: ".format(amyloid_class), getAccuracy(a1_final_annotations[amyloid_class], a2_final_annotations[amyloid_class]))
pair_map[(annotator1, annotator2)][amyloid_class] = getAccuracy(a1_final_annotations[amyloid_class], a2_final_annotations[amyloid_class])
pickle.dump(pair_map, open("pickles/annotator_interrater_map_iou_{}.pkl".format(iou_threshold), "wb"))

Expand All @@ -446,7 +445,6 @@ def plotInterraterAgreement(iou_threshold=0.5):
Plots a heatmap from annotator_interrater_map.pkl
"""
pair_map = pickle.load(open("pickles/annotator_interrater_map_iou_{}.pkl".format(iou_threshold), "rb"))
print(pair_map)
annotators = ["NP{}".format(i) for i in range(1, 5)]
for amyloid_class in ["Cored", "CAA"]:
grid = []
Expand All @@ -461,6 +459,8 @@ def plotInterraterAgreement(iou_threshold=0.5):
except:
l.append(pair_map[(a2, a1)][amyloid_class])
grid.append(l)

print("Average agreement for {}:{}, std:{}".format(amyloid_class, np.mean([x for x in l if x != 1.0]),np.std([x for x in l if x != 1.0])))
fig, ax = plt.subplots()
im = ax.imshow(grid,vmin=0, vmax=1)
ax.set_xticks(np.arange(len(annotators)))
Expand Down Expand Up @@ -643,7 +643,6 @@ def plotTimeChart(iou_threshold=0.5):
for annotator in annotators:
x = time_map[annotator]
y = AP_map[annotator][amyloid_class][iou_threshold]
print(amyloid_class, x, y)
marker = "$*$" if amyloid_class == "Cored" else "$@$"
if amyloid_class == "Cored":
ax.scatter(x, y, s=120, marker=marker, color=color_dict[annotator], label=annotator)
Expand Down Expand Up @@ -746,7 +745,6 @@ def plotImageComparisons(val_type="prospective", overlay_labels=True, overlay_pr
y2 = int(dictionary['y2'])
color = (0,0,0)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
# cv2.putText(img, class_label, (x1,y1), font, 1.5,(0,0,0),2,cv2.LINE_AA)
cv2.putText(img, class_symbols[class_label], (x1,y1), font, 1.5,(0,0,0),2,cv2.LINE_AA)
cv2.imwrite("output/{}/".format(annotator) + val_type + "_" + img_name, img)

Expand Down Expand Up @@ -886,36 +884,36 @@ def createMergedOrConsensusBenchmark(benchmark="consensus", iou_threshold=0.5):
shutil.rmtree("output/")
os.mkdir("output/")

#analysis of model v1 and v2 predictions
for phase in ["phase1", "phase2"]:
convertPreProspectiveAnnotationsToPickle(phase=phase)
runModelOnValidationImages(val_type=phase)
for iou_threshold in np.arange(0.1, 1.0, 0.1):
compareAnnotationsToPredictions(iou_threshold=iou_threshold, annotator=phase, val_type=phase)
plotPRC(annotator=phase, val_type=phase, separate_legend=False)
plotImageComparisons(val_type=phase, overlay_labels=True, overlay_predictions=True)
plotAPsForPhases()

#prospective validation of model v2 predictions
runModelOnValidationImages(val_type="prospective")
for iou_threshold in np.arange(.1, 1, .1):
createMergedOrConsensusBenchmark(benchmark="consensus", iou_threshold=iou_threshold)
for annotator in ["consensus"] + ["NP{}".format(i) for i in range(1, 5)]:
for iou_threshold in np.arange(0.1, 1.0, 0.1):
compareAnnotationsToPredictions(iou_threshold=iou_threshold, annotator=annotator, val_type="prospective")
plotPRC(annotator=annotator, val_type="prospective")
getAnnotationOverlaps(annotator, iou_threshold=0.05)
getPrecisionsOfAnnotatorsRelativeToEachOther()
plotPrecisionsOfAnnotatorsRelativeToEachOther(plotType="aggregate")
plotAPsForProspective(plotAvgOverlay=True)
plotImageComparisons(val_type="prospective", overlay_labels=True, overlay_predictions=True)
# #analysis of model v1 and v2 predictions
# for phase in ["phase1", "phase2"]:
# convertPreProspectiveAnnotationsToPickle(phase=phase)
# runModelOnValidationImages(val_type=phase)
# for iou_threshold in np.arange(0.1, 1.0, 0.1):
# compareAnnotationsToPredictions(iou_threshold=iou_threshold, annotator=phase, val_type=phase)
# plotPRC(annotator=phase, val_type=phase, separate_legend=False)
# plotImageComparisons(val_type=phase, overlay_labels=True, overlay_predictions=True)
# plotAPsForPhases()

# #prospective validation of model v2 predictions
# runModelOnValidationImages(val_type="prospective")
# for iou_threshold in np.arange(.1, 1, .1):
# createMergedOrConsensusBenchmark(benchmark="consensus", iou_threshold=iou_threshold)
# for annotator in ["consensus"] + ["NP{}".format(i) for i in range(1, 5)]:
# for iou_threshold in np.arange(0.1, 1.0, 0.1):
# compareAnnotationsToPredictions(iou_threshold=iou_threshold, annotator=annotator, val_type="prospective")
# plotPRC(annotator=annotator, val_type="prospective")
# getAnnotationOverlaps(annotator, iou_threshold=0.05)
# getPrecisionsOfAnnotatorsRelativeToEachOther()
# plotPrecisionsOfAnnotatorsRelativeToEachOther(plotType="aggregate")
# plotAPsForProspective(plotAvgOverlay=True)
# plotImageComparisons(val_type="prospective", overlay_labels=True, overlay_predictions=True)

##other analyses
findLowPerformanceImages("Cored", "consensus", iou_threshold=0.5)
plotAllAnnotations()
getInterraterAgreement(iou_threshold=0.5)
# findLowPerformanceImages("Cored", "consensus", iou_threshold=0.5)
# plotAllAnnotations()
# getInterraterAgreement(iou_threshold=0.5)
plotInterraterAgreement()
plotTimeChart()
# plotTimeChart()



Expand Down
8 changes: 3 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
pulled from https://github.com/eriklindernoren/PyTorch-YOLOv3 with small edits
"""
from __future__ import division

from models import *
from utils.logger import *
from utils.utils import *
Expand All @@ -8,24 +10,20 @@
from utils.transforms import *
from utils.parse_config import *
from test import evaluate

from terminaltables import AsciiTable

import os
import sys
import time
import datetime
import argparse
import tqdm

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
Expand Down
3 changes: 3 additions & 0 deletions unit_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
Unit tests for key functions and analyses
"""
from __future__ import division
from models import *
from utils.utils import *
Expand Down
68 changes: 46 additions & 22 deletions validation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
"""
Script pertaining to the validation of the study after all model training
For Lise's dataset:
consensus annotations: /srv/home/lminaud/tile_seg/consensus_csv/consensus_experts/consensus_2_complete.csv
WSIs:
UCI - 24: .svs rescanned: /srv/nas/mk2/projects/alzheimers/images/UCI/UCI_svs/
UCLA/UCD: /srv/nas/mk2/projects/alzheimers/images/ADBrain/
WSIs renamed with random ID: /srv/nas/mk2/projects/alzheimers/images/processed/processed_wsi/svs/
1536 tiles:
/srv/home/lminaud/tiles_backup/
The csv name is: CAA_img_Daniel_project.csv
image_details: /srv/home/lminaud/tile_seg/image_details.csv
Script pertaining to CERAD-like analysis and speed runs
"""
from __future__ import division
from models import *
Expand All @@ -34,6 +24,7 @@
import pickle
import socket
from scipy.stats import ttest_ind
import statsmodels.stats.power as smp
from core import *

def calculatePlaqueCountsPerWSI(task, save_images=False):
Expand Down Expand Up @@ -236,13 +227,15 @@ def plotCERADStatisticalSignificance(plaque_type="Cored"):
print("{} not found in WSI plaque counts dictionary".format(WSI_name))
continue
cerad_scores_map[row["CERAD"]].append(WSI_plaque_counts[WSI_name][plaque_type])
print(cerad_scores_map)
t_test_map = {(cat1, cat2): -1 for cat1 in categories for cat2 in categories} #key: (CERAD category1, CERAD category2), value: (t-statistic, p-value)
grid = []
for key in cerad_scores_map:
l = []
for key2 in cerad_scores_map:
t, p = ttest_ind(cerad_scores_map[key], cerad_scores_map[key2])
effect_size = (np.mean(cerad_scores_map[key]) - np.mean(cerad_scores_map[key2])) / float(np.sqrt((np.std(cerad_scores_map[key])**2 + np.std(cerad_scores_map[key2])**2) / float(2))) ##Cohen's d
nobs = len(cerad_scores_map[key]) + len(cerad_scores_map[key2])
power = smp.ttest_power(effect_size, nobs=nobs, alpha=0.05, alternative='two-sided')
t_test_map[key, key2] = float(t), float(p)
l.append(float(p))
grid.append(l)
Expand All @@ -262,15 +255,12 @@ def plotCERADStatisticalSignificance(plaque_type="Cored"):
text = ax.text(j, i, "{:.2e}".format(grid[i][j]), ha="center", va="center", color="white", fontsize=11)
else:
text = ax.text(j, i, str(round(grid[i][j], 3)), ha="center", va="center", color="white", fontsize=11)

cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.tick_params(labelsize=11)
fig.tight_layout()
ax.set_title("t-test p-values", fontsize=12)
plt.savefig("figures/CERAD-t-test-p-values.png", dpi=300)



def getStain(string):
"""
Given string, will return the stain
Expand Down Expand Up @@ -373,9 +363,43 @@ def speedCheck(use_gpu=True, include_merge_and_filter=True):
print("model time spent: ", model_time_spent)
print("avg time per WSI: ", model_time_spent / float(len(WSI_directories)))
print("avg time per 1536 image: ", model_time_spent / float(num_1536))
pickle.dump(time_dict, open("pickles/run_times_use_gpu_{}_{}.pkl".format(use_gpu, hostname), "wb"))

pickle.dump(time_dict, open("pickles/run_times_use_gpu_{}_merge_and_filter_{}_{}.pkl".format(use_gpu, include_merge_and_filter, hostname), "wb"))

def calculateAvgSpeedOfTangSlidingWindow():
"""
As first described in https://www.nature.com/articles/s41467-019-10212-1,
one valid approach to counting the number of Cored and CAA pathologies is to perform a sliding window approach, and then segment the resulting heatmap:
https://github.com/keiserlab/plaquebox-paper/blob/master/3)%20Visualization%20-%20Prediction%20Confidence%20Heatmaps.ipynb
This method calculates and prints the average time to draw a heatmap per WSI according to the tqdm output
"""
tqdms = ["28/28 [2:24:17<00:00, 309.19s/it]",
"49/49 [8:55:56<00:00, 656.25s/it]",
"28/28 [2:40:59<00:00, 345.00s/it]",
"28/28 [2:45:04<00:00, 353.73s/it]",
"28/28 [2:45:54<00:00, 355.52s/it]",
"28/28 [2:33:49<00:00, 329.63s/it]",
"27/27 [2:35:41<00:00, 345.99s/it]",
"26/26 [3:13:59<00:00, 447.68s/it]",
"28/28 [3:08:24<00:00, 403.74s/it]",
"21/21 [2:04:40<00:00, 356.22s/it]",
"25/25 [2:35:53<00:00, 374.13s/it]",
"27/27 [2:26:48<00:00, 326.22s/it]",
"31/31 [3:14:01<00:00, 375.53s/it]",
"28/28 [3:23:10<00:00, 435.36s/it]",
"26/26 [2:35:15<00:00, 358.28s/it]",
"28/28 [2:47:33<00:00, 359.04s/it]",
"28/28 [2:43:00<00:00, 349.29s/it]",
"31/31 [3:25:29<00:00, 397.72s/it]",
"25/25 [2:08:55<00:00, 309.42s/it]",
"28/28 [2:56:10<00:00, 377.54s/it]"]
total_seconds = 0
for tqdm in tqdms:
time = tqdm[tqdm.find("[") + 1:tqdm.find("<")]
hours, minutes, seconds = time.split(":")
total_seconds += float(hours)*60*60 + float(minutes)*60 + float(seconds)
avg_seconds = total_seconds / float(len(tqdms))
print(avg_seconds, len(tqdms))




Expand All @@ -385,13 +409,13 @@ def speedCheck(use_gpu=True, include_merge_and_filter=True):
# comparePreMergeLabelsWithPostMerge(sample_size=100)
# calculatePlaqueCountsPerWSI(task="CERAD all", save_images=False)
# calculatePlaqueCountsPerWSI(task="lise dataset")
plotCERADVsCounts(plaque_type = "Cored", CERAD_type="CERAD")
# plotCERADVsCounts(plaque_type = "Cored", CERAD_type="CERAD")
# plotCERADVsCounts(plaque_type = "Cored", CERAD_type="Cored_MTG")
# plotCERADVsCounts(plaque_type = "CAA", CERAD_type="CAA_MTG")
plotCERADStatisticalSignificance()
# speedCheck(use_gpu=True)
# speedCheck(use_gpu=False)

# plotCERADStatisticalSignificance()
speedCheck(use_gpu=True, include_merge_and_filter=True)
speedCheck(use_gpu=False, include_merge_and_filter=True)
# calculazteAvgSpeedOfTangSlidingWindow()



Expand Down

0 comments on commit a5c9ea6

Please sign in to comment.