Skip to content

Add ability to specify Postgres connection credentials #564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 97 additions & 2 deletions ann_benchmarks/algorithms/pgvector/module.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,52 @@
"""
This module supports connecting to a PostgreSQL instance and performing vector
indexing and search using the pgvector extension. The default behavior uses
the "ann" value of PostgreSQL user name, password, and database name, as well
as the default host and port values of the psycopg driver.

If PostgreSQL is managed externally, e.g. in a cloud DBaaS environment, the
environment variable overrides listed below are available for setting PostgreSQL
connection parameters:

ANN_BENCHMARKS_PG_USER
ANN_BENCHMARKS_PG_PASSWORD
ANN_BENCHMARKS_PG_DBNAME
ANN_BENCHMARKS_PG_HOST
ANN_BENCHMARKS_PG_PORT

This module starts the PostgreSQL service automatically using the "service"
command. The environment variable ANN_BENCHMARKS_PG_START_SERVICE could be set
to "false" (or e.g. "0" or "no") in order to disable this behavior.

This module will also attempt to create the pgvector extension inside the
target database, if it has not been already created.
"""

import subprocess
import sys
import os

import pgvector.psycopg
import psycopg

from typing import Dict, Any, Optional

from ..base.module import BaseANN
from ...util import get_bool_env_var


def get_pg_param_env_var_name(pg_param_name: str) -> str:
return f'ANN_BENCHMARKS_PG_{pg_param_name.upper()}'


def get_pg_conn_param(
pg_param_name: str,
default_value: Optional[str] = None) -> Optional[str]:
env_var_name = get_pg_param_env_var_name(pg_param_name)
env_var_value = os.getenv(env_var_name, default_value)
if env_var_value is None or len(env_var_value.strip()) == 0:
return default_value
return env_var_value


class PGVector(BaseANN):
Expand All @@ -21,9 +63,61 @@ def __init__(self, metric, method_param):
else:
raise RuntimeError(f"unknown metric {metric}")

def ensure_pgvector_extension_created(self, conn: psycopg.Connection) -> None:
"""
Ensure that `CREATE EXTENSION vector` has been executed.
"""
with conn.cursor() as cur:
# We have to use a separate cursor for this operation.
# If we reuse the same cursor for later operations, we might get
# the following error:
# KeyError: "couldn't find the type 'vector' in the types registry"
cur.execute(
"SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'vector')")
pgvector_exists = cur.fetchone()[0]
if pgvector_exists:
print("vector extension already exists")
else:
print("vector extension does not exist, creating")
cur.execute("CREATE EXTENSION vector")

def fit(self, X):
subprocess.run("service postgresql start", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr)
conn = psycopg.connect(user="ann", password="ann", dbname="ann", autocommit=True)
psycopg_connect_kwargs: Dict[str, Any] = dict(
autocommit=True,
)
for arg_name in ['user', 'password', 'dbname']:
# The default value is "ann" for all of these parameters.
psycopg_connect_kwargs[arg_name] = get_pg_conn_param(
arg_name, 'ann')

# If host/port are not specified, leave the default choice to the
# psycopg driver.
pg_host: Optional[str] = get_pg_conn_param('host')
if pg_host is not None:
psycopg_connect_kwargs['host'] = pg_host

pg_port_str: Optional[str] = get_pg_conn_param('port')
if pg_port_str is not None:
psycopg_connect_kwargs['port'] = int(pg_port_str)

should_start_service = get_bool_env_var(
get_pg_param_env_var_name('start_service'),
default_value=True)
if should_start_service:
subprocess.run(
"service postgresql start",
shell=True,
check=True,
stdout=sys.stdout,
stderr=sys.stderr)
else:
print(
"Assuming that PostgreSQL service is managed externally. "
"Not attempting to start the service.")

conn = psycopg.connect(**psycopg_connect_kwargs)
self.ensure_pgvector_extension_created(conn)

pgvector.psycopg.register_vector(conn)
cur = conn.cursor()
cur.execute("DROP TABLE IF EXISTS items")
Expand All @@ -46,6 +140,7 @@ def fit(self, X):
print("done!")
self._cur = cur


def set_query_arguments(self, ef_search):
self._ef_search = ef_search
self._cur.execute("SET hnsw.ef_search = %d" % ef_search)
Expand Down
48 changes: 26 additions & 22 deletions ann_benchmarks/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import importlib
import os
import glob
import logging
from enum import Enum
from itertools import product
from typing import Any, Dict, List, Optional, Union
Expand All @@ -21,18 +22,19 @@ class Definition:
query_argument_groups: List[List[Any]]
disabled: bool


def instantiate_algorithm(definition: Definition) -> BaseANN:
"""
Create a `BaseANN` from a definition.

Args:
definition (Definition): An object containing information about the algorithm.

Returns:
BaseANN: Instantiated algorithm

Note:
The constructors for the algorithm definition are generally located at
The constructors for the algorithm definition are generally located at
ann_benchmarks/algorithms/*/module.py.
"""
print(f"Trying to instantiate {definition.module}.{definition.constructor}({definition.arguments})")
Expand All @@ -52,7 +54,7 @@ def algorithm_status(definition: Definition) -> InstantiationStatus:
"""
Determine the instantiation status of the algorithm based on its python module and constructor.

Attempts to find the Python class constructor based on the definition's module path and
Attempts to find the Python class constructor based on the definition's module path and
constructor name.

Args:
Expand All @@ -68,6 +70,8 @@ def algorithm_status(definition: Definition) -> InstantiationStatus:
else:
return InstantiationStatus.NO_CONSTRUCTOR
except ImportError:
logging.exception("Could not import algorithm module for %s",
definition.module)
return InstantiationStatus.NO_MODULE


Expand Down Expand Up @@ -103,7 +107,7 @@ def _generate_combinations(args: Union[List[Any], Dict[Any, Any]]) -> List[Union
def _substitute_variables(arg: Any, vs: Dict[str, Any]) -> Any:
"""
Substitutes any string variables present in the argument structure with provided values.

