|
| 1 | +# %% [markdown] |
| 2 | +# ## |
| 3 | +import os |
| 4 | +import warnings |
| 5 | +from itertools import chain |
| 6 | + |
| 7 | +import colorcet as cc |
| 8 | +import matplotlib as mpl |
| 9 | +import matplotlib.gridspec as gridspec |
| 10 | +import matplotlib.patches as patches |
| 11 | +import matplotlib.pyplot as plt |
| 12 | +import matplotlib.transforms as transforms |
| 13 | +import networkx as nx |
| 14 | +import numpy as np |
| 15 | +import pandas as pd |
| 16 | +import seaborn as sns |
| 17 | +from joblib import Parallel, delayed |
| 18 | +from mpl_toolkits.axes_grid1 import make_axes_locatable |
| 19 | +from scipy.stats import poisson |
| 20 | +from sklearn.exceptions import ConvergenceWarning |
| 21 | +from sklearn.manifold import MDS, TSNE, Isomap |
| 22 | +from sklearn.metrics import pairwise_distances |
| 23 | +from sklearn.neighbors import NearestNeighbors |
| 24 | +from sklearn.utils.testing import ignore_warnings |
| 25 | +from tqdm.autonotebook import tqdm |
| 26 | +from umap import UMAP |
| 27 | + |
| 28 | +from graspy.embed import ( |
| 29 | + AdjacencySpectralEmbed, |
| 30 | + ClassicalMDS, |
| 31 | + LaplacianSpectralEmbed, |
| 32 | + OmnibusEmbed, |
| 33 | + select_dimension, |
| 34 | + selectSVD, |
| 35 | +) |
| 36 | +from graspy.models import DCSBMEstimator, SBMEstimator |
| 37 | +from graspy.plot import pairplot |
| 38 | +from graspy.utils import ( |
| 39 | + augment_diagonal, |
| 40 | + binarize, |
| 41 | + pass_to_ranks, |
| 42 | + remove_loops, |
| 43 | + symmetrize, |
| 44 | + to_laplace, |
| 45 | +) |
| 46 | +from src.align import Procrustes |
| 47 | +from src.cluster import BinaryCluster, MaggotCluster, get_paired_inds |
| 48 | +from src.data import load_metagraph |
| 49 | +from src.flow import fit_gm_exp, make_exp_match |
| 50 | +from src.graph import MetaGraph, preprocess |
| 51 | +from src.hierarchy import signal_flow |
| 52 | +from src.io import readcsv, savecsv, savefig |
| 53 | +from src.pymaid import start_instance |
| 54 | +from src.traverse import Cascade, RandomWalk, to_markov_matrix, to_transmission_matrix |
| 55 | +from src.utils import get_blockmodel_df |
| 56 | +from src.visualization import ( |
| 57 | + CLASS_COLOR_DICT, |
| 58 | + add_connections, |
| 59 | + adjplot, |
| 60 | + barplot_text, |
| 61 | + draw_networkx_nice, |
| 62 | + gridmap, |
| 63 | + matrixplot, |
| 64 | + palplot, |
| 65 | + remove_shared_ax, |
| 66 | + remove_spines, |
| 67 | + screeplot, |
| 68 | + set_axes_equal, |
| 69 | + stacked_barplot, |
| 70 | +) |
| 71 | + |
| 72 | + |
| 73 | +FNAME = os.path.basename(__file__)[:-3] |
| 74 | +print(FNAME) |
| 75 | + |
| 76 | + |
| 77 | +def stashfig(name, **kws): |
| 78 | + savefig(name, foldername=FNAME, save_on=True, **kws) |
| 79 | + |
| 80 | + |
| 81 | +def stashcsv(df, name, **kws): |
| 82 | + savecsv(df, name, foldername=FNAME, **kws) |
| 83 | + |
| 84 | + |
| 85 | +rc_dict = { |
| 86 | + "axes.spines.right": False, |
| 87 | + "axes.spines.top": False, |
| 88 | + "axes.formatter.limits": (-3, 3), |
| 89 | + "figure.figsize": (6, 3), |
| 90 | + "figure.dpi": 100, |
| 91 | + "axes.edgecolor": "lightgrey", |
| 92 | + "ytick.color": "dimgrey", |
| 93 | + "xtick.color": "dimgrey", |
| 94 | + "axes.labelcolor": "dimgrey", |
| 95 | + "text.color": "dimgrey", |
| 96 | +} |
| 97 | +for key, val in rc_dict.items(): |
| 98 | + mpl.rcParams[key] = val |
| 99 | +context = sns.plotting_context(context="talk", font_scale=1, rc=rc_dict) |
| 100 | +sns.set_context(context) |
| 101 | + |
| 102 | +np.random.seed(8888) |
| 103 | + |
| 104 | + |
| 105 | +# %% [markdown] |
| 106 | +# ## |
| 107 | + |
| 108 | +# parameters of clustering |
| 109 | +metric = "bic" |
| 110 | +bic_ratio = 1 |
| 111 | +d = 8 # embedding dimension |
| 112 | +method = "color_iso" |
| 113 | + |
| 114 | +basename = f"-method={method}-d={d}-bic_ratio={bic_ratio}" |
| 115 | +title = f"Method={method}, d={d}, BIC ratio={bic_ratio}" |
| 116 | + |
| 117 | +exp = "137.3-BDP-omni-clust" |
| 118 | + |
| 119 | +# load data |
| 120 | +cluster_meta = readcsv("meta" + basename, foldername=exp, index_col=0) |
| 121 | +cluster_meta["lvl0_labels"] = cluster_meta["lvl0_labels"].astype(str) |
| 122 | + |
| 123 | +# parameters of Bhat plotting |
| 124 | +graph_type = "Gad" |
| 125 | +level = 3 |
| 126 | +label_key = f"lvl{level}_labels" |
| 127 | +use_sides = False |
| 128 | +use_super = False |
| 129 | +use_sens = False |
| 130 | +if use_sides: |
| 131 | + label_key += "_side" |
| 132 | +basename = ( |
| 133 | + f"-{graph_type}-lvl={level}-sides={use_sides}-sup={use_super}-sens={use_sens}" |
| 134 | +) |
| 135 | + |
| 136 | +full_mg = load_metagraph(graph_type) |
| 137 | +if use_super: |
| 138 | + super_mg = load_metagraph("Gs") |
| 139 | + super_mg = super_mg.reindex(full_mg.meta.index, use_ids=True) |
| 140 | + full_mg = MetaGraph(full_mg.adj + super_mg.adj, full_mg.meta) |
| 141 | + |
| 142 | +meta = full_mg.meta |
| 143 | + |
| 144 | +# get labels from the clustering |
| 145 | +cluster_labels = cluster_meta[label_key] |
| 146 | +meta["label"] = cluster_labels |
| 147 | + |
| 148 | +# deal with sensories - will overwrite cluster labels for ORN |
| 149 | +if use_sens: |
| 150 | + sens_meta = meta[meta["class1"] == "sens"] |
| 151 | + meta.loc[sens_meta.index, "label"] = sens_meta["merge_class"] |
| 152 | + |
| 153 | +# deal with supers |
| 154 | +if use_super: |
| 155 | + brain_names = ["Brain Hemisphere left", "Brain Hemisphere right"] |
| 156 | + super_meta = meta[meta["class1"] == "super"].copy() |
| 157 | + super_meta = super_meta[~super_meta["name"].isin(brain_names)] |
| 158 | + if not use_sides: |
| 159 | + |
| 160 | + def strip_side(x): |
| 161 | + x = x.replace("_left", "") |
| 162 | + x = x.replace("_right", "") |
| 163 | + return x |
| 164 | + |
| 165 | + super_meta["name"] = super_meta["name"].map(strip_side) |
| 166 | + |
| 167 | + meta.loc[super_meta.index, "label"] = super_meta["name"] |
| 168 | + |
| 169 | +labeled_inds = meta[~meta["label"].isna()].index |
| 170 | +full_mg = full_mg.reindex(labeled_inds, use_ids=True) |
| 171 | +full_mg.meta.loc[:, "inds"] = range(len(full_mg)) |
| 172 | +print(len(full_mg)) |
| 173 | +print(full_mg["label"].unique()) |
| 174 | +# unlabeled_meta = full_mg.meta[full_mg.meta["label"].isna()] |
| 175 | + |
| 176 | + |
| 177 | +def calc_bar_params(sizes, label, mid, palette=None): |
| 178 | + if palette is None: |
| 179 | + palette = CLASS_COLOR_DICT |
| 180 | + heights = sizes.loc[label] |
| 181 | + n_in_bar = heights.sum() |
| 182 | + offset = mid - n_in_bar / 2 |
| 183 | + starts = heights.cumsum() - heights + offset |
| 184 | + colors = np.vectorize(palette.get)(heights.index) |
| 185 | + return heights, starts, colors |
| 186 | + |
| 187 | + |
| 188 | +def plot_bar(meta, mid, ax, orientation="horizontal", width=0.7): |
| 189 | + if orientation == "horizontal": |
| 190 | + method = ax.barh |
| 191 | + ax.xaxis.set_visible(False) |
| 192 | + remove_spines(ax) |
| 193 | + elif orientation == "vertical": |
| 194 | + method = ax.bar |
| 195 | + ax.yaxis.set_visible(False) |
| 196 | + remove_spines(ax) |
| 197 | + sizes = meta.groupby("merge_class").size() |
| 198 | + sizes /= sizes.sum() |
| 199 | + starts = sizes.cumsum() - sizes |
| 200 | + colors = np.vectorize(CLASS_COLOR_DICT.get)(starts.index) |
| 201 | + for i in range(len(sizes)): |
| 202 | + method(mid, sizes[i], width, starts[i], color=colors[i]) |
| 203 | + |
| 204 | + |
| 205 | +meta = full_mg.meta |
| 206 | +adj = full_mg.adj |
| 207 | +fig, axs = plt.subplots(1, 2, figsize=(20, 10)) |
| 208 | + |
| 209 | + |
| 210 | +labels = meta["label"] |
| 211 | +bar_ratio = 0.05 |
| 212 | +use_weights = True |
| 213 | +use_counts = True |
| 214 | +sort_method = "sf" |
| 215 | +alpha = 0.05 |
| 216 | +width = 0.9 |
| 217 | +log = False |
| 218 | +basename += f"-weights={use_weights}-counts={use_counts}-sort={sort_method}-log={log}" |
| 219 | + |
| 220 | +blockmodel_df = get_blockmodel_df( |
| 221 | + adj, labels, return_counts=use_counts, use_weights=use_weights |
| 222 | +) |
| 223 | +heatmap_kws = dict(square=True, cmap="Reds", cbar_kws=dict(shrink=0.7)) |
| 224 | +data = blockmodel_df.values |
| 225 | +# data = pass_to_ranks(data) |
| 226 | +# data = np.log10(data + 1) |
| 227 | +if log: |
| 228 | + data = np.log10(data) |
| 229 | + data[~np.isfinite(data)] = 0 |
| 230 | + blockmodel_df = pd.DataFrame( |
| 231 | + data=data, index=blockmodel_df.index, columns=blockmodel_df.columns |
| 232 | + ) |
| 233 | + |
| 234 | +if sort_method == "sf": |
| 235 | + sf = signal_flow(data) |
| 236 | + perm = np.argsort(-sf) |
| 237 | +if sort_method == "gm": |
| 238 | + perm, score = fit_gm_exp(data, alpha, 1, 0, n_init=20, return_best=True) |
| 239 | + |
| 240 | + |
| 241 | +blockmodel_df = blockmodel_df.iloc[perm, perm] |
| 242 | +data = blockmodel_df.values |
| 243 | + |
| 244 | +uni_labels = blockmodel_df.index.values |
| 245 | + |
| 246 | +ax = axs[0] |
| 247 | +adjplot(data, ax=ax, cbar=False) |
| 248 | +# sns.heatmap(data, ax=ax, cbar=False, xticklabels=False, yticklabels=False) |
| 249 | +divider = make_axes_locatable(ax) |
| 250 | +top_ax = divider.append_axes("top", size=f"{bar_ratio*100}%", pad=0, sharex=ax) |
| 251 | +left_ax = divider.append_axes("left", size=f"{bar_ratio*100}%", pad=0, sharey=ax) |
| 252 | +remove_shared_ax(top_ax) |
| 253 | +remove_shared_ax(left_ax) |
| 254 | +mids = np.arange(len(data)) + 0.5 |
| 255 | + |
| 256 | +for i, label in enumerate(uni_labels): |
| 257 | + temp_meta = meta[meta["label"] == label] |
| 258 | + plot_bar(temp_meta, mids[i], left_ax, orientation="horizontal", width=width) |
| 259 | + plot_bar(temp_meta, mids[i], top_ax, orientation="vertical", width=width) |
| 260 | + |
| 261 | +ax.yaxis.set_visible(True) |
| 262 | +ax.yaxis.tick_right() |
| 263 | +ax.yaxis.set_ticks(np.arange(len(data)) + 0.5) |
| 264 | +ax.yaxis.set_ticklabels(uni_labels, fontsize=10, color="dimgrey", va="center") |
| 265 | +ax.yaxis.set_tick_params(rotation=0, color="dimgrey") |
| 266 | + |
| 267 | +if len(uni_labels) <= 10: |
| 268 | + pal = sns.color_palette("tab10") |
| 269 | +elif len(uni_labels) <= 20: |
| 270 | + pal = sns.color_palette("tab20") |
| 271 | +else: |
| 272 | + pal = cc.glasbey_light |
| 273 | +color_map = dict(zip(uni_labels, pal)) |
| 274 | +ticklabels = axs[0].get_yticklabels() |
| 275 | +for t in ticklabels: |
| 276 | + text = t.get_text() |
| 277 | + t.set_color(color_map[text]) |
| 278 | + |
| 279 | + |
| 280 | +remove_diag = True |
| 281 | + |
| 282 | +# convert the adjacency and a partition to a minigraph based on SBM probs |
| 283 | +prob_df = blockmodel_df |
| 284 | +if remove_diag: |
| 285 | + adj = prob_df.values |
| 286 | + adj -= np.diag(np.diag(adj)) |
| 287 | + prob_df = pd.DataFrame(data=adj, index=prob_df.index, columns=prob_df.columns) |
| 288 | + |
| 289 | +g = nx.from_pandas_adjacency(prob_df, create_using=nx.DiGraph()) |
| 290 | +uni_labels, counts = np.unique(labels, return_counts=True) |
| 291 | + |
| 292 | +# add size attribute base on number of vertices |
| 293 | +size_map = dict(zip(uni_labels, counts)) |
| 294 | +nx.set_node_attributes(g, size_map, name="Size") |
| 295 | + |
| 296 | +# add signal flow attribute (for the minigraph itself) |
| 297 | +mini_adj = nx.to_numpy_array(g, nodelist=uni_labels) |
| 298 | +node_signal_flow = signal_flow(mini_adj) |
| 299 | +sf_map = dict(zip(uni_labels, node_signal_flow)) |
| 300 | +nx.set_node_attributes(g, sf_map, name="Signal Flow") |
| 301 | + |
| 302 | +# rank signal flow |
| 303 | +sort_inds = np.argsort(node_signal_flow) |
| 304 | +rank_inds = np.argsort(sort_inds) |
| 305 | +rank_sf_map = dict(zip(uni_labels, rank_inds)) |
| 306 | +nx.set_node_attributes(g, rank_sf_map, name="rank_sf") |
| 307 | + |
| 308 | +# add spectral properties |
| 309 | +sym_adj = symmetrize(mini_adj) |
| 310 | +n_components = 5 |
| 311 | +latent = AdjacencySpectralEmbed(n_components=n_components).fit_transform(sym_adj) |
| 312 | +for i in range(n_components): |
| 313 | + latent_dim = latent[:, i] |
| 314 | + lap_map = dict(zip(uni_labels, latent_dim)) |
| 315 | + nx.set_node_attributes(g, lap_map, name=f"AdjEvec-{i}") |
| 316 | + |
| 317 | +# add spring layout properties |
| 318 | +pos = nx.spring_layout(g) |
| 319 | +spring_x = {} |
| 320 | +spring_y = {} |
| 321 | +for key, val in pos.items(): |
| 322 | + spring_x[key] = val[0] |
| 323 | + spring_y[key] = val[1] |
| 324 | +nx.set_node_attributes(g, spring_x, name="Spring-x") |
| 325 | +nx.set_node_attributes(g, spring_y, name="Spring-y") |
| 326 | + |
| 327 | +# add colors |
| 328 | + |
| 329 | +nx.set_node_attributes(g, color_map, name="Color") |
| 330 | + |
| 331 | +ax = axs[1] |
| 332 | + |
| 333 | +x_pos_key = "AdjEvec-1" |
| 334 | +y_pos_key = "rank_sf" |
| 335 | +x_pos = nx.get_node_attributes(g, x_pos_key) |
| 336 | +y_pos = nx.get_node_attributes(g, y_pos_key) |
| 337 | + |
| 338 | +# all_x_pos = list(x_pos.items()) |
| 339 | +# all_y_pos = list(y_pos.items()) |
| 340 | +# y_max = max(all_y_pos) |
| 341 | + |
| 342 | +if use_counts: |
| 343 | + vmin = 1000 |
| 344 | + weight_scale = 1 / 2000 |
| 345 | +else: |
| 346 | + weight_scale = 1 |
| 347 | + vmin = 0.01 |
| 348 | + |
| 349 | +draw_networkx_nice( |
| 350 | + g, |
| 351 | + x_pos_key, |
| 352 | + y_pos_key, |
| 353 | + colors="Color", |
| 354 | + sizes="Size", |
| 355 | + weight_scale=weight_scale, |
| 356 | + vmin=vmin, |
| 357 | + ax=ax, |
| 358 | + y_boost=0.3, |
| 359 | +) |
| 360 | +stashfig(f"layout-x={x_pos_key}-y={y_pos_key}" + basename) |
| 361 | + |
| 362 | + |
| 363 | + |
| 364 | +# %% |
0 commit comments