2222target database, if it has not been already created.
2323"""
2424
25+ import os
2526import subprocess
2627import sys
27- import os
28+ import threading
29+ import time
2830
2931import pgvector .psycopg
3032import psycopg
3537from ...util import get_bool_env_var
3638
3739
40+ METRIC_PROPERTIES = {
41+ "angular" : {
42+ "distance_operator" : "<=>" ,
43+ # A substring of e.g. vector_cosine_ops or halfvec_cosine_ops
44+ "ops_type" : "cosine" ,
45+ },
46+ "euclidean" : {
47+ "distance_operator" : "<->" ,
48+ "ops_type" : "l2" ,
49+ }
50+ }
51+
52+
3853def get_pg_param_env_var_name (pg_param_name : str ) -> str :
3954 return f'ANN_BENCHMARKS_PG_{ pg_param_name .upper ()} '
4055
@@ -49,6 +64,151 @@ def get_pg_conn_param(
4964 return env_var_value
5065
5166
67+ class IndexingProgressMonitor :
68+ """
69+ Continuously logs indexing progress, elapsed and estimated remaining
70+ indexing time.
71+ """
72+
73+ MONITORING_DELAY_SEC = 0.5
74+
75+ def __init__ (self , psycopg_connect_kwargs : Dict [str , str ]) -> None :
76+ self .psycopg_connect_kwargs = psycopg_connect_kwargs
77+ self .monitoring_condition = threading .Condition ()
78+ self .stop_requested = False
79+ self .psycopg_connect_kwargs = psycopg_connect_kwargs
80+ self .prev_phase = None
81+ self .prev_progress_pct = None
82+ self .prev_tuples_done = None
83+ self .prev_report_time_sec = None
84+ self .time_to_load_all_tuples_sec = None
85+
86+ def report_progress (
87+ self ,
88+ phase : str ,
89+ progress_pct : Any ,
90+ tuples_done : Any ) -> None :
91+ if progress_pct is None :
92+ progress_pct = 0.0
93+ progress_pct = float (progress_pct )
94+ if tuples_done is None :
95+ tuples_done = 0
96+ tuples_done = int (tuples_done )
97+ # Only report progress when phase or percentage change.
98+ if (phase == self .prev_phase and
99+ progress_pct == self .prev_progress_pct and
100+ tuples_done == self .prev_tuples_done ):
101+ return
102+ time_now_sec = time .time ()
103+
104+ elapsed_time_sec = time_now_sec - self .indexing_start_time_sec
105+ fields = [
106+ f"Phase: { phase } " ,
107+ f"progress: { progress_pct :.1f} %" ,
108+ f"elapsed time: { elapsed_time_sec :.3f} sec"
109+ ]
110+ if (self .prev_report_time_sec is not None and
111+ self .prev_tuples_done is not None and
112+ elapsed_time_sec ):
113+ overall_tuples_per_sec = tuples_done / elapsed_time_sec
114+ fields .append (
115+ f"overall tuples/sec: { overall_tuples_per_sec :%.2f} " )
116+
117+ time_since_last_report_sec = time_now_sec - self .prev_report_time_sec
118+ if time_since_last_report_sec > 0 :
119+ cur_tuples_per_sec = ((tuples_done - self .prev_tuples_done ) /
120+ time_since_last_report_sec )
121+ fields .append (
122+ f"current tuples/sec: { cur_tuples_per_sec :%.2f} " )
123+
124+
125+ remaining_pct = 100 - progress_pct
126+ if progress_pct > 0 and remaining_pct > 0 :
127+ estimated_remaining_time_sec = \
128+ elapsed_time_sec / progress_pct * remaining_pct
129+ estimated_total_time_sec = \
130+ elapsed_time_sec + estimated_remaining_time_sec
131+ fields .extend ([
132+ "estimated remaining time: " \
133+ f"{ estimated_remaining_time_sec :.3f} sec" ,
134+ f"estimated total time: { estimated_total_time_sec :.3f} sec"
135+ ])
136+ print (", " .join (fields ))
137+ sys .stdout .flush ()
138+
139+ self .prev_progress_pct = progress_pct
140+ self .prev_phase = phase
141+ self .prev_tuples_done = tuples_done
142+ self .prev_report_time_sec = time_now_sec
143+
144+ def monitoring_loop_impl (self , monitoring_cur ) -> None :
145+ while True :
146+ # Indexing progress query taken from
147+ # https://github.com/pgvector/pgvector/blob/master/README.md
148+ monitoring_cur .execute (
149+ "SELECT phase, " +
150+ "round(100.0 * blocks_done / nullif(blocks_total, 0), 1), " +
151+ "tuples_done " +
152+ "FROM pg_stat_progress_create_index" );
153+ result_rows = monitoring_cur .fetchall ()
154+
155+ if len (result_rows ) == 1 :
156+ phase , progress_pct , tuples_done = result_rows [0 ]
157+ self .report_progress (phase , progress_pct , tuples_done )
158+ if (self .time_to_load_all_tuples_sec is None and
159+ phase == 'building index: loading tuples' and
160+ progress_pct is not None and
161+ float (progress_pct ) > 100.0 - 1e-7 ):
162+ # Even after pgvector reports progress as 100%, it still spends
163+ # some time postprocessing the index and writing it to disk.
164+ # We keep track of the the time it takes to reach 100%
165+ # separately.
166+ self .time_to_load_all_tuples_sec = \
167+ time .time () - self .indexing_start_time_sec
168+ elif len (result_rows ) > 0 :
169+ # This should not happen.
170+ print (f"Expected exactly one progress result row, got: { result_rows } " )
171+ with self .monitoring_condition :
172+ if self .stop_requested :
173+ return
174+ self .monitoring_condition .wait (
175+ timeout = self .MONITORING_DELAY_SEC )
176+ if self .stop_requested :
177+ return
178+
179+ def monitor_progress (self ) -> None :
180+ prev_phase = None
181+ prev_progress_pct = None
182+ with psycopg .connect (** self .psycopg_connect_kwargs ) as monitoring_conn :
183+ with monitoring_conn .cursor () as monitoring_cur :
184+ self .monitoring_loop_impl (monitoring_cur )
185+
186+ def start_monitoring_thread (self ) -> None :
187+ self .indexing_start_time_sec = time .time ()
188+ self .monitoring_thread = threading .Thread (target = self .monitor_progress )
189+ self .monitoring_thread .start ()
190+
191+ def stop_monitoring_thread (self ) -> None :
192+ with self .monitoring_condition :
193+ self .stop_requested = True
194+ self .monitoring_condition .notify_all ()
195+ self .monitoring_thread .join ()
196+ self .indexing_time_sec = time .time () - self .indexing_start_time_sec
197+
198+ def report_timings (self ) -> None :
199+ print ("pgvector total indexing time: {:3f} sec" .format (
200+ self .indexing_time_sec ))
201+ if self .time_to_load_all_tuples_sec is not None :
202+ print (" Time to load all tuples into the index: {:.3f} sec" .format (
203+ self .time_to_load_all_tuples_sec
204+ ))
205+ postprocessing_time_sec = \
206+ self .indexing_time_sec - self .time_to_load_all_tuples_sec
207+ print (" Index postprocessing time: {:.3f} sec" .format (
208+ postprocessing_time_sec ))
209+ else :
210+ print (" Detailed breakdown of indexing time not available." )
211+
52212class PGVector (BaseANN ):
53213 def __init__ (self , metric , method_param ):
54214 self ._metric = metric
@@ -63,6 +223,21 @@ def __init__(self, metric, method_param):
63223 else :
64224 raise RuntimeError (f"unknown metric { metric } " )
65225
226+ def get_metric_properties (self ) -> Dict [str , str ]:
227+ """
228+ Get properties of the metric type associated with this index.
229+
230+ Returns:
231+ A dictionary with keys distance_operator and ops_type.
232+ """
233+ if self ._metric not in METRIC_PROPERTIES :
234+ raise ValueError (
235+ "Unknown metric: {}. Valid metrics: {}" .format (
236+ self ._metric ,
237+ ', ' .join (sorted (METRIC_PROPERTIES .keys ()))
238+ ))
239+ return METRIC_PROPERTIES [self ._metric ]
240+
66241 def ensure_pgvector_extension_created (self , conn : psycopg .Connection ) -> None :
67242 """
68243 Ensure that `CREATE EXTENSION vector` has been executed.
@@ -124,23 +299,37 @@ def fit(self, X):
124299 cur .execute ("CREATE TABLE items (id int, embedding vector(%d))" % X .shape [1 ])
125300 cur .execute ("ALTER TABLE items ALTER COLUMN embedding SET STORAGE PLAIN" )
126301 print ("copying data..." )
302+ sys .stdout .flush ()
303+ num_rows = 0
304+ insert_start_time_sec = time .time ()
127305 with cur .copy ("COPY items (id, embedding) FROM STDIN WITH (FORMAT BINARY)" ) as copy :
128306 copy .set_types (["int4" , "vector" ])
129307 for i , embedding in enumerate (X ):
130308 copy .write_row ((i , embedding ))
309+ insert_elapsed_time_sec = time .time () - insert_start_time_sec
310+ print ("inserted {} rows into table in {:.3f} seconds" .format (
311+ num_rows , insert_elapsed_time_sec ))
312+
131313 print ("creating index..." )
132- if self ._metric == "angular" :
133- cur .execute (
134- "CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops) WITH (m = %d, ef_construction = %d)" % (self ._m , self ._ef_construction )
314+ sys .stdout .flush ()
315+ create_index_str = \
316+ "CREATE INDEX ON items USING hnsw (embedding vector_%s_ops) " \
317+ "WITH (m = %d, ef_construction = %d)" % (
318+ self .get_metric_properties ()["ops_type" ],
319+ self ._m ,
320+ self ._ef_construction
135321 )
136- elif self ._metric == "euclidean" :
137- cur .execute ("CREATE INDEX ON items USING hnsw (embedding vector_l2_ops) WITH (m = %d, ef_construction = %d)" % (self ._m , self ._ef_construction ))
138- else :
139- raise RuntimeError (f"unknown metric { self ._metric } " )
322+ progress_monitor = IndexingProgressMonitor (psycopg_connect_kwargs )
323+ progress_monitor .start_monitoring_thread ()
324+
325+ try :
326+ cur .execute (create_index_str )
327+ finally :
328+ progress_monitor .stop_monitoring_thread ()
140329 print ("done!" )
330+ progress_monitor .report_timings ()
141331 self ._cur = cur
142332
143-
144333 def set_query_arguments (self , ef_search ):
145334 self ._ef_search = ef_search
146335 self ._cur .execute ("SET hnsw.ef_search = %d" % ef_search )
0 commit comments