-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbase_graph.py
96 lines (68 loc) · 3.15 KB
/
base_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import math
import re
import copy
import random
import sympy
from dynamic_graph import *
from simple_base_graph import *
class BaseGraph(DynamicGraph):
def __init__(self, n_nodes, max_degree=1, seed=0, inner_edges=True):
self.state = np.random.RandomState(seed)
self.inner_edges = inner_edges
self.max_degree = max_degree
self.n_nodes = n_nodes
self.seed = seed
super().__init__(self.construct())
def construct(self):
node_list_list1, node_list_list2, n_power, n_rest = self.split_nodes()
simple_adics = [SimpleBaseGraph(len(node_list_list1[i]), max_degree=self.max_degree) for i in range(n_power)]
hyper_cubes = [HyperHyperCube(len(node_list_list2[i]), max_degree=self.max_degree) for i in range(n_rest)]
# check which is better
g = SimpleBaseGraph(self.n_nodes, max_degree=self.max_degree, seed=self.seed, inner_edges=self.inner_edges)
if len(g.w_list) < len(simple_adics[0].w_list) + len(hyper_cubes[0].w_list):
return g.w_list
w_list = []
for m in range(len(simple_adics[0].w_list)):
w = torch.zeros((self.n_nodes, self.n_nodes))
for l in range(n_power):
w += self.extend(simple_adics[l].w_list[m], node_list_list1[l])
w_list.append(w)
for m in range(len(hyper_cubes[0].w_list)):
w = torch.zeros((self.n_nodes, self.n_nodes))
for l in range(n_rest):
w += self.extend(hyper_cubes[l].w_list[m], node_list_list2[l])
w_list.append(w)
return w_list
def diag(self, X, Y):
new_W = torch.zeros((X.size()[0] + Y.size()[0], X.size()[0] + Y.size()[0]))
new_W[0:X.size()[0], 0:X.size()[0]] = X
new_W[X.size()[0]:, X.size()[0]:] = Y
return new_W
def extend(self, w, node_list):
new_w = torch.zeros((self.n_nodes, self.n_nodes))
for i in range(len(node_list)):
for j in range(len(node_list)):
new_w[node_list[i], node_list[j]] = w[i, j]
return new_w
def split_nodes(self):
factors = [n**int(math.log(self.n_nodes, n)) for n in range(2, self.max_degree+2)]
factor = np.prod(factors)
n_power = math.gcd(self.n_nodes, factor)
n_rest = int(self.n_nodes / n_power)
node_list = list(range(self.n_nodes))
node_list_list1 = []
for i in range(n_power):
node_list_list1.append(node_list[n_rest*i:n_rest*(i+1)])
node_list_list2 = [[] for _ in range(n_rest)]
for i in range(n_power):
for j in range(n_rest):
node_list_list2[j].append(node_list_list1[i][j])
return node_list_list1, node_list_list2, n_power, n_rest
def get_neighbors(self, i):
in_neighbors = self.get_in_neighbors(i)
out_neighbors = self.get_out_neighbors(i)
self.itr += 1
#if self.itr % len(self.w_list) == 0:
# self.w_list = self.shuffle_node_index(self.w_list, self.n_nodes)
return in_neighbors, out_neighbors