Support for nested substitution in the case `arg` is a List or Dict.

Args:
Expand Down Expand Up @@ -160,8 +164,8 @@ def _get_definitions(base_dir: str = "ann_benchmarks/algorithms") -> List[Dict[s

def _get_algorithm_definitions(point_type: str, distance_metric: str, base_dir: str = "ann_benchmarks/algorithms") -> Dict[str, Dict[str, Any]]:
"""Get algorithm definitions for a specific point type and distance metric.
A specific algorithm folder can have multiple algorithm definitions for a given point type and

A specific algorithm folder can have multiple algorithm definitions for a given point type and
metric. For example, `ann_benchmarks.algorithms.nmslib` has two definitions for euclidean float
data: specifically `SW-graph(nmslib)` and `hnsw(nmslib)`, even though the module is named nmslib.

Expand All @@ -176,7 +180,7 @@ def _get_algorithm_definitions(point_type: str, distance_metric: str, base_dir:
"disabled": false,
"docker_tag": ann-benchmarks-nmslib,
...
},
},
'SW-graph(nmslib)': {
"base_args": ['@metric', sw-graph],
"constructor": NmslibReuseIndex,
Expand Down Expand Up @@ -205,9 +209,9 @@ def _get_algorithm_definitions(point_type: str, distance_metric: str, base_dir:
def list_algorithms(base_dir: str = "ann_benchmarks/algorithms") -> None:
"""
Output (to stdout), a list of all algorithms, with their supported point types and metrics.

Args:
base_dir (str, optional): The base directory where the algorithms are stored.
base_dir (str, optional): The base directory where the algorithms are stored.
Defaults to "ann_benchmarks/algorithms".
"""
all_configs = _get_definitions(base_dir)
Expand Down Expand Up @@ -236,7 +240,7 @@ def list_algorithms(base_dir: str = "ann_benchmarks/algorithms") -> None:

def generate_arg_combinations(run_group: Dict[str, Any], arg_type: str) -> List:
"""Generate combinations of arguments from a run group for a specific argument type.

Args:
run_group (Dict[str, Any]): The run group containing argument definitions.
arg_type (str): The type of argument group to generate combinations for.
Expand All @@ -262,10 +266,10 @@ def generate_arg_combinations(run_group: Dict[str, Any], arg_type: str) -> List:


def prepare_args(run_group: Dict[str, Any]) -> List:
"""For an Algorithm's run group, prepare arguments.
"""For an Algorithm's run group, prepare arguments.

An `arg_groups` is preferenced over an `args` key.

Args:
run_group (Dict[str, Any]): The run group containing argument definitions.

Expand All @@ -283,7 +287,7 @@ def prepare_args(run_group: Dict[str, Any]) -> List:

def prepare_query_args(run_group: Dict[str, Any]) -> List:
"""For an algorithm's run group, prepare query args/ query arg groups.

Args:
run_group (Dict[str, Any]): The run group containing argument definitions.

Expand All @@ -299,28 +303,28 @@ def prepare_query_args(run_group: Dict[str, Any]) -> List:
def create_definitions_from_algorithm(name: str, algo: Dict[str, Any], dimension: int, distance_metric: str = "euclidean", count: int = 10) -> List[Definition]:
"""
Create definitions from an indvidual algorithm. An algorithm (e.g. annoy) can have multiple
definitions based on various run groups (see config.ymls for clear examples).
definitions based on various run groups (see config.ymls for clear examples).

Args:
name (str): Name of the algorithm.
algo (Dict[str, Any]): Dictionary with algorithm parameters.
dimension (int): Dimension of the algorithm.
distance_metric (str, optional): Distance metric used by the algorithm. Defaults to "euclidean".
count (int, optional): Count of the definitions to be created. Defaults to 10.

Raises:
Exception: If the algorithm does not define "docker_tag", "module" or "constructor" properties.

Returns:
List[Definition]: A list of definitions created from the algorithm.
"""
required_properties = ["docker_tag", "module", "constructor"]
missing_properties = [prop for prop in required_properties if prop not in algo]
if missing_properties:
raise ValueError(f"Algorithm {name} is missing the following properties: {', '.join(missing_properties)}")

base_args = algo.get("base_args", [])

definitions = []
for run_group in algo["run_groups"].values():
args = prepare_args(run_group)
Expand All @@ -336,7 +340,7 @@ def create_definitions_from_algorithm(name: str, algo: Dict[str, Any], dimension

vs = {"@count": count, "@metric": distance_metric, "@dimension": dimension}
current_args = [_substitute_variables(arg, vs) for arg in current_args]

definitions.append(
Definition(
algorithm=name,
Expand Down Expand Up @@ -369,6 +373,6 @@ def get_definitions(
definitions.extend(
create_definitions_from_algorithm(name, algo, dimension, distance_metric, count)
)


return definitions
15 changes: 15 additions & 0 deletions ann_benchmarks/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os


def get_bool_env_var(env_var_name: str, default_value: bool) -> bool:
"""
Interpret the given environment variable's value as a boolean flag. If it
is not specified or empty, return the given default value.
"""
str_value = os.getenv(env_var_name)
if str_value is None:
return default_value
str_value = str_value.strip().lower()
if len(str_value) == 0:
return default_value
return str_value in ['y', 'yes', '1', 'true', 't', 'on']
Loading