Skip to content

Commit 3846d08

Browse files
committed
update scripts
1 parent 0c74fd4 commit 3846d08

12 files changed

+1431
-158
lines changed

experiments/evaluate_clustering/evaluate_clustering.py

Lines changed: 92 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def stashcsv(df, name, **kws):
6060
savecsv(df, name, pathname=save_path / "outs", **kws)
6161

6262

63-
def sort_mg(mg, level_names, class_order=CLASS_ORDER):
63+
def sort_mg(mg, level_names, class_order=CLASS_ORDER, ascending=True):
6464
"""Required sorting prior to plotting the dendrograms
6565
6666
Parameters
@@ -85,7 +85,7 @@ def sort_mg(mg, level_names, class_order=CLASS_ORDER):
8585
meta[f"{sc}_{co}_order"] = meta[sc].map(class_value)
8686
total_sort_by.append(f"{sc}_{co}_order")
8787
total_sort_by.append(sc)
88-
mg = mg.sort_values(total_sort_by, ascending=False)
88+
mg = mg.sort_values(total_sort_by, ascending=ascending) # TODO used to be False!
8989
return mg
9090

9191

@@ -96,13 +96,14 @@ def plot_adjacencies(full_mg, axs, lowest_level=7):
9696
for level in np.arange(lowest_level + 1):
9797
ax = axs[0, level]
9898
adj = binarize(full_mg.adj)
99+
# [f"lvl{level}_labels", f"merge_class_sf_order", "merge_class"]
99100
_, _, top, _ = adjplot(
100101
adj,
101102
ax=ax,
102103
plot_type="scattermap",
103104
sizes=(0.5, 0.5),
104105
sort_class=level_names[: level + 1],
105-
item_order=[f"{CLASS_KEY}_{CLASS_ORDER}_order", CLASS_KEY, CLASS_ORDER],
106+
# item_order=[f"{CLASS_KEY}_{CLASS_ORDER}_order", CLASS_KEY, CLASS_ORDER],
106107
class_order=CLASS_ORDER,
107108
meta=full_mg.meta,
108109
palette=CLASS_COLOR_DICT,
@@ -495,37 +496,106 @@ def plot_clustering_results(
495496
n_init = 256
496497
max_hops = 16
497498
allow_loops = False
499+
include_reverse = False
498500
walk_spec = f"gt={graph_type}-n_init={n_init}-hops={max_hops}-loops={allow_loops}"
499501
walk_meta = pd.read_csv(
500-
f"maggot_models/experiments/walk_sort/outs/meta_w_order-{walk_spec}.csv",
502+
f"maggot_models/experiments/walk_sort/outs/meta_w_order-{walk_spec}-include_reverse={include_reverse}.csv",
501503
index_col=0,
502504
)
503505
meta["median_node_visits"] = walk_meta["median_node_visits"]
504506

505507
# %%
506508
# plot results
507509
lowest_level = 7 # last level to show for dendrograms, adjacencies
508-
plot_clustering_results(
509-
adj,
510-
meta,
511-
basename,
512-
lowest_level=lowest_level,
513-
show_adjs=True,
514-
show_singles=False,
515-
make_flippable=False,
516-
)
510+
# plot_clustering_results(
511+
# adj,
512+
# meta,
513+
# basename,
514+
# lowest_level=lowest_level,
515+
# show_adjs=True,
516+
# show_singles=False,
517+
# make_flippable=False,
518+
# )
517519

518520
#%%
519-
# lowest_level = 7
520-
# mg = MetaGraph(adj, meta)
521-
# level_names = [f"lvl{i}_labels" for i in range(lowest_level + 1)]
522-
# mg = sort_mg(mg, level_names)
523-
# fig, axs = plt.subplots(
524-
# 2, lowest_level + 1, figsize=10 * np.array([lowest_level + 1, 2])
525-
# )
521+
lowest_level = 7
522+
mg = MetaGraph(adj, meta)
523+
level_names = [f"lvl{i}_labels" for i in range(lowest_level + 1)]
524+
mg = sort_mg(mg, level_names)
525+
fig, axs = plt.subplots(
526+
2, lowest_level + 1, figsize=10 * np.array([lowest_level + 1, 2])
527+
)
526528
# for level in np.arange(lowest_level + 1):
527-
# plot_adjacencies(mg, axs, lowest_level=lowest_level)
528-
# stashfig(f"adjplots-lowest={lowest_level}" + basename, fmt="png")
529+
plot_adjacencies(mg, axs, lowest_level=lowest_level)
530+
stashfig(f"adjplots-lowest={lowest_level}" + basename, fmt="png")
531+
#%%
532+
from matplotlib.colors import ListedColormap
533+
534+
sort_meta = mg.meta.copy()
535+
fig, axs = plt.subplots(
536+
1, 2 * (lowest_level + 1), figsize=(10, 10), gridspec_kw=dict(wspace=0)
537+
)
538+
539+
# meta = mg.meta
540+
# sort_class = level_names + ["merge_class"]
541+
# class_order = [class_order]
542+
# total_sort_by = []
543+
# for sc in sort_class:
544+
# for co in class_order:
545+
# class_value = meta.groupby(sc)[co].mean()
546+
# meta[f"{sc}_{co}_order"] = meta[sc].map(class_value)
547+
# total_sort_by.append(f"{sc}_{co}_order")
548+
# total_sort_by.append(sc)
549+
# mg = mg.sort_values(total_sort_by, ascending=False)
550+
551+
552+
for level in np.arange(lowest_level + 1)[::-1]:
553+
# sort_meta = sort_meta.sort_values(
554+
# [
555+
# f"lvl{level}_labels_{CLASS_ORDER}_order",
556+
# f"lvl{level}_labels",
557+
# f"{CLASS_KEY}_{CLASS_ORDER}_order",
558+
# CLASS_KEY,
559+
# ],
560+
# ascending=True,
561+
# )
562+
sort_meta["inds"] = range(len(sort_meta))
563+
firsts = sort_meta.groupby(f"lvl{level}_labels", sort=False)["inds"].first()
564+
565+
# mean_visits = sort_meta.groupby(
566+
# [
567+
# f"lvl{level}_labels",
568+
# f"{CLASS_KEY}_{CLASS_ORDER}_order",
569+
# ]
570+
# )["median_node_visit"].mean()
571+
# meta.groupby([leaf_key, "merge_class"], sort=False).size()
572+
573+
sort_meta[CLASS_KEY].values
574+
color_dict = CLASS_COLOR_DICT
575+
classes = sort_meta["merge_class"].values
576+
uni_classes = np.unique(sort_meta["merge_class"])
577+
class_map = dict(zip(uni_classes, range(len(uni_classes))))
578+
color_sorted = np.vectorize(color_dict.get)(uni_classes)
579+
lc = ListedColormap(color_sorted)
580+
class_indicator = np.vectorize(class_map.get)(classes)
581+
class_indicator = class_indicator.reshape(len(classes), 1)
582+
ax = axs[2 * level + 1]
583+
sns.heatmap(
584+
class_indicator,
585+
cmap=lc,
586+
cbar=False,
587+
yticklabels=False,
588+
# xticklabels=False,
589+
square=False,
590+
ax=ax,
591+
)
592+
ax.set(xlabel=level, xticks=[])
593+
594+
ax = axs[2 * level]
595+
ax.axis("off")
596+
ax.set(ylim=axs[2 * level + 1].get_ylim())
597+
for first_ind in firsts:
598+
ax.axhline(first_ind, color="grey", linestyle="--", alpha=1, linewidth=1)
529599

530600
# %% [markdown]
531601
# # ##

experiments/nblast/nblast.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#%%
2+
import logging
3+
import time
4+
5+
import numpy as np
6+
import pandas as pd
7+
import pymaid
8+
from sklearn.preprocessing import QuantileTransformer
9+
from pathlib import Path
10+
11+
from graspologic.utils import symmetrize
12+
from navis import NeuronList, TreeNeuron, nblast_allbyall
13+
from src.data import load_metagraph
14+
from src.pymaid import start_instance
15+
16+
# REF: https://stackoverflow.com/questions/35326814/change-level-logged-to-ipython-jupyter-notebook
17+
# logger = logging.getLogger()
18+
# # assert len(logger.handlers) == 1
19+
# handler = logger.handlers[0]
20+
# handler.setLevel(logging.ERROR)
21+
22+
t0 = time.time()
23+
24+
# for pymaid to pull neurons
25+
start_instance()
26+
27+
out_dir = Path("maggot_models/experiments/nblast/outs")
28+
29+
#%% load connectivity data
30+
mg = load_metagraph("G")
31+
meta = mg.meta
32+
33+
#%% define some functions
34+
35+
36+
def pairwise_nblast(neuron_ids, point_thresh=5):
37+
neuron_ids = [int(n) for n in neuron_ids]
38+
neurons = pymaid.get_neuron(neuron_ids) # load in with pymaid
39+
40+
# HACK: I am guessing there is a better way to do the below?
41+
# TODO: I was also getting some errors about neurons with more that one soma, so I
42+
# threw them out for now.
43+
treenode_tables = []
44+
for neuron_id, neuron in zip(neuron_ids, neurons):
45+
treenode_table = pymaid.get_treenode_table(neuron, include_details=False)
46+
treenode_tables.append(treenode_table)
47+
48+
success_neurons = []
49+
tree_neurons = []
50+
for neuron_id, treenode_table in zip(neuron_ids, treenode_tables):
51+
treenode_table.rename(columns={"parent_node_id": "parent_id"}, inplace=True)
52+
53+
tree_neuron = TreeNeuron(treenode_table)
54+
if (tree_neuron.soma is not None) and (len(tree_neuron.soma) > 1):
55+
print(f"Neuron {neuron_id} has more than one soma, removing")
56+
elif len(treenode_table) < point_thresh:
57+
print(f"Neuron {neuron_id} has fewer than {point_thresh} points, removing")
58+
else:
59+
tree_neurons.append(tree_neuron)
60+
success_neurons.append(neuron_id)
61+
62+
tree_neurons = NeuronList(tree_neurons)
63+
print(f"{len(tree_neurons)} neurons ready for NBLAST")
64+
65+
currtime = time.time()
66+
# NOTE: I've had too modify original code to allow smat=None
67+
# NOTE: this only works when normalized=False also
68+
scores = nblast_allbyall(tree_neurons, smat=None, normalized=False, progress=True)
69+
print(f"{time.time() - currtime:.3f} elapsed to run NBLAST.")
70+
71+
scores = pd.DataFrame(
72+
data=scores.values, index=success_neurons, columns=success_neurons
73+
)
74+
75+
return scores
76+
77+
78+
def postprocess_nblast(scores):
79+
distance = scores.values # the raw nblast scores are dissimilarities/distances
80+
sym_distance = symmetrize(distance) # the raw scores are not symmetric
81+
# make the distances between 0 and 1
82+
sym_distance /= sym_distance.max()
83+
sym_distance -= sym_distance.min()
84+
# and then convert to similarity
85+
morph_sim = 1 - sym_distance
86+
87+
# rank transform the similarities
88+
# NOTE this is very different from what native NBLAST does and could likely be
89+
# improved upon a lot. I did this becuase it seemed like a quick way of accounting
90+
# for difference in scale for different neurons as well as the fact that the raw
91+
# distribution of similaritys was skewed low (very few small values)
92+
quant = QuantileTransformer()
93+
indices = np.triu_indices_from(morph_sim, k=1)
94+
transformed_vals = quant.fit_transform(morph_sim[indices].reshape(-1, 1))
95+
transformed_vals = np.squeeze(transformed_vals)
96+
# this is a discrete version of PTR basically
97+
ptr_morph_sim = np.ones_like(morph_sim)
98+
ptr_morph_sim[indices] = transformed_vals
99+
ptr_morph_sim[indices[::-1]] = transformed_vals
100+
101+
ptr_morph_sim = pd.DataFrame(
102+
data=ptr_morph_sim, index=scores.index, columns=scores.columns
103+
)
104+
105+
return ptr_morph_sim
106+
107+
108+
#%% run nblast
109+
for side in ["left", "right"]:
110+
print(f"Processing side: {side}")
111+
side_meta = meta[meta[side]]
112+
113+
scores = pairwise_nblast(side_meta.index.values)
114+
scores.to_csv(out_dir / f"{side}-nblast-scores.csv")
115+
116+
similarity = postprocess_nblast(scores)
117+
similarity.to_csv(out_dir / f"{side}-nblast-similarities.csv")
118+
print()
119+
120+
#%%
121+
print("\n\n")
122+
print(f"{time.time() - t0:.3f} elapsed for whole script.")

experiments/walk_sort/generate_walks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def main(
102102
("dVNC",),
103103
("dSEZ",),
104104
("RGN",),
105-
("m`otor-PaN", "motor-MN", "motor-AN", "motor-VAN"),
105+
("motor-PaN", "motor-MN", "motor-AN", "motor-VAN"),
106106
]
107107
stop_names = ["dVNC", "dSEZ", "RGN", "motor"]
108108

experiments/walk_sort/walk_sort.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def stashcsv(df, name, **kws):
5454
n_init = 256
5555
max_hops = 16
5656
allow_loops = False
57+
include_reverse = False
5758
walk_path = "maggot_models/experiments/walk_sort/outs/walks-"
5859
walk_spec = f"gt={graph_type}-n_init={n_init}-hops={max_hops}-loops={allow_loops}"
5960
forward_walk_path = walk_path + walk_spec + "-reverse=False" + ".txt"
@@ -71,15 +72,15 @@ def process_paths(walk_path):
7172
paths.remove("")
7273
print(f"# of paths after removing duplicates: {len(paths)}")
7374

74-
n_subsample = len(paths) # 2 ** 14
75-
choice_inds = np.random.choice(len(paths), n_subsample, replace=False)
76-
new_paths = []
77-
for i in range(len(paths)):
78-
if i in choice_inds:
79-
new_paths.append(paths[i])
80-
paths = new_paths
75+
# n_subsample = len(paths) # 2 ** 14
76+
# choice_inds = np.random.choice(len(paths), n_subsample, replace=False)
77+
# new_paths = []
78+
# for i in range(len(paths)):
79+
# if i in choice_inds:
80+
# new_paths.append(paths[i])
81+
# paths = new_paths
8182

82-
print(f"# of paths after subsampling: {len(paths)}")
83+
# print(f"# of paths after subsampling: {len(paths)}")
8384
paths = [path.split(" ") for path in paths]
8485
paths = [[int(node) for node in path] for path in paths]
8586
# all_nodes = set()
@@ -95,23 +96,23 @@ def process_paths(walk_path):
9596

9697

9798
# %%
98-
99+
all_nodes = set()
99100
node_visits = {}
100101
for path in forward_paths:
101102
for i, node in enumerate(path):
102103
if node not in node_visits:
103104
node_visits[node] = []
104105
node_visits[node].append(i / (len(path) - 1))
106+
[[all_nodes.add(node) for node in path] for path in forward_paths]
105107

106-
for path in backward_paths:
107-
for i, node in enumerate(path):
108-
if node not in node_visits:
109-
node_visits[node] = []
110-
node_visits[node].append(1 - (i / (len(path) - 1)))
108+
if include_reverse:
109+
for path in backward_paths:
110+
for i, node in enumerate(path):
111+
if node not in node_visits:
112+
node_visits[node] = []
113+
node_visits[node].append(1 - (i / (len(path) - 1)))
114+
[[all_nodes.add(node) for node in path] for path in backward_paths]
111115

112-
all_nodes = set()
113-
[[all_nodes.add(node) for node in path] for path in forward_paths]
114-
[[all_nodes.add(node) for node in path] for path in backward_paths]
115116
uni_nodes = np.unique(list(all_nodes))
116117

117118
median_node_visits = {}
@@ -128,7 +129,9 @@ def process_paths(walk_path):
128129
median_class_visits[node_class] = np.median(all_visits_flat)
129130
meta["median_class_visits"] = meta["merge_class"].map(median_class_visits)
130131

131-
meta.to_csv(f"maggot_models/experiments/walk_sort/outs/meta_w_order-{walk_spec}.csv")
132+
meta.to_csv(
133+
f"maggot_models/experiments/walk_sort/outs/meta_w_order-{walk_spec}-include_reverse={include_reverse}.csv"
134+
)
132135

133136
print(f"# of nodes: {len(meta)}")
134137
unvisit_meta = meta[meta["median_node_visits"].isna()]

0 commit comments

Comments
 (0)