From 202296bebbe88e6931b6ed20e6d127d37b6b837c Mon Sep 17 00:00:00 2001 From: "Jeremy R. Manning" Date: Mon, 6 Nov 2023 12:19:07 -0500 Subject: [PATCH] fix clustering ranks when all dists are equal --- quail/analysis/clustering.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/quail/analysis/clustering.py b/quail/analysis/clustering.py index ae6db08..f89274b 100644 --- a/quail/analysis/clustering.py +++ b/quail/analysis/clustering.py @@ -72,7 +72,11 @@ def _get_weight_exact(egg, feature, distdict, permute, n_perms): dists = distmat[pres.index(c),:] di = dists[pres.index(n)] dists_filt = np.array([dist for idx, dist in enumerate(dists) if idx not in past_idxs]) - ranks.append(np.mean(np.where(np.sort(dists_filt)[::-1] == di)[0]+1) / len(dists_filt)) + + if len(np.unique(dists_filt)) == 1: + ranks.append(0.5) + else: + ranks.append(np.mean(np.where(np.sort(dists_filt)[::-1] == di)[0]+1) / len(dists_filt)) past_idxs.append(pres.index(c)) past_words.append(c) return np.nanmean(ranks) @@ -97,7 +101,10 @@ def _get_weight_best(egg, feature, distdict, permute, n_perms, distance): dists = distmat[cdx, :] di = dists[ndx] dists_filt = np.array([dist for idx, dist in enumerate(dists)]) - ranks.append(np.mean(np.where(np.sort(dists_filt)[::-1] == di)[0]+1) / len(dists_filt)) + if len(np.unique(dists_filt)) == 1: + ranks.append(0.5) + else: + ranks.append(np.mean(np.where(np.sort(dists_filt)[::-1] == di)[0] + 1) / len(dists_filt)) return np.nanmean(ranks) def _get_weight_smooth(egg, feature, distdict, permute, n_perms, distance): @@ -120,7 +127,10 @@ def _get_weight_smooth(egg, feature, distdict, permute, n_perms, distance): dists = distmat[cdx, :] di = dists[ndx] dists_filt = np.array([dist for idx, dist in enumerate(dists)]) - ranks.append(np.mean(np.where(np.sort(dists_filt)[::-1] == di)[0]+1) / len(dists_filt)) + if len(np.unique(dists_filt)) == 1: + ranks.append(0.5) + else: + ranks.append(np.mean(np.where(np.sort(dists_filt)[::-1] == di)[0] + 1) / len(dists_filt)) return np.nanmean(ranks) def get_distmat(egg, feature, distdict):