Skip to content

Commit 179a8f7

Browse files
committed
Add ability to specify database connection credentials
1 parent 0e32628 commit 179a8f7

File tree

4 files changed

+64
-7
lines changed

4 files changed

+64
-7
lines changed

ann_benchmarks/algorithms/base/module.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,10 @@ def get_additional(self) -> Dict[str, Any]:
7676
return {}
7777

7878
def __str__(self) -> str:
79-
return self.name
79+
return self.name
80+
81+
def set_conn_params(self, conn_params: Dict[str, str]) -> None:
82+
"""Set connection parameters that might be required for connecting to
83+
the system under test, such as a database server.
84+
"""
85+
pass

ann_benchmarks/algorithms/pgvector/module.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
11
import subprocess
22
import sys
33

4+
from typing import Dict
5+
46
import pgvector.psycopg
57
import psycopg
68

79
from ..base.module import BaseANN
810

911

12+
DEFAULT_POSTGRES_USER = 'ann'
13+
DEFAULT_POSTGRES_PASSWORD = 'ann'
14+
DEFAULT_POSTGRES_DB_NAME = 'ann'
15+
16+
1017
class PGVector(BaseANN):
18+
_conn_params: Dict[str, str]
19+
1120
def __init__(self, metric, method_param):
1221
self._metric = metric
1322
self._m = method_param['M']
1423
self._ef_construction = method_param['efConstruction']
1524
self._cur = None
25+
self._conn_params = {}
1626

1727
if metric == "angular":
1828
self._query = "SELECT id FROM items ORDER BY embedding <=> %s LIMIT %s"
@@ -21,9 +31,29 @@ def __init__(self, metric, method_param):
2131
else:
2232
raise RuntimeError(f"unknown metric {metric}")
2333

34+
def set_conn_params(self, conn_params: Dict[str, str]) -> None:
35+
self._conn_params = conn_params
36+
37+
def get_conn_param(self, key: str, default_value: str) -> str:
38+
value = self._conn_params.get(key)
39+
if value is None:
40+
return default_value
41+
return value
42+
2443
def fit(self, X):
25-
subprocess.run("service postgresql start", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr)
26-
conn = psycopg.connect(user="ann", password="ann", dbname="ann", autocommit=True)
44+
psycopg_connect_kwargs: Dict[str, Any] = dict(
45+
autocommit=True,
46+
user=self.get_conn_param('user', DEFAULT_POSTGRES_USER),
47+
password=self.get_conn_param('password', DEFAULT_POSTGRES_PASSWORD),
48+
dbname=self.get_conn_param('dbname', DEFAULT_POSTGRES_DB_NAME)
49+
)
50+
for arg_name in ['host', 'port']:
51+
# For these arguments, if they are not specified, leave the default
52+
# choice to the psycopg driver.
53+
if self._conn_params.get(arg_name) is not None:
54+
psycopg_connect_kwargs[arg_name] = self._conn_params[arg_name]
55+
56+
conn = psycopg.connect(**psycopg_connect_kwargs)
2757
pgvector.psycopg.register_vector(conn)
2858
cur = conn.cursor()
2959
cur.execute("DROP TABLE IF EXISTS items")

ann_benchmarks/main.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .constants import INDEX_DIR
1919
from .datasets import DATASETS, get_dataset
2020
from .results import build_result_filepath
21-
from .runner import run, run_docker
21+
from .runner import run, run_docker, get_conn_params_from_args
2222

2323

2424
logging.config.fileConfig("logging.conf")
@@ -68,7 +68,7 @@ def run_worker(cpu: int, mem_limit: int, args: argparse.Namespace, queue: multip
6868
while not queue.empty():
6969
definition = queue.get()
7070
if args.local:
71-
run(definition, args.dataset, args.count, args.runs, args.batch)
71+
run(definition, args.dataset, args.count, args.runs, args.batch, get_conn_params_from_args(args))
7272
else:
7373
cpu_limit = str(cpu) if not args.batch else f"0-{multiprocessing.cpu_count() - 1}"
7474

@@ -122,6 +122,11 @@ def parse_arguments() -> argparse.Namespace:
122122
)
123123
parser.add_argument("--run-disabled", help="run algorithms that are disabled in algos.yml", action="store_true")
124124
parser.add_argument("--parallelism", type=positive_int, help="Number of Docker containers in parallel", default=1)
125+
parser.add_argument("--user", help="Username to connect to server")
126+
parser.add_argument("--password", help="Password to connect to server")
127+
parser.add_argument("--dbname", help="Database name to use when connecting to server")
128+
parser.add_argument("--host", help="Server to which to connect")
129+
parser.add_argument("--port", type=int, help="Port to use to connect to server")
125130

126131
args = parser.parse_args()
127132
if args.timeout == -1:

ann_benchmarks/runner.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,20 @@ def build_index(algo: BaseANN, X_train: numpy.ndarray) -> Tuple:
194194
return build_time, index_size
195195

196196

197-
def run(definition: Definition, dataset_name: str, count: int, run_count: int, batch: bool) -> None:
197+
def get_conn_params_from_args(args: argparse.Namespace) -> Dict[str, str]:
198+
"""Extracts server connection parameters from the given arguments object."""
199+
return {
200+
key: getattr(args, key)
201+
for key in ('user', 'password', 'dbname', 'host', 'port')
202+
if getattr(args, key) is not None
203+
}
204+
205+
def run(definition: Definition,
206+
dataset_name: str,
207+
count: int,
208+
run_count: int,
209+
batch: bool,
210+
conn_params: Dict[str, str]) -> None:
198211
"""Run the algorithm benchmarking.
199212
200213
Args:
@@ -203,6 +216,7 @@ def run(definition: Definition, dataset_name: str, count: int, run_count: int, b
203216
count (int): The number of results to return.
204217
run_count (int): The number of runs.
205218
batch (bool): If true, runs in batch mode.
219+
conn_params (dict): Parameters for connecting to the server.
206220
"""
207221
algo = instantiate_algorithm(definition)
208222
assert not definition.query_argument_groups or hasattr(
@@ -211,6 +225,7 @@ def run(definition: Definition, dataset_name: str, count: int, run_count: int, b
211225
error: query argument groups have been specified for {definition.module}.{definition.constructor}({definition.arguments}), but the \
212226
algorithm instantiated from it does not implement the set_query_arguments \
213227
function"""
228+
algo.set_conn_params(conn_params)
214229

215230
X_train, X_test, distance = load_and_transform_dataset(dataset_name)
216231

@@ -288,7 +303,8 @@ def run_from_cmdline():
288303
query_argument_groups=query_args,
289304
disabled=False,
290305
)
291-
run(definition, args.dataset, args.count, args.runs, args.batch)
306+
run(definition, args.dataset, args.count, args.runs, args.batch,
307+
get_conn_params_from_args(args))
292308

293309

294310
def run_docker(

0 commit comments

Comments
 (0)