Skip to content

Commit 44d4a6c

Browse files
committed
push
1 parent 3b8eb3e commit 44d4a6c

File tree

153 files changed

+31650
-224
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

153 files changed

+31650
-224
lines changed

.Rhistory

Whitespace-only changes.

data/logs/2020-05-26.txt

Lines changed: 857 additions & 0 deletions
Large diffs are not rendered by default.

data/process_scripts/process_maggot_brain_connectome_2020-05-26.py

Lines changed: 730 additions & 0 deletions
Large diffs are not rendered by default.

notebooks/137.4-BDP-omni-clust.py

Lines changed: 733 additions & 0 deletions
Large diffs are not rendered by default.

notebooks/137.5-BDP-omni-clust.py

Lines changed: 856 additions & 0 deletions
Large diffs are not rendered by default.

notebooks/151.0-BDP-plot-all-class-pairs.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def stashcsv(df, name, **kws):
5151

5252

5353
graph_type = "G"
54-
master_mg = load_metagraph(graph_type, version="2020-04-23")
55-
mg = preprocess(
56-
master_mg,
57-
threshold=0,
58-
sym_threshold=False,
59-
remove_pdiff=True,
60-
binarize=False,
61-
weight="weight",
62-
)
54+
mg = load_metagraph(graph_type, version="2020-05-26")
55+
# mg = preprocess(
56+
# master_mg,
57+
# threshold=0,
58+
# sym_threshold=False,
59+
# remove_pdiff=True,
60+
# binarize=False,
61+
# weight="weight",
62+
# )
6363
meta = mg.meta
6464

6565

@@ -70,7 +70,7 @@ def stashcsv(df, name, **kws):
7070
uni_labels, counts = np.unique(labels, return_counts=True)
7171
inds = np.argsort(-counts)
7272

73-
paired = meta["Pair ID"] != -1
73+
paired = meta["pair_id"] != -1
7474

7575
fig, axs = plt.subplots(1, 2, figsize=(20, 20))
7676
ax = axs[0]

