1+ """
2+ This module supports connecting to a PostgreSQL instance and performing vector
3+ indexing and search using the pgvector extension. The default behavior uses
4+ the "ann" value of PostgreSQL user name, password, and database name, as well
5+ as the default host and port values of the psycopg driver.
6+
7+ If PostgreSQL is managed externally, e.g. in a cloud DBaaS environment, the
8+ environment variable overrides listed below are available for setting PostgreSQL
9+ connection parameters:
10+
11+ ANN_BENCHMARKS_PG_USER
12+ ANN_BENCHMARKS_PG_PASSWORD
13+ ANN_BENCHMARKS_PG_DBNAME
14+ ANN_BENCHMARKS_PG_HOST
15+ ANN_BENCHMARKS_PG_PORT
16+
17+ This module starts the PostgreSQL service automatically using the "service"
18+ command. The environment variable ANN_BENCHMARKS_PG_START_SERVICE could be set
19+ to "false" (or e.g. "0" or "no") in order to disable this behavior.
20+
21+ This module will also attempt to create the pgvector extension inside the
22+ target database, if it has not been already created.
23+ """
24+
125import subprocess
226import sys
27+ import os
328
429import pgvector .psycopg
530import psycopg
631
32+ from typing import Dict , Any , Optional
33+
734from ..base .module import BaseANN
35+ from ...util import get_bool_env_var
36+
37+
38+ def get_pg_param_env_var_name (pg_param_name : str ) -> str :
39+ return f'ANN_BENCHMARKS_PG_{ pg_param_name .upper ()} '
40+
41+
42+ def get_pg_conn_param (
43+ pg_param_name : str ,
44+ default_value : Optional [str ] = None ) -> Optional [str ]:
45+ env_var_name = get_pg_param_env_var_name (pg_param_name )
46+ env_var_value = os .getenv (env_var_name , default_value )
47+ if env_var_value is None or len (env_var_value .strip ()) == 0 :
48+ return default_value
49+ return env_var_value
850
951
1052class PGVector (BaseANN ):
@@ -21,9 +63,61 @@ def __init__(self, metric, method_param):
2163 else :
2264 raise RuntimeError (f"unknown metric { metric } " )
2365
66+ def ensure_pgvector_extension_created (self , conn : psycopg .Connection ) -> None :
67+ """
68+ Ensure that `CREATE EXTENSION vector` has been executed.
69+ """
70+ with conn .cursor () as cur :
71+ # We have to use a separate cursor for this operation.
72+ # If we reuse the same cursor for later operations, we might get
73+ # the following error:
74+ # KeyError: "couldn't find the type 'vector' in the types registry"
75+ cur .execute (
76+ "SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'vector')" )
77+ pgvector_exists = cur .fetchone ()[0 ]
78+ if pgvector_exists :
79+ print ("vector extension already exists" )
80+ else :
81+ print ("vector extension does not exist, creating" )
82+ cur .execute ("CREATE EXTENSION vector" )
83+
2484 def fit (self , X ):
25- subprocess .run ("service postgresql start" , shell = True , check = True , stdout = sys .stdout , stderr = sys .stderr )
26- conn = psycopg .connect (user = "ann" , password = "ann" , dbname = "ann" , autocommit = True )
85+ psycopg_connect_kwargs : Dict [str , Any ] = dict (
86+ autocommit = True ,
87+ )
88+ for arg_name in ['user' , 'password' , 'dbname' ]:
89+ # The default value is "ann" for all of these parameters.
90+ psycopg_connect_kwargs [arg_name ] = get_pg_conn_param (
91+ arg_name , 'ann' )
92+
93+ # If host/port are not specified, leave the default choice to the
94+ # psycopg driver.
95+ pg_host : Optional [str ] = get_pg_conn_param ('host' )
96+ if pg_host is not None :
97+ psycopg_connect_kwargs ['host' ] = pg_host
98+
99+ pg_port_str : Optional [str ] = get_pg_conn_param ('port' )
100+ if pg_port_str is not None :
101+ psycopg_connect_kwargs ['port' ] = int (pg_port_str )
102+
103+ should_start_service = get_bool_env_var (
104+ get_pg_param_env_var_name ('start_service' ),
105+ default_value = True )
106+ if should_start_service :
107+ subprocess .run (
108+ "service postgresql start" ,
109+ shell = True ,
110+ check = True ,
111+ stdout = sys .stdout ,
112+ stderr = sys .stderr )
113+ else :
114+ print (
115+ "Assuming that PostgreSQL service is managed externally. "
116+ "Not attempting to start the service." )
117+
118+ conn = psycopg .connect (** psycopg_connect_kwargs )
119+ self .ensure_pgvector_extension_created (conn )
120+
27121 pgvector .psycopg .register_vector (conn )
28122 cur = conn .cursor ()
29123 cur .execute ("DROP TABLE IF EXISTS items" )
@@ -46,6 +140,7 @@ def fit(self, X):
46140 print ("done!" )
47141 self ._cur = cur
48142
143+
49144 def set_query_arguments (self , ef_search ):
50145 self ._ef_search = ef_search
51146 self ._cur .execute ("SET hnsw.ef_search = %d" % ef_search )
0 commit comments