Skip to content

Commit 482df6a

Browse files
committed
stash
1 parent 85c220f commit 482df6a

27 files changed

+34065
-6789
lines changed

data/process_scripts/process_maggot_brain_connectome_2022-11-03.py

Lines changed: 1059 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#%%
2+
3+
import datetime
4+
import time
5+
6+
import numpy as np
7+
import pandas as pd
8+
from graspologic.utils import is_fully_connected
9+
from src.data import join_node_meta, load_maggot_graph
10+
11+
t0 = time.time()
12+
13+
#%%
14+
def join_index_as_indicator(index, name):
15+
series = pd.Series(index=index, data=np.ones(len(index), dtype=bool), name=name)
16+
join_node_meta(series, overwrite=True, fillna=False)
17+
18+
19+
#%%
20+
21+
mg = load_maggot_graph()
22+
nodes = mg.nodes.copy()
23+
24+
nodes = nodes[nodes["brain_and_inputs"] | nodes["accessory_neurons"]]
25+
print("Number of brain, input, and accessory nodes:", len(nodes))
26+
join_index_as_indicator(nodes.index, "considered")
27+
28+
nodes = nodes[~nodes["very_incomplete"]]
29+
nodes = nodes[~nodes["partially_differentiated"]]
30+
nodes = nodes[~nodes["motor"]]
31+
print(
32+
"Number of nodes after removing incomplete, partially differentiated, and motor neurons:",
33+
len(nodes),
34+
)
35+
join_index_as_indicator(nodes.index, "selected")
36+
37+
mg = mg.node_subgraph(nodes.index)
38+
39+
mg.to_largest_connected_component()
40+
41+
print("Number of nodes after taking LCC:", len(mg))
42+
43+
44+
print("Is fully connected:", is_fully_connected(mg.sum.adj))
45+
46+
join_index_as_indicator(mg.nodes.index, "selected_lcc")
47+
48+
#%%
49+
elapsed = time.time() - t0
50+
delta = datetime.timedelta(seconds=elapsed)
51+
print("----")
52+
print(f"Script took {delta}")
53+
print(f"Completed at {datetime.datetime.now()}")
54+
print("----")

experiments/flow/flow.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pandas as pd
66

7-
from giskard.flow import rank_graph_match_flow, rank_signal_flow, signal_flow
7+
from giskard.flow import rank_signal_flow, signal_flow
88
from src.data import join_node_meta, load_maggot_graph
99

1010
t0 = time.time()
@@ -13,12 +13,15 @@
1313
#%%
1414
print("Loading data...")
1515
mg = load_maggot_graph()
16-
nodes = mg.nodes.copy()
16+
mg = mg.node_subgraph(mg.nodes[mg.nodes["selected_lcc"]].index)
1717
mg = mg.sum
18-
mg.to_largest_connected_component(verbose=True)
1918
index = mg.nodes.index
2019
adj = mg.adj
2120
meta = mg.nodes
21+
22+
# nodes = mg.nodes.copy()
23+
# mg.to_largest_connected_component(verbose=True)
24+
2225
# #%%
2326
# sort_meta = meta.copy()
2427
# sort_meta.sort_values(

