-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMultiIndexHash.py
120 lines (87 loc) · 3.58 KB
/
MultiIndexHash.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
class MultiIndexHash(object):
def __init__(self,codes,m=None):
self.N = codes.shape[0]
self.Q = codes.shape[1]
self.codes = codes
if not m:
m = codes.shape[1]//np.log2(self.N)
self.m = int(m)
self.s = np.array_split(np.arange(self.Q),self.m)
self.tables = self.init_tables()
self.lookup = list(np.asarray(list(t.keys())) for t in self.tables)
def init_tables(self):
'''creates multi-index hash tables
codes - a NxQ binary array with N vectors of length Q
m - number of tables to build, if empty, will compute optimal number'''
tables = []
for j in range(self.m):
table = {}
for i in range(self.N):
substr = tuple(self.codes[i,self.s[j]])
if substr not in table:
table[substr] = []
table[substr].append(i)
tables.append(table)
return tables
def r_search(self,query,r):
r_ = r // self.m
a = r % self.m
neighbors = []
## Search for neighbors using substring hash tables
for j in range(self.m):
if j < a:
r_search = r_
else:
r_search = r_ - 1
sub_index = self.s[j]
q_sub = query[sub_index]
look_up = self.lookup[j]
q_sub = np.reshape(q_sub,(1,-1))
dist = np.sum(np.logical_xor(q_sub,look_up), axis=1) ##Hamming Distance
candidates = set()
for n in np.argwhere(dist <= r_search).flatten():
key = tuple(look_up[n,:])
l = self.tables[j][key]
neighbors += l
## Check all neighbors using full Hamming Distance
neighbors = np.array(list(set(neighbors)))
codes_n = self.codes[neighbors,:]
dist = np.sum(np.logical_xor(query,codes_n), axis=1)
results = {}
for n in np.argwhere(dist <= r).flatten():
results[neighbors[n]] = dist[n]
return sorted(results.items(), key = lambda x: x[1])
def k_nn(self,query,k):
neighbors = [set() for i in range(self.Q)]
near = 0
j = 0
r = 0
r_ = 0
while near < k:
sub_index = self.s[j]
q_sub = query[sub_index]
look_up = self.lookup[j]
q_sub = np.reshape(q_sub,(1,-1))
dist = np.sum(np.logical_xor(q_sub,look_up), axis=1) ##Hamming Distance
candidates = set()
for n in np.argwhere(dist <= r_).flatten():
for l in self.tables[j][tuple(look_up[n,:])]:
candidates.add(l)
candidates = np.array(list(candidates))
codes_n = self.codes[candidates,:]
dist = np.sum(np.logical_xor(query,codes_n), axis=1)
for i in range(candidates.shape[0]):
d = dist[i]
neighbors[d].add(candidates[i])
near = sum(list(len(neighbors[d]) for d in range(r)))
j += 1
if j >= self.m:
j = 0
r_ += 1
r += 1
out = []
for d in range(r):
for n in neighbors[d]:
out.append((n,d))
return out