1- import psutil
1+ # import psutil
22import os
3+ import multiprocessing
34from time import time
4- from sklearn import preprocessing
55
6- import pykgn as kgn
76import numpy as np
87import faiss
98from faiss import Kmeans
9+ from sklearn import preprocessing
10+
11+ import pykgn as kgn
1012
1113from ..base .module import BaseANN
1214
1315
16+
17+
1418class EPSearcher :
1519 def __init__ (self , data : np .ndarray , cur_ep : int ) -> None :
1620 self .data = data
@@ -74,72 +78,39 @@ def metric_mapping(metric):
7478class Kgn (BaseANN ):
7579 def __init__ (self , metric , dim , method_param ):
7680 self .metric = metric_mapping (metric )
77- self .R = method_param ['R' ]
78- self .L = method_param ['L' ]
79- self .index_type = method_param ['index_type' ]
80- self .optimize = method_param ['optimize' ]
81- self .batch = method_param ['batch' ]
82- self .kmeans_ep = method_param ['kmeans_ep' ]
83- self .kmeans_type = method_param ['kmeans_type' ]
84- self .level = method_param ['level' ]
8581 self .name = 'kgn_(%s)' % (method_param )
8682 self .dir = 'indices'
87- self .path = f'{ metric } _{ dim } _{ self .index_type } _R_{ self .R } _L_{ self .L } .kgn'
88-
83+ self .path = f'{ metric } _{ dim } .kgn'
84+ self .R = method_param ['R' ] # [128, 160]
85+ self .level = method_param ['level' ] # [1, 2]
86+
87+ def build (self , X ):
88+ Index = kgn .Index (nb = self .n , dim = self .d , base = X , topK = 10 , metric = self .metric , level = self .level , R = self .R )
89+ full_path = os .path .join (self .dir , self .path )
90+ Index .build (full_path )
91+
8992 def fit (self , X ):
90- print (self .name , self .level , self .metric )
9193 if self .metric == "IP" :
9294 X = preprocessing .normalize (X , "l2" , axis = 1 )
9395 self .d = X .shape [1 ]
96+ self .n = X .shape [0 ]
9497 if not os .path .exists (self .dir ):
9598 os .mkdir (self .dir )
9699 if self .path not in os .listdir (self .dir ):
97- print ("build Index" )
98- p = kgn .Index (self .index_type , dim = self .d ,
99- metric = self .metric , R = self .R , L = self .L )
100- g = p .build (X ,20 )
101- g .save (os .path .join (self .dir , self .path ))
102- del p
103- del g
104-
105- # find kmeans centers -- RI
106- if (self .kmeans_type == 0 ):
107- RI = np .array ([])
108- elif (self .kmeans_type == 2 ):
109- t = time ()
110- kmeans_ep_searcher = EPSearcherKmeans_re (X , 0 , self .kmeans_ep , self .metric )
111- T = time () - t
112- print ("Time of bi_kmeans = " , T , " k=" , self .kmeans_ep )
113- RI = kmeans_ep_searcher .get_cent ()
114- else :
115- print ("Error: no such kmeans algorithm in main_opt.py" )
116- print ("kmeans_ep" , self .kmeans_ep )
117- g = kgn .Graph ()
118- g .load (os .path .join (self .dir , self .path ))
119- if self .level == 1 :
120- self .searcher = kgn .Searcher (g , X , self .metric , "SQ8U" ,20 )
121- elif self .level == 2 :
122- self .searcher = kgn .Searcher (g , X , self .metric , "SQ4U" ,20 )
123- print ("Make Searcher" )
124-
125- if self .optimize :
126- if self .batch :
127- if self .level <= 4 :
128- self .searcher .optimize ()
129- else :
130- print (self .level , "no needs optimized" )
131- pass
100+ full_path = os .path .join (self .dir , self .path )
101+ self .Index = kgn .Index (nb = self .n , dim = self .d , base = X , topK = 10 , metric = self .metric , level = self .level , R = self .R )
102+ if os .path .exists (full_path ) and os .path .isfile (full_path ):
103+ print ("load Index" )
104+ self .Index .load (full_path )
132105 else :
133- if self .level <= 4 :
134- self .searcher .optimize (1 )
135- else :
136- print (self .level , "no needs optimized" )
137- pass
138- print ("Optimize Parameters" )
139-
106+ print ("build Index" )
107+ p = multiprocessing .Process (target = self .build , args = (X , ))
108+ p .start ()
109+ p .join ()
110+ self .Index .load (full_path )
111+
140112
141113 def set_query_arguments (self , ef ):
142- self .searcher .set_ef (ef )
143114 self .ef = ef
144115
145116 def prepare_query (self , q , n ):
@@ -149,15 +120,10 @@ def prepare_query(self, q, n):
149120 self .n = n
150121
151122 def run_prepared_query (self ):
152- if self .level <= 3 :
153- self .res = self .searcher .search (
154- self .q , self .n )
155- else :
156- self .res = self .searcher .search (
157- self .q , self .n )
123+ self .res = self .Index .search (self .ef , self .q )
158124
159125 def get_prepared_query_results (self ):
160126 return self .res
161127
162128 def freeIndex (self ):
163- del self .searcher
129+ del self .Index
0 commit comments