notebooks/157.1-BDP-plot_b_hat.py

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
# %% [markdown]
2+
# ##
3+
import os
4+
import warnings
5+
from itertools import chain
6+
7+
import colorcet as cc
8+
import matplotlib as mpl
9+
import matplotlib.gridspec as gridspec
10+
import matplotlib.patches as patches
11+
import matplotlib.pyplot as plt
12+
import matplotlib.transforms as transforms
13+
import networkx as nx
14+
import numpy as np
15+
import pandas as pd
16+
import seaborn as sns
17+
from joblib import Parallel, delayed
18+
from mpl_toolkits.axes_grid1 import make_axes_locatable
19+
from scipy.stats import poisson
20+
from sklearn.exceptions import ConvergenceWarning
21+
from sklearn.manifold import MDS, TSNE, Isomap
22+
from sklearn.metrics import pairwise_distances
23+
from sklearn.neighbors import NearestNeighbors
24+
from sklearn.utils.testing import ignore_warnings
25+
from tqdm.autonotebook import tqdm
26+
from umap import UMAP
27+
28+
from graspy.embed import (
29+
AdjacencySpectralEmbed,
30+
ClassicalMDS,
31+
LaplacianSpectralEmbed,
32+
OmnibusEmbed,
33+
select_dimension,
34+
selectSVD,
35+
)
36+
from graspy.models import DCSBMEstimator, SBMEstimator
37+
from graspy.plot import pairplot
38+
from graspy.utils import (
39+
augment_diagonal,
40+
binarize,
41+
pass_to_ranks,
42+
remove_loops,
43+
symmetrize,
44+
to_laplace,
45+
)
46+
from src.align import Procrustes
47+
from src.cluster import BinaryCluster, MaggotCluster, get_paired_inds
48+
from src.data import load_metagraph
49+
from src.flow import fit_gm_exp, make_exp_match
50+
from src.graph import MetaGraph, preprocess
51+
from src.hierarchy import signal_flow
52+
from src.io import readcsv, savecsv, savefig
53+
from src.pymaid import start_instance
54+
from src.traverse import Cascade, RandomWalk, to_markov_matrix, to_transmission_matrix
55+
from src.utils import get_blockmodel_df
56+
from src.visualization import (
57+
CLASS_COLOR_DICT,
58+
add_connections,
59+
adjplot,
60+
barplot_text,
61+
draw_networkx_nice,
62+
gridmap,
63+
matrixplot,
64+
palplot,
65+
remove_shared_ax,
66+
remove_spines,
67+
screeplot,
68+
set_axes_equal,
69+
stacked_barplot,
70+
)
71+
72+
73+
FNAME = os.path.basename(__file__)[:-3]
74+
print(FNAME)
75+
76+
77+
def stashfig(name, **kws):
78+
savefig(name, foldername=FNAME, save_on=True, **kws)
79+
80+
81+
def stashcsv(df, name, **kws):
82+
savecsv(df, name, foldername=FNAME, **kws)
83+
84+
85+
rc_dict = {
86+
"axes.spines.right": False,
87+
"axes.spines.top": False,
88+
"axes.formatter.limits": (-3, 3),
89+
"figure.figsize": (6, 3),
90+
"figure.dpi": 100,
91+
"axes.edgecolor": "lightgrey",
92+
"ytick.color": "dimgrey",
93+
"xtick.color": "dimgrey",
94+
"axes.labelcolor": "dimgrey",
95+
"text.color": "dimgrey",
96+
}
97+
for key, val in rc_dict.items():
98+
mpl.rcParams[key] = val
99+
context = sns.plotting_context(context="talk", font_scale=1, rc=rc_dict)
100+
sns.set_context(context)
101+
102+
np.random.seed(8888)
103+
104+
105+
# %% [markdown]
106+
# ##
107+
108+
# parameters of clustering
109+
metric = "bic"
110+
bic_ratio = 1
111+
d = 8 # embedding dimension
112+
method = "color_iso"
113+
114+
basename = f"-method={method}-d={d}-bic_ratio={bic_ratio}"
115+
title = f"Method={method}, d={d}, BIC ratio={bic_ratio}"
116+
117+
exp = "137.3-BDP-omni-clust"
118+
119+
# load data
120+
cluster_meta = readcsv("meta" + basename, foldername=exp, index_col=0)
121+
cluster_meta["lvl0_labels"] = cluster_meta["lvl0_labels"].astype(str)
122+
123+
# parameters of Bhat plotting
124+
graph_type = "Gad"
125+
level = 3
126+
label_key = f"lvl{level}_labels"
127+
use_sides = False
128+
use_super = False
129+
use_sens = False
130+
if use_sides:
131+
label_key += "_side"
132+
basename = (
133+
f"-{graph_type}-lvl={level}-sides={use_sides}-sup={use_super}-sens={use_sens}"
134+
)
135+
136+
full_mg = load_metagraph(graph_type)
137+
if use_super:
138+
super_mg = load_metagraph("Gs")
139+
super_mg = super_mg.reindex(full_mg.meta.index, use_ids=True)
140+
full_mg = MetaGraph(full_mg.adj + super_mg.adj, full_mg.meta)
141+
142+
meta = full_mg.meta
143+
144+
# get labels from the clustering
145+
cluster_labels = cluster_meta[label_key]
146+
meta["label"] = cluster_labels
147+
148+
# deal with sensories - will overwrite cluster labels for ORN
149+
if use_sens:
150+
sens_meta = meta[meta["class1"] == "sens"]
151+
meta.loc[sens_meta.index, "label"] = sens_meta["merge_class"]
152+
153+
# deal with supers
154+
if use_super:
155+
brain_names = ["Brain Hemisphere left", "Brain Hemisphere right"]
156+
super_meta = meta[meta["class1"] == "super"].copy()
157+
super_meta = super_meta[~super_meta["name"].isin(brain_names)]
158+
if not use_sides:
159+
160+
def strip_side(x):
161+
x = x.replace("_left", "")
162+
x = x.replace("_right", "")
163+
return x
164+
165+
super_meta["name"] = super_meta["name"].map(strip_side)
166+
167+
meta.loc[super_meta.index, "label"] = super_meta["name"]
168+
169+
labeled_inds = meta[~meta["label"].isna()].index
170+
full_mg = full_mg.reindex(labeled_inds, use_ids=True)
171+
full_mg.meta.loc[:, "inds"] = range(len(full_mg))
172+
print(len(full_mg))
173+
print(full_mg["label"].unique())
174+
# unlabeled_meta = full_mg.meta[full_mg.meta["label"].isna()]
175+
176+
177+
def calc_bar_params(sizes, label, mid, palette=None):
178+
if palette is None:
179+
palette = CLASS_COLOR_DICT
180+
heights = sizes.loc[label]
181+
n_in_bar = heights.sum()
182+
offset = mid - n_in_bar / 2
183+
starts = heights.cumsum() - heights + offset
184+
colors = np.vectorize(palette.get)(heights.index)
185+
return heights, starts, colors
186+
187+
188+
def plot_bar(meta, mid, ax, orientation="horizontal", width=0.7):
189+
if orientation == "horizontal":
190+
method = ax.barh
191+
ax.xaxis.set_visible(False)
192+
remove_spines(ax)
193+
elif orientation == "vertical":
194+
method = ax.bar
195+
ax.yaxis.set_visible(False)
196+
remove_spines(ax)
197+
sizes = meta.groupby("merge_class").size()
198+
sizes /= sizes.sum()
199+
starts = sizes.cumsum() - sizes
200+
colors = np.vectorize(CLASS_COLOR_DICT.get)(starts.index)
201+
for i in range(len(sizes)):
202+
method(mid, sizes[i], width, starts[i], color=colors[i])
203+
204+
205+
meta = full_mg.meta
206+
adj = full_mg.adj
207+
fig, axs = plt.subplots(1, 2, figsize=(20, 10))
208+
209+
210+
labels = meta["label"]
211+
bar_ratio = 0.05
212+
use_weights = True
213+
use_counts = True
214+
sort_method = "sf"
215+
alpha = 0.05
216+
width = 0.9
217+
log = False
218+
basename += f"-weights={use_weights}-counts={use_counts}-sort={sort_method}-log={log}"
219+
220+
blockmodel_df = get_blockmodel_df(
221+
adj, labels, return_counts=use_counts, use_weights=use_weights
222+
)
223+
heatmap_kws = dict(square=True, cmap="Reds", cbar_kws=dict(shrink=0.7))
224+
data = blockmodel_df.values
225+
# data = pass_to_ranks(data)
226+
# data = np.log10(data + 1)
227+
if log:
228+
data = np.log10(data)
229+
data[~np.isfinite(data)] = 0
230+
blockmodel_df = pd.DataFrame(
231+
data=data, index=blockmodel_df.index, columns=blockmodel_df.columns
232+
)
233+
234+
if sort_method == "sf":
235+
sf = signal_flow(data)
236+
perm = np.argsort(-sf)
237+
if sort_method == "gm":
238+
perm, score = fit_gm_exp(data, alpha, 1, 0, n_init=20, return_best=True)
239+
240+
241+
blockmodel_df = blockmodel_df.iloc[perm, perm]
242+
data = blockmodel_df.values
243+
244+
uni_labels = blockmodel_df.index.values
245+
246+
ax = axs[0]
247+
adjplot(data, ax=ax, cbar=False)
248+
# sns.heatmap(data, ax=ax, cbar=False, xticklabels=False, yticklabels=False)
249+
divider = make_axes_locatable(ax)
250+
top_ax = divider.append_axes("top", size=f"{bar_ratio*100}%", pad=0, sharex=ax)
251+
left_ax = divider.append_axes("left", size=f"{bar_ratio*100}%", pad=0, sharey=ax)
252+
remove_shared_ax(top_ax)
253+
remove_shared_ax(left_ax)
254+
mids = np.arange(len(data)) + 0.5
255+
256+
for i, label in enumerate(uni_labels):
257+
temp_meta = meta[meta["label"] == label]
258+
plot_bar(temp_meta, mids[i], left_ax, orientation="horizontal", width=width)
259+
plot_bar(temp_meta, mids[i], top_ax, orientation="vertical", width=width)
260+
261+
ax.yaxis.set_visible(True)
262+
ax.yaxis.tick_right()
263+
ax.yaxis.set_ticks(np.arange(len(data)) + 0.5)
264+
ax.yaxis.set_ticklabels(uni_labels, fontsize=10, color="dimgrey", va="center")
265+
ax.yaxis.set_tick_params(rotation=0, color="dimgrey")
266+
267+
if len(uni_labels) <= 10:
268+
pal = sns.color_palette("tab10")
269+
elif len(uni_labels) <= 20:
270+
pal = sns.color_palette("tab20")
271+
else:
272+
pal = cc.glasbey_light
273+
color_map = dict(zip(uni_labels, pal))
274+
ticklabels = axs[0].get_yticklabels()
275+
for t in ticklabels:
276+
text = t.get_text()
277+
t.set_color(color_map[text])
278+
279+
280+
remove_diag = True
281+
282+
# convert the adjacency and a partition to a minigraph based on SBM probs
283+
prob_df = blockmodel_df
284+
if remove_diag:
285+
adj = prob_df.values
286+
adj -= np.diag(np.diag(adj))
287+
prob_df = pd.DataFrame(data=adj, index=prob_df.index, columns=prob_df.columns)
288+
289+
g = nx.from_pandas_adjacency(prob_df, create_using=nx.DiGraph())
290+
uni_labels, counts = np.unique(labels, return_counts=True)
291+
292+
# add size attribute base on number of vertices
293+
size_map = dict(zip(uni_labels, counts))
294+
nx.set_node_attributes(g, size_map, name="Size")
295+
296+
# add signal flow attribute (for the minigraph itself)
297+
mini_adj = nx.to_numpy_array(g, nodelist=uni_labels)
298+
node_signal_flow = signal_flow(mini_adj)
299+
sf_map = dict(zip(uni_labels, node_signal_flow))
300+
nx.set_node_attributes(g, sf_map, name="Signal Flow")
301+
302+
# rank signal flow
303+
sort_inds = np.argsort(node_signal_flow)
304+
rank_inds = np.argsort(sort_inds)
305+
rank_sf_map = dict(zip(uni_labels, rank_inds))
306+
nx.set_node_attributes(g, rank_sf_map, name="rank_sf")
307+
308+
# add spectral properties
309+
sym_adj = symmetrize(mini_adj)
310+
n_components = 5
311+
latent = AdjacencySpectralEmbed(n_components=n_components).fit_transform(sym_adj)
312+
for i in range(n_components):
313+
latent_dim = latent[:, i]
314+
lap_map = dict(zip(uni_labels, latent_dim))
315+
nx.set_node_attributes(g, lap_map, name=f"AdjEvec-{i}")
316+
317+
# add spring layout properties
318+
pos = nx.spring_layout(g)
319+
spring_x = {}
320+
spring_y = {}
321+
for key, val in pos.items():
322+
spring_x[key] = val[0]
323+
spring_y[key] = val[1]
324+
nx.set_node_attributes(g, spring_x, name="Spring-x")
325+
nx.set_node_attributes(g, spring_y, name="Spring-y")
326+
327+
# add colors
328+
329+
nx.set_node_attributes(g, color_map, name="Color")
330+
331+
ax = axs[1]
332+
333+
x_pos_key = "AdjEvec-1"
334+
y_pos_key = "rank_sf"
335+
x_pos = nx.get_node_attributes(g, x_pos_key)
336+
y_pos = nx.get_node_attributes(g, y_pos_key)
337+
338+
# all_x_pos = list(x_pos.items())
339+
# all_y_pos = list(y_pos.items())
340+
# y_max = max(all_y_pos)
341+
342+
if use_counts:
343+
vmin = 1000
344+
weight_scale = 1 / 2000
345+
else:
346+
weight_scale = 1
347+
vmin = 0.01
348+
349+
draw_networkx_nice(
350+
g,
351+
x_pos_key,
352+
y_pos_key,
353+
colors="Color",
354+
sizes="Size",
355+
weight_scale=weight_scale,
356+
vmin=vmin,
357+
ax=ax,
358+
y_boost=0.3,
359+
)
360+
stashfig(f"layout-x={x_pos_key}-y={y_pos_key}" + basename)
361+
362+
363+
364+
# %%

0 commit comments

Comments
 (0)