Skip to content

Commit

Permalink
swap to type dispatching for graph & W
Browse files Browse the repository at this point in the history
  • Loading branch information
ljwolf committed Aug 15, 2024
1 parent 0b62548 commit 9ca259d
Showing 1 changed file with 37 additions and 24 deletions.
61 changes: 37 additions & 24 deletions esda/moran_local_mv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import esda
from libpysal.weights import lag_spatial
from libpysal.graph import Graph

try:
from tqdm import tqdm
Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(
Attributes
----------
W : The weights matrix inputted, but row standardized
connectivity : The weights matrix inputted, but row standardized
D : The "design" matrix used in computation. If X is
not None, this will be [1 y X]
R : the "response" matrix used in computation. Will
Expand Down Expand Up @@ -79,9 +80,10 @@ def __init__(
"""
self._mvquads = mvquads
y = np.asarray(y).reshape(-1, 1)
if hasattr(W, "to_W"):
W = W.to_W()
W.transform = "r"
if isinstance(W, Graph):
W = W.transform("R")
else:
W.transform = "r" # TODO: as a function for graph
y = y - y.mean()
if unit_scale:
y /= y.std()
Expand All @@ -97,17 +99,17 @@ def __init__(
self.DtDi = np.linalg.inv(
self.D.T @ self.D
) # this is only PxP, so not too bad...
self._left_component_ = (self.D @ self.DtDi) * (W.n - 1)
self._left_component_ = (self.D @ self.DtDi) * (self.N - 1)
self._lmos_ = self._left_component_ * self.R
self.W = W
self.connectivity = W
self.permutations = permutations
if permutations is not None: # NOQA necessary to avoid None > 0
if permutations > 0:
self._crand(y, X, W)


self._rlmos_ *= W.n - 1
self._p_sim_ = np.zeros((W.n, self.P + 1))
self._rlmos_ *= self.N - 1
self._p_sim_ = np.zeros((self.N, self.P + 1))
# TODO: this should be changed to the general p-value framework
for permutation in range(self.permutations):
self._p_sim_ += (
Expand Down Expand Up @@ -167,12 +169,18 @@ def _crand(self, y, X, W):
N = W.n
N_permutations = self.permutations
prange = range(N_permutations)
max_neighbs = W.max_neighbors + 1
if isinstance(W, Graph):
max_neighbs = W.cardinalities.max() + 1
else:
max_neighbs = W.max_neighbors + 1
pre_permutations = np.array(
[np.random.permutation(N - 1)[0:max_neighbs] for i in prange]
)
straight_ids = np.arange(N)
id_order = W.id_order
if isinstance(W, Graph): # NOQA
id_order = W.unique_ids
else:
id_order = W.id_order
DtDi = self.DtDi
ordered_weights = [W.weights[id_order[i]] for i in straight_ids]
ordered_cardinalities = [W.cardinalities[id_order[i]] for i in straight_ids]
Expand Down Expand Up @@ -309,12 +317,14 @@ def __init__(
X /= X.std(axis=0)
self.y = y
self.X = X
if hasattr(W, "to_W"):
W = W.to_W()
W.transform = "r"
self.W = W
y_filtered_ = self.y_filtered_ = self._part_regress_transform(y, X)
Wyf = lag_spatial(self.W, y_filtered_)
if isinstance(W, Graph):
W = W.transform("R")
Wyf = W.lag(y_filtered_)
else:
W.transform = "r"
Wyf = lag_spatial(W, y_filtered_) # TODO: graph
self.connectivity = W
self.partials_ = np.column_stack((y_filtered_, Wyf))
self.permutations = permutations
y_out = self.y_filtered_
Expand Down Expand Up @@ -355,18 +365,21 @@ def _crand(self):
neighbors to i in each randomization.
"""
_, z = self.partials_.T
lisas = np.zeros((self.W.n, self.permutations))
n_1 = self.W.n - 1
lisas = np.zeros((self.connectivity.n, self.permutations))
n_1 = self.connectivity.n - 1
prange = list(range(self.permutations))
k = self.W.max_neighbors + 1
nn = self.W.n - 1
k = self.connectivity.cardinalities.max() + 1
nn = self.connectivity.n - 1
rids = np.array([np.random.permutation(nn)[0:k] for i in prange])
ids = np.arange(self.W.n)
ido = self.W.id_order
w = [self.W.weights[ido[i]] for i in ids]
wc = [self.W.cardinalities[ido[i]] for i in ids]
ids = np.arange(self.connectivity.n)
if hasattr(self.connectivity, "id_order"):
ido = self.connectivity.id_order
else:
ido = self.connectivity.unique_ids.values
w = [self.connectivity.weights[ido[i]] for i in ids]
wc = [self.connectivity.cardinalities[ido[i]] for i in ids]

for i in tqdm(range(self.W.n), desc="Simulating by site"):
for i in tqdm(range(self.connectivity.n), desc="Simulating by site"):
idsi = ids[ids != i]
np.random.shuffle(idsi)
tmp = z[idsi[rids[:, 0 : wc[i]]]]
Expand Down

0 comments on commit 9ca259d

Please sign in to comment.