Skip to content

Commit 4078c80

Browse files
committed
Index build progress tracking in the pgvector module
1 parent 33ecd5f commit 4078c80

File tree

1 file changed

+198
-9
lines changed

1 file changed

+198
-9
lines changed

ann_benchmarks/algorithms/pgvector/module.py

Lines changed: 198 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222
target database, if it has not been already created.
2323
"""
2424

25+
import os
2526
import subprocess
2627
import sys
27-
import os
28+
import threading
29+
import time
2830

2931
import pgvector.psycopg
3032
import psycopg
@@ -35,6 +37,19 @@
3537
from ...util import get_bool_env_var
3638

3739

40+
METRIC_PROPERTIES = {
41+
"angular": {
42+
"distance_operator": "<=>",
43+
# A substring of e.g. vector_cosine_ops or halfvec_cosine_ops
44+
"ops_type": "cosine",
45+
},
46+
"euclidean": {
47+
"distance_operator": "<->",
48+
"ops_type": "l2",
49+
}
50+
}
51+
52+
3853
def get_pg_param_env_var_name(pg_param_name: str) -> str:
3954
return f'ANN_BENCHMARKS_PG_{pg_param_name.upper()}'
4055

@@ -49,6 +64,151 @@ def get_pg_conn_param(
4964
return env_var_value
5065

5166

67+
class IndexingProgressMonitor:
68+
"""
69+
Continuously logs indexing progress, elapsed and estimated remaining
70+
indexing time.
71+
"""
72+
73+
MONITORING_DELAY_SEC = 0.5
74+
75+
def __init__(self, psycopg_connect_kwargs: Dict[str, str]) -> None:
76+
self.psycopg_connect_kwargs = psycopg_connect_kwargs
77+
self.monitoring_condition = threading.Condition()
78+
self.stop_requested = False
79+
self.psycopg_connect_kwargs = psycopg_connect_kwargs
80+
self.prev_phase = None
81+
self.prev_progress_pct = None
82+
self.prev_tuples_done = None
83+
self.prev_report_time_sec = None
84+
self.time_to_load_all_tuples_sec = None
85+
86+
def report_progress(
87+
self,
88+
phase: str,
89+
progress_pct: Any,
90+
tuples_done: Any) -> None:
91+
if progress_pct is None:
92+
progress_pct = 0.0
93+
progress_pct = float(progress_pct)
94+
if tuples_done is None:
95+
tuples_done = 0
96+
tuples_done = int(tuples_done)
97+
# Only report progress when phase or percentage change.
98+
if (phase == self.prev_phase and
99+
progress_pct == self.prev_progress_pct and
100+
tuples_done == self.prev_tuples_done):
101+
return
102+
time_now_sec = time.time()
103+
104+
elapsed_time_sec = time_now_sec - self.indexing_start_time_sec
105+
fields = [
106+
f"Phase: {phase}",
107+
f"progress: {progress_pct:.1f}%",
108+
f"elapsed time: {elapsed_time_sec:.3f} sec"
109+
]
110+
if (self.prev_report_time_sec is not None and
111+
self.prev_tuples_done is not None and
112+
elapsed_time_sec):
113+
overall_tuples_per_sec = tuples_done / elapsed_time_sec
114+
fields.append(
115+
f"overall tuples/sec: {overall_tuples_per_sec:%.2f}")
116+
117+
time_since_last_report_sec = time_now_sec - self.prev_report_time_sec
118+
if time_since_last_report_sec > 0:
119+
cur_tuples_per_sec = ((tuples_done - self.prev_tuples_done) /
120+
time_since_last_report_sec)
121+
fields.append(
122+
f"current tuples/sec: {cur_tuples_per_sec:%.2f}")
123+
124+
125+
remaining_pct = 100 - progress_pct
126+
if progress_pct > 0 and remaining_pct > 0:
127+
estimated_remaining_time_sec = \
128+
elapsed_time_sec / progress_pct * remaining_pct
129+
estimated_total_time_sec = \
130+
elapsed_time_sec + estimated_remaining_time_sec
131+
fields.extend([
132+
"estimated remaining time: " \
133+
f"{estimated_remaining_time_sec:.3f} sec" ,
134+
f"estimated total time: {estimated_total_time_sec:.3f} sec"
135+
])
136+
print(", ".join(fields))
137+
sys.stdout.flush()
138+
139+
self.prev_progress_pct = progress_pct
140+
self.prev_phase = phase
141+
self.prev_tuples_done = tuples_done
142+
self.prev_report_time_sec = time_now_sec
143+
144+
def monitoring_loop_impl(self, monitoring_cur) -> None:
145+
while True:
146+
# Indexing progress query taken from
147+
# https://github.com/pgvector/pgvector/blob/master/README.md
148+
monitoring_cur.execute(
149+
"SELECT phase, " +
150+
"round(100.0 * blocks_done / nullif(blocks_total, 0), 1), " +
151+
"tuples_done " +
152+
"FROM pg_stat_progress_create_index");
153+
result_rows = monitoring_cur.fetchall()
154+
155+
if len(result_rows) == 1:
156+
phase, progress_pct, tuples_done = result_rows[0]
157+
self.report_progress(phase, progress_pct, tuples_done)
158+
if (self.time_to_load_all_tuples_sec is None and
159+
phase == 'building index: loading tuples' and
160+
progress_pct is not None and
161+
float(progress_pct) > 100.0 - 1e-7):
162+
# Even after pgvector reports progress as 100%, it still spends
163+
# some time postprocessing the index and writing it to disk.
164+
# We keep track of the the time it takes to reach 100%
165+
# separately.
166+
self.time_to_load_all_tuples_sec = \
167+
time.time() - self.indexing_start_time_sec
168+
elif len(result_rows) > 0:
169+
# This should not happen.
170+
print(f"Expected exactly one progress result row, got: {result_rows}")
171+
with self.monitoring_condition:
172+
if self.stop_requested:
173+
return
174+
self.monitoring_condition.wait(
175+
timeout=self.MONITORING_DELAY_SEC)
176+
if self.stop_requested:
177+
return
178+
179+
def monitor_progress(self) -> None:
180+
prev_phase = None
181+
prev_progress_pct = None
182+
with psycopg.connect(**self.psycopg_connect_kwargs) as monitoring_conn:
183+
with monitoring_conn.cursor() as monitoring_cur:
184+
self.monitoring_loop_impl(monitoring_cur)
185+
186+
def start_monitoring_thread(self) -> None:
187+
self.indexing_start_time_sec = time.time()
188+
self.monitoring_thread = threading.Thread(target=self.monitor_progress)
189+
self.monitoring_thread.start()
190+
191+
def stop_monitoring_thread(self) -> None:
192+
with self.monitoring_condition:
193+
self.stop_requested = True
194+
self.monitoring_condition.notify_all()
195+
self.monitoring_thread.join()
196+
self.indexing_time_sec = time.time() - self.indexing_start_time_sec
197+
198+
def report_timings(self) -> None:
199+
print("pgvector total indexing time: {:3f} sec".format(
200+
self.indexing_time_sec))
201+
if self.time_to_load_all_tuples_sec is not None:
202+
print(" Time to load all tuples into the index: {:.3f} sec".format(
203+
self.time_to_load_all_tuples_sec
204+
))
205+
postprocessing_time_sec = \
206+
self.indexing_time_sec - self.time_to_load_all_tuples_sec
207+
print(" Index postprocessing time: {:.3f} sec".format(
208+
postprocessing_time_sec))
209+
else:
210+
print(" Detailed breakdown of indexing time not available.")
211+
52212
class PGVector(BaseANN):
53213
def __init__(self, metric, method_param):
54214
self._metric = metric
@@ -63,6 +223,21 @@ def __init__(self, metric, method_param):
63223
else:
64224
raise RuntimeError(f"unknown metric {metric}")
65225

226+
def get_metric_properties(self) -> Dict[str, str]:
227+
"""
228+
Get properties of the metric type associated with this index.
229+
230+
Returns:
231+
A dictionary with keys distance_operator and ops_type.
232+
"""
233+
if self._metric not in METRIC_PROPERTIES:
234+
raise ValueError(
235+
"Unknown metric: {}. Valid metrics: {}".format(
236+
self._metric,
237+
', '.join(sorted(METRIC_PROPERTIES.keys()))
238+
))
239+
return METRIC_PROPERTIES[self._metric]
240+
66241
def ensure_pgvector_extension_created(self, conn: psycopg.Connection) -> None:
67242
"""
68243
Ensure that `CREATE EXTENSION vector` has been executed.
@@ -124,23 +299,37 @@ def fit(self, X):
124299
cur.execute("CREATE TABLE items (id int, embedding vector(%d))" % X.shape[1])
125300
cur.execute("ALTER TABLE items ALTER COLUMN embedding SET STORAGE PLAIN")
126301
print("copying data...")
302+
sys.stdout.flush()
303+
num_rows = 0
304+
insert_start_time_sec = time.time()
127305
with cur.copy("COPY items (id, embedding) FROM STDIN WITH (FORMAT BINARY)") as copy:
128306
copy.set_types(["int4", "vector"])
129307
for i, embedding in enumerate(X):
130308
copy.write_row((i, embedding))
309+
insert_elapsed_time_sec = time.time() - insert_start_time_sec
310+
print("inserted {} rows into table in {:.3f} seconds".format(
311+
num_rows, insert_elapsed_time_sec))
312+
131313
print("creating index...")
132-
if self._metric == "angular":
133-
cur.execute(
134-
"CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops) WITH (m = %d, ef_construction = %d)" % (self._m, self._ef_construction)
314+
sys.stdout.flush()
315+
create_index_str = \
316+
"CREATE INDEX ON items USING hnsw (embedding vector_%s_ops) " \
317+
"WITH (m = %d, ef_construction = %d)" % (
318+
self.get_metric_properties()["ops_type"],
319+
self._m,
320+
self._ef_construction
135321
)
136-
elif self._metric == "euclidean":
137-
cur.execute("CREATE INDEX ON items USING hnsw (embedding vector_l2_ops) WITH (m = %d, ef_construction = %d)" % (self._m, self._ef_construction))
138-
else:
139-
raise RuntimeError(f"unknown metric {self._metric}")
322+
progress_monitor = IndexingProgressMonitor(psycopg_connect_kwargs)
323+
progress_monitor.start_monitoring_thread()
324+
325+
try:
326+
cur.execute(create_index_str)
327+
finally:
328+
progress_monitor.stop_monitoring_thread()
140329
print("done!")
330+
progress_monitor.report_timings()
141331
self._cur = cur
142332

143-
144333
def set_query_arguments(self, ef_search):
145334
self._ef_search = ef_search
146335
self._cur.execute("SET hnsw.ef_search = %d" % ef_search)

0 commit comments

Comments
 (0)