Skip to content

Commit 4d734fb

Browse files
Add ability to specify Postgres connection credentials (#564)
* Rework the approach for specifying PostgreSQL connection parameters * Remove trailing whitespace * Address review comments * Fix coding style
1 parent 2331417 commit 4d734fb

File tree

3 files changed

+138
-24
lines changed

3 files changed

+138
-24
lines changed

ann_benchmarks/algorithms/pgvector/module.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,52 @@
1+
"""
2+
This module supports connecting to a PostgreSQL instance and performing vector
3+
indexing and search using the pgvector extension. The default behavior uses
4+
the "ann" value of PostgreSQL user name, password, and database name, as well
5+
as the default host and port values of the psycopg driver.
6+
7+
If PostgreSQL is managed externally, e.g. in a cloud DBaaS environment, the
8+
environment variable overrides listed below are available for setting PostgreSQL
9+
connection parameters:
10+
11+
ANN_BENCHMARKS_PG_USER
12+
ANN_BENCHMARKS_PG_PASSWORD
13+
ANN_BENCHMARKS_PG_DBNAME
14+
ANN_BENCHMARKS_PG_HOST
15+
ANN_BENCHMARKS_PG_PORT
16+
17+
This module starts the PostgreSQL service automatically using the "service"
18+
command. The environment variable ANN_BENCHMARKS_PG_START_SERVICE could be set
19+
to "false" (or e.g. "0" or "no") in order to disable this behavior.
20+
21+
This module will also attempt to create the pgvector extension inside the
22+
target database, if it has not been already created.
23+
"""
24+
125
import subprocess
226
import sys
27+
import os
328

429
import pgvector.psycopg
530
import psycopg
631

32+
from typing import Dict, Any, Optional
33+
734
from ..base.module import BaseANN
35+
from ...util import get_bool_env_var
36+
37+
38+
def get_pg_param_env_var_name(pg_param_name: str) -> str:
39+
return f'ANN_BENCHMARKS_PG_{pg_param_name.upper()}'
40+
41+
42+
def get_pg_conn_param(
43+
pg_param_name: str,
44+
default_value: Optional[str] = None) -> Optional[str]:
45+
env_var_name = get_pg_param_env_var_name(pg_param_name)
46+
env_var_value = os.getenv(env_var_name, default_value)
47+
if env_var_value is None or len(env_var_value.strip()) == 0:
48+
return default_value
49+
return env_var_value
850

951

1052
class PGVector(BaseANN):
@@ -21,9 +63,61 @@ def __init__(self, metric, method_param):
2163
else:
2264
raise RuntimeError(f"unknown metric {metric}")
2365

66+
def ensure_pgvector_extension_created(self, conn: psycopg.Connection) -> None:
67+
"""
68+
Ensure that `CREATE EXTENSION vector` has been executed.
69+
"""
70+
with conn.cursor() as cur:
71+
# We have to use a separate cursor for this operation.
72+
# If we reuse the same cursor for later operations, we might get
73+
# the following error:
74+
# KeyError: "couldn't find the type 'vector' in the types registry"
75+
cur.execute(
76+
"SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'vector')")
77+
pgvector_exists = cur.fetchone()[0]
78+
if pgvector_exists:
79+
print("vector extension already exists")
80+
else:
81+
print("vector extension does not exist, creating")
82+
cur.execute("CREATE EXTENSION vector")
83+
2484
def fit(self, X):
25-
subprocess.run("service postgresql start", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr)
26-
conn = psycopg.connect(user="ann", password="ann", dbname="ann", autocommit=True)
85+
psycopg_connect_kwargs: Dict[str, Any] = dict(
86+
autocommit=True,
87+
)
88+
for arg_name in ['user', 'password', 'dbname']:
89+
# The default value is "ann" for all of these parameters.
90+
psycopg_connect_kwargs[arg_name] = get_pg_conn_param(
91+
arg_name, 'ann')
92+
93+
# If host/port are not specified, leave the default choice to the
94+
# psycopg driver.
95+
pg_host: Optional[str] = get_pg_conn_param('host')
96+
if pg_host is not None:
97+
psycopg_connect_kwargs['host'] = pg_host
98+
99+
pg_port_str: Optional[str] = get_pg_conn_param('port')
100+
if pg_port_str is not None:
101+
psycopg_connect_kwargs['port'] = int(pg_port_str)
102+
103+
should_start_service = get_bool_env_var(
104+
get_pg_param_env_var_name('start_service'),
105+
default_value=True)
106+
if should_start_service:
107+
subprocess.run(
108+
"service postgresql start",
109+
shell=True,
110+
check=True,
111+
stdout=sys.stdout,
112+
stderr=sys.stderr)
113+
else:
114+
print(
115+
"Assuming that PostgreSQL service is managed externally. "
116+
"Not attempting to start the service.")
117+
118+
conn = psycopg.connect(**psycopg_connect_kwargs)
119+
self.ensure_pgvector_extension_created(conn)
120+
27121
pgvector.psycopg.register_vector(conn)
28122
cur = conn.cursor()
29123
cur.execute("DROP TABLE IF EXISTS items")
@@ -46,6 +140,7 @@ def fit(self, X):
46140
print("done!")
47141
self._cur = cur
48142

143+
49144
def set_query_arguments(self, ef_search):
50145
self._ef_search = ef_search
51146
self._cur.execute("SET hnsw.ef_search = %d" % ef_search)

ann_benchmarks/definitions.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import importlib
44
import os
55
import glob
6+
import logging
67
from enum import Enum
78
from itertools import product
89
from typing import Any, Dict, List, Optional, Union
@@ -21,18 +22,19 @@ class Definition:
2122
query_argument_groups: List[List[Any]]
2223
disabled: bool
2324

25+
2426
def instantiate_algorithm(definition: Definition) -> BaseANN:
2527
"""
2628
Create a `BaseANN` from a definition.
27-
29+
2830
Args:
2931
definition (Definition): An object containing information about the algorithm.
3032
3133
Returns:
3234
BaseANN: Instantiated algorithm
3335
3436
Note:
35-
The constructors for the algorithm definition are generally located at
37+
The constructors for the algorithm definition are generally located at
3638
ann_benchmarks/algorithms/*/module.py.
3739
"""
3840
print(f"Trying to instantiate {definition.module}.{definition.constructor}({definition.arguments})")
@@ -52,7 +54,7 @@ def algorithm_status(definition: Definition) -> InstantiationStatus:
5254
"""
5355
Determine the instantiation status of the algorithm based on its python module and constructor.
5456
55-
Attempts to find the Python class constructor based on the definition's module path and
57+
Attempts to find the Python class constructor based on the definition's module path and
5658
constructor name.
5759
5860
Args:
@@ -68,6 +70,8 @@ def algorithm_status(definition: Definition) -> InstantiationStatus:
6870
else:
6971
return InstantiationStatus.NO_CONSTRUCTOR
7072
except ImportError:
73+
logging.exception("Could not import algorithm module for %s",
74+
definition.module)
7175
return InstantiationStatus.NO_MODULE
7276

7377

@@ -103,7 +107,7 @@ def _generate_combinations(args: Union[List[Any], Dict[Any, Any]]) -> List[Union
103107
def _substitute_variables(arg: Any, vs: Dict[str, Any]) -> Any:
104108
"""
105109
Substitutes any string variables present in the argument structure with provided values.
106-
110+
107111
Support for nested substitution in the case `arg` is a List or Dict.
108112
109113
Args:
@@ -160,8 +164,8 @@ def _get_definitions(base_dir: str = "ann_benchmarks/algorithms") -> List[Dict[s
160164

161165
def _get_algorithm_definitions(point_type: str, distance_metric: str, base_dir: str = "ann_benchmarks/algorithms") -> Dict[str, Dict[str, Any]]:
162166
"""Get algorithm definitions for a specific point type and distance metric.
163-
164-
A specific algorithm folder can have multiple algorithm definitions for a given point type and
167+
168+
A specific algorithm folder can have multiple algorithm definitions for a given point type and
165169
metric. For example, `ann_benchmarks.algorithms.nmslib` has two definitions for euclidean float
166170
data: specifically `SW-graph(nmslib)` and `hnsw(nmslib)`, even though the module is named nmslib.
167171
@@ -176,7 +180,7 @@ def _get_algorithm_definitions(point_type: str, distance_metric: str, base_dir:
176180
"disabled": false,
177181
"docker_tag": ann-benchmarks-nmslib,
178182
...
179-
},
183+
},
180184
'SW-graph(nmslib)': {
181185
"base_args": ['@metric', sw-graph],
182186
"constructor": NmslibReuseIndex,
@@ -205,9 +209,9 @@ def _get_algorithm_definitions(point_type: str, distance_metric: str, base_dir:
205209
def list_algorithms(base_dir: str = "ann_benchmarks/algorithms") -> None:
206210
"""
207211
Output (to stdout), a list of all algorithms, with their supported point types and metrics.
208-
212+
209213
Args:
210-
base_dir (str, optional): The base directory where the algorithms are stored.
214+
base_dir (str, optional): The base directory where the algorithms are stored.
211215
Defaults to "ann_benchmarks/algorithms".
212216
"""
213217
all_configs = _get_definitions(base_dir)
@@ -236,7 +240,7 @@ def list_algorithms(base_dir: str = "ann_benchmarks/algorithms") -> None:
236240

237241
def generate_arg_combinations(run_group: Dict[str, Any], arg_type: str) -> List:
238242
"""Generate combinations of arguments from a run group for a specific argument type.
239-
243+
240244
Args:
241245
run_group (Dict[str, Any]): The run group containing argument definitions.
242246
arg_type (str): The type of argument group to generate combinations for.
@@ -262,10 +266,10 @@ def generate_arg_combinations(run_group: Dict[str, Any], arg_type: str) -> List:
262266

263267

264268
def prepare_args(run_group: Dict[str, Any]) -> List:
265-
"""For an Algorithm's run group, prepare arguments.
266-
269+
"""For an Algorithm's run group, prepare arguments.
270+
267271
An `arg_groups` is preferenced over an `args` key.
268-
272+
269273
Args:
270274
run_group (Dict[str, Any]): The run group containing argument definitions.
271275
@@ -283,7 +287,7 @@ def prepare_args(run_group: Dict[str, Any]) -> List:
283287

284288
def prepare_query_args(run_group: Dict[str, Any]) -> List:
285289
"""For an algorithm's run group, prepare query args/ query arg groups.
286-
290+
287291
Args:
288292
run_group (Dict[str, Any]): The run group containing argument definitions.
289293
@@ -299,28 +303,28 @@ def prepare_query_args(run_group: Dict[str, Any]) -> List:
299303
def create_definitions_from_algorithm(name: str, algo: Dict[str, Any], dimension: int, distance_metric: str = "euclidean", count: int = 10) -> List[Definition]:
300304
"""
301305
Create definitions from an indvidual algorithm. An algorithm (e.g. annoy) can have multiple
302-
definitions based on various run groups (see config.ymls for clear examples).
303-
306+
definitions based on various run groups (see config.ymls for clear examples).
307+
304308
Args:
305309
name (str): Name of the algorithm.
306310
algo (Dict[str, Any]): Dictionary with algorithm parameters.
307311
dimension (int): Dimension of the algorithm.
308312
distance_metric (str, optional): Distance metric used by the algorithm. Defaults to "euclidean".
309313
count (int, optional): Count of the definitions to be created. Defaults to 10.
310-
314+
311315
Raises:
312316
Exception: If the algorithm does not define "docker_tag", "module" or "constructor" properties.
313-
317+
314318
Returns:
315319
List[Definition]: A list of definitions created from the algorithm.
316320
"""
317321
required_properties = ["docker_tag", "module", "constructor"]
318322
missing_properties = [prop for prop in required_properties if prop not in algo]
319323
if missing_properties:
320324
raise ValueError(f"Algorithm {name} is missing the following properties: {', '.join(missing_properties)}")
321-
325+
322326
base_args = algo.get("base_args", [])
323-
327+
324328
definitions = []
325329
for run_group in algo["run_groups"].values():
326330
args = prepare_args(run_group)
@@ -336,7 +340,7 @@ def create_definitions_from_algorithm(name: str, algo: Dict[str, Any], dimension
336340

337341
vs = {"@count": count, "@metric": distance_metric, "@dimension": dimension}
338342
current_args = [_substitute_variables(arg, vs) for arg in current_args]
339-
343+
340344
definitions.append(
341345
Definition(
342346
algorithm=name,
@@ -369,6 +373,6 @@ def get_definitions(
369373
definitions.extend(
370374
create_definitions_from_algorithm(name, algo, dimension, distance_metric, count)
371375
)
372-
376+
373377

374378
return definitions

ann_benchmarks/util.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
3+
4+
def get_bool_env_var(env_var_name: str, default_value: bool) -> bool:
5+
"""
6+
Interpret the given environment variable's value as a boolean flag. If it
7+
is not specified or empty, return the given default value.
8+
"""
9+
str_value = os.getenv(env_var_name)
10+
if str_value is None:
11+
return default_value
12+
str_value = str_value.strip().lower()
13+
if len(str_value) == 0:
14+
return default_value
15+
return str_value in ['y', 'yes', '1', 'true', 't', 'on']

0 commit comments

Comments
 (0)