experiments/flow_row/flow_row.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#%%
2+
import pymaid
23
import matplotlib.pyplot as plt
34
import numpy as np
45
import seaborn as sns
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#%%
2+
import pymaid
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import seaborn as sns
6+
from scipy.ndimage import gaussian_filter1d
7+
from src.data import load_maggot_graph, load_palette
8+
from src.io import savefig
9+
from src.visualization import adjplot, set_theme
10+
11+
set_theme(font_scale=1.25)
12+
13+
14+
def stashfig(name, **kws):
15+
savefig(
16+
name,
17+
pathname="./maggot_models/experiments/flow_row/figs",
18+
**kws,
19+
)
20+
savefig(
21+
name,
22+
pathname="./maggot_models/experiments/flow_row/figs",
23+
format="pdf",
24+
**kws,
25+
)
26+
27+
28+
def diag_indices(length, k=0):
29+
neg = False
30+
if k < 0:
31+
neg = True
32+
k = np.abs(k)
33+
inds = (np.arange(length - k), np.arange(k, length))
34+
if neg:
35+
return (inds[1], inds[0])
36+
else:
37+
return inds
38+
39+
40+
def calc_mean_by_k(ks, perm_adj):
41+
length = len(perm_adj)
42+
ps = []
43+
for k in ks:
44+
p = perm_adj[diag_indices(length, k)].mean()
45+
ps.append(p)
46+
return np.array(ps)
47+
48+
49+
def get_vals_by_k(ks, perm_adj):
50+
ys = []
51+
xs = []
52+
for k in ks:
53+
y = perm_adj[diag_indices(len(perm_adj), k)]
54+
ys.append(y)
55+
x = np.full(len(y), k)
56+
xs.append(x)
57+
return np.concatenate(ys), np.concatenate(xs)
58+
59+
60+
remove_missing = True
61+
mg = load_maggot_graph()
62+
63+
# %%
64+
graph_types = ["ad", "aa", "dd", "da"]
65+
graph_names = dict(
66+
zip(graph_types, [r"A $\to$ D", r"A $\to$ A", r"D $\to$ D", r"D $\to$ A"])
67+
)
68+
colors = sns.color_palette("deep", n_colors=len(graph_types))
69+
graph_type_colors = dict(zip(graph_types, colors))
70+
71+
#%%
72+
nodes = mg.nodes.copy()
73+
nodes = nodes[~nodes["sum_rank_signal_flow"].isna()]
74+
nodes.sort_values("sum_rank_signal_flow", inplace=True)
75+
nodes["order"] = range(len(nodes))
76+
mg = mg.node_subgraph(nodes.index)
77+
mg.nodes = nodes
78+
79+
80+
#%%
81+
82+
line_kws = dict(linewidth=1, linestyle="--", color="grey")
83+
84+
palette = load_palette()
85+
86+
87+
def plot_sorted_adj(graph_type, ax):
88+
adj = mg.to_edge_type_graph(graph_type).adj
89+
meta = mg.nodes
90+
_, _, top, _, = adjplot(
91+
adj,
92+
meta=meta,
93+
item_order="order",
94+
colors="simple_group",
95+
palette=palette,
96+
plot_type="scattermap",
97+
sizes=(0.5, 1),
98+
ax=ax,
99+
color=graph_type_colors[graph_type],
100+
)
101+
top.set_title(
102+
graph_names[graph_type], color=graph_type_colors[graph_type], fontsize="x-large"
103+
)
104+
ax.plot([0, len(adj)], [0, len(adj)], **line_kws)
105+
106+
107+
def plot_diag_vals(graph_type, ax, mode="values", sigma=25):
108+
adj = mg.to_edge_type_graph(graph_type).adj
109+
ks = np.arange(-len(adj) + 1, len(adj))
110+
vals = calc_mean_by_k(ks, adj)
111+
if mode == "values":
112+
sns.scatterplot(
113+
x=ks,
114+
y=vals,
115+
s=10,
116+
alpha=0.4,
117+
linewidth=0,
118+
ax=ax,
119+
color=graph_type_colors[graph_type],
120+
)
121+
elif mode == "kde":
122+
kde_vals = gaussian_filter1d(vals, sigma=sigma, mode="constant")
123+
sns.lineplot(x=ks, y=kde_vals, ax=ax, color=graph_type_colors[graph_type])
124+
upper_mass = adj[np.triu_indices_from(adj, k=1)].mean()
125+
lower_mass = adj[np.tril_indices_from(adj, k=1)].mean()
126+
upper_mass_prop = upper_mass / (upper_mass + lower_mass)
127+
lower_mass_prop = lower_mass / (upper_mass + lower_mass)
128+
upper_text = f"{upper_mass_prop:.2f}"
129+
lower_text = f"{lower_mass_prop:.2f}"
130+
ax.text(0.1, 0.8, lower_text, transform=ax.transAxes, color="black")
131+
ax.text(
132+
0.9,
133+
0.8,
134+
upper_text,
135+
ha="right",
136+
transform=ax.transAxes,
137+
color="black",
138+
)
139+
ax.axvline(0, **line_kws)
140+
ax.yaxis.set_major_locator(plt.NullLocator())
141+
ax.set_xticks([-3000, 0, 3000])
142+
ax.set_xticklabels([-3000, 0, 3000])
143+
144+
145+
fig, axs = plt.subplots(2, 4, figsize=(20, 10))
146+
for i, graph_type in enumerate(graph_types):
147+
plot_sorted_adj(graph_type, axs[0, i])
148+
plot_diag_vals(graph_type, axs[1, i], mode="kde", sigma=75)
149+
150+
ax = axs[1, 0]
151+
ax.text(0.1, 0.9, r"$p$ back", transform=ax.transAxes, color="black")
152+
ax.text(0.9, 0.9, r"$p$ fwd", transform=ax.transAxes, color="black", ha="right")
153+
fig.text(0.47, 0.05, "Distance in sorting")
154+
axs[1, 0].set_ylabel("Mean synapse mass")
155+
stashfig("adj-row-sort-by-sf")
156+
157+
#%%
158+
# from graspologic.match import GraphMatch
159+
160+
# adj = mg.sum.adj
161+
# # constructing the match matrix
162+
# match_mat = np.zeros_like(adj)
163+
# triu_inds = np.triu_indices(len(match_mat), k=1)
164+
# match_mat[triu_inds] = 1
165+
166+
# # running graph matching
167+
# np.random.seed(8888)
168+
# gm = GraphMatch(n_init=1, max_iter=100, eps=1e-6)
169+
# gm.fit(match_mat, adj)
170+
# perm_inds = gm.perm_inds_
171+
172+
# adj_matched = adj[perm_inds][:, perm_inds]
173+
# upsets = adj_matched[triu_inds[::-1]].sum()
174+
# upset_ration = upsets / adj_matched.sum()

experiments/gaussian_cluster/gaussian_cluster.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import datetime
33
import time
44
from ast import literal_eval
5-
from json import load
6-
from operator import index
75
from pathlib import Path
86

97
import matplotlib.pyplot as plt
@@ -55,8 +53,14 @@ def stashfig(name, **kws):
5553
#%% load the embedding, get the correct subset of data
5654

5755
mg = load_maggot_graph()
58-
mg = mg[mg.nodes["has_embedding"]]
56+
mg = mg.node_subgraph(mg.nodes.query("has_embedding").index)
5957
nodes = mg.nodes.copy()
58+
print(len(nodes))
59+
60+
# #%%
61+
# nodes.groupby("predicted_pair_id").size().sort_values()
62+
63+
#%%
6064
nodes = nodes[nodes.index.isin(embedding_df.index)]
6165
embedding_df = embedding_df[embedding_df.index.isin(nodes.index)]
6266
nodes = nodes.reindex(embedding_df.index)
@@ -67,7 +71,7 @@ def stashfig(name, **kws):
6771
if symmetrize_pairs:
6872
pair_groups = nodes.groupby("pair_id")
6973
for pair_id, pair_group in pair_groups:
70-
if pair_id > 1:
74+
if pair_id > 1 and len(pair_group) == 2:
7175
inds = pair_group["inds"].values
7276
pair_embeddings = embedding[inds]
7377
mean_embedding = pair_embeddings.mean(axis=0)

0 commit comments

Comments
 (0)