diff --git a/.ddev/config.toml b/.ddev/config.toml index 2574dd8a6d013..89f129b9bebbe 100644 --- a/.ddev/config.toml +++ b/.ddev/config.toml @@ -88,6 +88,8 @@ paramiko = ['LGPL-2.1-only'] oracledb = ['Apache-2.0'] # https://github.com/psycopg/psycopg/blob/master/LICENSE.txt psycopg = ['LGPL-3.0-only'] +# https://github.com/psycopg/psycopg/blob/master/psycopg_pool/LICENSE.txt +psycopg-pool = ['LGPL-3.0-only'] # https://github.com/psycopg/psycopg2/blob/master/LICENSE # https://github.com/psycopg/psycopg2/blob/master/doc/COPYING.LESSER psycopg2-binary = ['LGPL-3.0-only', 'BSD-3-Clause'] diff --git a/LICENSE-3rdparty.csv b/LICENSE-3rdparty.csv index f547490297f8e..59ad24cf2cd99 100644 --- a/LICENSE-3rdparty.csv +++ b/LICENSE-3rdparty.csv @@ -64,6 +64,7 @@ protobuf,PyPI,BSD-3-Clause,Copyright 2008 Google Inc. protobuf,PyPI,BSD-3-Clause,Copyright 2008 Google Inc. All rights reserved. psutil,PyPI,BSD-3-Clause,"Copyright (c) 2009, Jay Loden, Dave Daeschler, Giampaolo Rodola'" psycopg,PyPI,LGPL-3.0-only,Copyright (C) 2020 The Psycopg Team +psycopg-pool,PyPI,LGPL-3.0-only,Copyright (C) 2020 The Psycopg Team psycopg2-binary,PyPI,BSD-3-Clause,Copyright 2013 Federico Di Gregorio psycopg2-binary,PyPI,LGPL-3.0-only,Copyright (C) 2013 Federico Di Gregorio pyasn1,PyPI,BSD-3-Clause,"Copyright (c) 2005-2019, Ilya Etingof " diff --git a/datadog_checks_base/CHANGELOG.md b/datadog_checks_base/CHANGELOG.md index c10da43773845..f3c1ada4b73df 100644 --- a/datadog_checks_base/CHANGELOG.md +++ b/datadog_checks_base/CHANGELOG.md @@ -19,6 +19,7 @@ * Downgrade pydantic to 2.0.2 ([#15596](https://github.com/DataDog/integrations-core/pull/15596)) * Bump cryptography to 41.0.3 ([#15517](https://github.com/DataDog/integrations-core/pull/15517)) +* Prevent `command already in progress` errors in the Postgres integration ([#15489](https://github.com/DataDog/integrations-core/pull/15489)) ## 32.7.0 / 2023-08-10 diff --git a/datadog_checks_base/datadog_checks/base/data/agent_requirements.in b/datadog_checks_base/datadog_checks/base/data/agent_requirements.in index d9adb8e1761fc..23f15ee71509c 100644 --- a/datadog_checks_base/datadog_checks/base/data/agent_requirements.in +++ b/datadog_checks_base/datadog_checks/base/data/agent_requirements.in @@ -59,6 +59,7 @@ protobuf==3.20.2; python_version > '3.0' psutil==5.9.0 psycopg2-binary==2.8.6; sys_platform != 'darwin' or platform_machine != 'arm64' psycopg[binary]==3.1.10; python_version > '3.0' +psycopg-pool==3.1.7; python_version > '3.0' pyasn1==0.4.6 pycryptodomex==3.10.1 pydantic==2.0.2; python_version > '3.0' diff --git a/datadog_checks_dev/CHANGELOG.md b/datadog_checks_dev/CHANGELOG.md index 0571de8a8882d..f297fc4ac6399 100644 --- a/datadog_checks_dev/CHANGELOG.md +++ b/datadog_checks_dev/CHANGELOG.md @@ -18,6 +18,7 @@ * Ignore `pydantic` when bumping the dependencies ([#15597](https://github.com/DataDog/integrations-core/pull/15597)) * Stop using the TOX_ENV_NAME variable ([#15528](https://github.com/DataDog/integrations-core/pull/15528)) +* Prevent `command already in progress` errors in the Postgres integration ([#15489](https://github.com/DataDog/integrations-core/pull/15489)) ## 23.0.0 / 2023-08-10 diff --git a/datadog_checks_dev/datadog_checks/dev/tooling/commands/validate/licenses.py b/datadog_checks_dev/datadog_checks/dev/tooling/commands/validate/licenses.py index 44b5d89b18cd5..a18ce7e9c0c2e 100644 --- a/datadog_checks_dev/datadog_checks/dev/tooling/commands/validate/licenses.py +++ b/datadog_checks_dev/datadog_checks/dev/tooling/commands/validate/licenses.py @@ -54,6 +54,8 @@ 'psycopg2-binary': ['LGPL-3.0-only', 'BSD-3-Clause'], # https://github.com/psycopg/psycopg/blob/master/LICENSE.txt 'psycopg': ['LGPL-3.0-only'], + # https://github.com/psycopg/psycopg/blob/master/psycopg_pool/LICENSE.txt + 'psycopg-pool': ['LGPL-3.0-only'], # https://github.com/Legrandin/pycryptodome/blob/master/LICENSE.rst 'pycryptodomex': ['Unlicense', 'BSD-2-Clause'], # https://github.com/requests/requests-kerberos/pull/123 diff --git a/postgres/CHANGELOG.md b/postgres/CHANGELOG.md index d71c7bb8a1a26..21d222fec3aee 100644 --- a/postgres/CHANGELOG.md +++ b/postgres/CHANGELOG.md @@ -12,6 +12,7 @@ ***Fixed***: * Update datadog-checks-base dependency version to 32.6.0 ([#15604](https://github.com/DataDog/integrations-core/pull/15604)) +* Prevent `command already in progress` errors in the Postgres integration ([#15489](https://github.com/DataDog/integrations-core/pull/15489)) ## 14.1.0 / 2023-08-10 diff --git a/postgres/datadog_checks/postgres/connections.py b/postgres/datadog_checks/postgres/connections.py index fa56b2918ace5..e2fbcc7f173df 100644 --- a/postgres/datadog_checks/postgres/connections.py +++ b/postgres/datadog_checks/postgres/connections.py @@ -8,7 +8,7 @@ import time from typing import Callable, Dict -import psycopg +from psycopg_pool import ConnectionPool from datadog_checks.base import AgentCheck @@ -25,18 +25,16 @@ def __str__(self): class ConnectionInfo: def __init__( self, - connection: psycopg.Connection, + connection: ConnectionPool, deadline: int, active: bool, last_accessed: int, - thread: threading.Thread, persistent: bool, ): self.connection = connection self.deadline = deadline self.active = active self.last_accessed = last_accessed - self.thread = thread self.persistent = persistent @@ -68,41 +66,47 @@ def __repr__(self): def reset(self): self.__init__() - def __init__(self, check: AgentCheck, connect_fn: Callable[[str], None], max_conns: int = None): - self.log = check.log + def __init__(self, check: AgentCheck, connect_fn: Callable[[str, int, int], None], max_conns: int = None): + self._check = check + self._log = check.log + self._config = check._config self.max_conns: int = max_conns self._stats = self.Stats() self._mu = threading.RLock() + self._query_lock = threading.Lock() self._conns: Dict[str, ConnectionInfo] = {} if hasattr(inspect, 'signature'): connect_sig = inspect.signature(connect_fn) - if len(connect_sig.parameters) != 1: + if not (len(connect_sig.parameters) >= 1): raise ValueError( "Invalid signature for the connection function. " - "A single parameter for dbname is expected, got signature: {}".format(connect_sig) + "Expected parameters: dbname, min_pool_size, max_pool_size. " + "Got signature: {}".format(connect_sig) ) self.connect_fn = connect_fn - def _get_connection_raw( + def _get_connection_pool( self, dbname: str, ttl_ms: int, timeout: int = None, - startup_fn: Callable[[psycopg.Connection], None] = None, + min_pool_size: int = 1, + max_pool_size: int = None, + startup_fn: Callable[[ConnectionPool], None] = None, persistent: bool = False, - ) -> psycopg.Connection: + ) -> ConnectionPool: """ - Return a connection from the pool. + Return a connection pool for the requested database from the managed pool. Pass a function to startup_func if there is an action needed with the connection when re-establishing it. """ start = datetime.datetime.now() self.prune_connections() with self._mu: - conn = self._conns.pop(dbname, ConnectionInfo(None, None, None, None, None, None)) - db = conn.connection - if db is None or db.closed: + conn = self._conns.pop(dbname, ConnectionInfo(None, None, None, None, None)) + db_pool = conn.connection + if db_pool is None or db_pool.closed: if self.max_conns is not None: # try to free space until we succeed while len(self._conns) >= self.max_conns: @@ -113,27 +117,22 @@ def _get_connection_raw( time.sleep(0.01) continue self._stats.connection_opened += 1 - db = self.connect_fn(dbname) + db_pool = self.connect_fn(dbname, min_pool_size, max_pool_size) if startup_fn: - startup_fn(db) + startup_fn(db_pool) else: # if already in pool, retain persistence status persistent = conn.persistent - if db.info.status != psycopg.pq.ConnStatus.OK: - # Some transaction went wrong and the connection is in an unhealthy state. Let's fix that - db.rollback() - deadline = datetime.datetime.now() + datetime.timedelta(milliseconds=ttl_ms) self._conns[dbname] = ConnectionInfo( - connection=db, + connection=db_pool, deadline=deadline, active=True, last_accessed=datetime.datetime.now(), - thread=threading.current_thread(), persistent=persistent, ) - return db + return db_pool @contextlib.contextmanager def get_connection(self, dbname: str, ttl_ms: int, timeout: int = None, persistent: bool = False): @@ -144,16 +143,19 @@ def get_connection(self, dbname: str, ttl_ms: int, timeout: int = None, persiste Blocks until a connection can be added to the pool, and optionally takes a timeout in seconds. """ + with self._mu: + pool = self._get_connection_pool(dbname=dbname, ttl_ms=ttl_ms, timeout=timeout, persistent=persistent) + db = pool.getconn(timeout=timeout) try: - with self._mu: - db = self._get_connection_raw(dbname=dbname, ttl_ms=ttl_ms, timeout=timeout, persistent=persistent) yield db finally: with self._mu: try: - self._conns[dbname].active = False + pool.putconn(db) + if not self._conns[dbname].persistent: + self._conns[dbname].active = False except KeyError: - # if self._get_connection_raw hit an exception, self._conns[dbname] didn't get populated + # if self._get_connection_raw hit an exception, self._conns[conn_name] didn't get populated pass def prune_connections(self): @@ -166,10 +168,10 @@ def prune_connections(self): """ with self._mu: now = datetime.datetime.now() - for dbname, conn in list(self._conns.items()): - if conn.deadline < now: + for conn_name, conn in list(self._conns.items()): + if conn.deadline < now and not conn.active and not conn.persistent: self._stats.connection_pruned += 1 - self._terminate_connection_unsafe(dbname) + self._terminate_connection_unsafe(conn_name) def close_all_connections(self, timeout=None): """ @@ -202,14 +204,29 @@ def evict_lru(self) -> str: return None def _terminate_connection_unsafe(self, dbname: str): - db = self._conns.pop(dbname, ConnectionInfo(None, None, None, None, None, None)).connection + db = self._conns.pop(dbname, ConnectionInfo(None, None, None, None, None)).connection if db is not None: try: - self._stats.connection_closed += 1 if not db.closed: db.close() + self._stats.connection_closed += 1 except Exception: self._stats.connection_closed_failed += 1 - self.log.exception("failed to close DB connection for db=%s", dbname) + self._log.exception("failed to close DB connection for db=%s", dbname) return False return True + + def get_main_db_pool(self, max_pool_conn_size: int = 3): + """ + Returns a memoized, persistent psycopg connection pool to `self.dbname`. + Is meant to be shared across multiple threads, and opens a preconfigured max number of connections. + :return: a psycopg connection + """ + conn = self._get_connection_pool( + dbname=self._config.dbname, + ttl_ms=self._config.idle_connection_timeout, + max_pool_size=max_pool_conn_size, + startup_fn=self._check.load_pg_settings, + persistent=True, + ) + return conn diff --git a/postgres/datadog_checks/postgres/explain_parameterized_queries.py b/postgres/datadog_checks/postgres/explain_parameterized_queries.py index 454e65e625c35..e50083877b1c6 100644 --- a/postgres/datadog_checks/postgres/explain_parameterized_queries.py +++ b/postgres/datadog_checks/postgres/explain_parameterized_queries.py @@ -72,27 +72,28 @@ def __init__(self, check, config): def explain_statement(self, dbname, statement, obfuscated_statement): if self._check.version < V12: return None - self._set_plan_cache_mode(dbname) query_signature = compute_sql_signature(obfuscated_statement) - if not self._create_prepared_statement(dbname, statement, obfuscated_statement, query_signature): - return None + with self._check.db_pool.get_connection(dbname, self._check._config.idle_connection_timeout) as conn: + self._set_plan_cache_mode(conn) - result = self._explain_prepared_statement(dbname, statement, obfuscated_statement, query_signature) - self._deallocate_prepared_statement(dbname, query_signature) - if result: - return result[0]['explain_statement'][0] + if not self._create_prepared_statement(conn, statement, obfuscated_statement, query_signature): + return None + + result = self._explain_prepared_statement(conn, statement, obfuscated_statement, query_signature) + self._deallocate_prepared_statement(conn, query_signature) + if result: + return result[0]['explain_statement'][0] return None - def _set_plan_cache_mode(self, dbname): - self._execute_query(dbname, "SET plan_cache_mode = force_generic_plan") + def _set_plan_cache_mode(self, conn): + self._execute_query(conn, "SET plan_cache_mode = force_generic_plan") @tracked_method(agent_check_getter=agent_check_getter) - def _create_prepared_statement(self, dbname, statement, obfuscated_statement, query_signature): + def _create_prepared_statement(self, conn, statement, obfuscated_statement, query_signature): try: self._execute_query( - dbname, - PREPARE_STATEMENT_QUERY.format(query_signature=query_signature, statement=statement), + conn, PREPARE_STATEMENT_QUERY.format(query_signature=query_signature, statement=statement) ) return True except Exception as e: @@ -108,26 +109,24 @@ def _create_prepared_statement(self, dbname, statement, obfuscated_statement, qu return False @tracked_method(agent_check_getter=agent_check_getter) - def _get_number_of_parameters_for_prepared_statement(self, dbname, query_signature): - rows = self._execute_query_and_fetch_rows( - dbname, PARAM_TYPES_COUNT_QUERY.format(query_signature=query_signature) - ) + def _get_number_of_parameters_for_prepared_statement(self, conn, query_signature): + rows = self._execute_query_and_fetch_rows(conn, PARAM_TYPES_COUNT_QUERY.format(query_signature=query_signature)) count = 0 if rows and 'count' in rows[0]: count = rows[0]['count'] return count @tracked_method(agent_check_getter=agent_check_getter) - def _explain_prepared_statement(self, dbname, statement, obfuscated_statement, query_signature): + def _explain_prepared_statement(self, conn, statement, obfuscated_statement, query_signature): null_parameter = ','.join( - 'null' for _ in range(self._get_number_of_parameters_for_prepared_statement(dbname, query_signature)) + 'null' for _ in range(self._get_number_of_parameters_for_prepared_statement(conn, query_signature)) ) execute_prepared_statement_query = EXECUTE_PREPARED_STATEMENT_QUERY.format( prepared_statement=query_signature, generic_values=null_parameter ) try: return self._execute_query_and_fetch_rows( - dbname, + conn, EXPLAIN_QUERY.format( explain_function=self._config.statement_samples_config.get( 'explain_function', 'datadog.explain_statement' @@ -147,11 +146,9 @@ def _explain_prepared_statement(self, dbname, statement, obfuscated_statement, q ) return None - def _deallocate_prepared_statement(self, dbname, query_signature): + def _deallocate_prepared_statement(self, conn, query_signature): try: - self._execute_query( - dbname, "DEALLOCATE PREPARE dd_{query_signature}".format(query_signature=query_signature) - ) + self._execute_query(conn, "DEALLOCATE PREPARE dd_{query_signature}".format(query_signature=query_signature)) except Exception as e: logger.warning( 'Failed to deallocate prepared statement query_signature=[%s] | err=[%s]', @@ -159,15 +156,13 @@ def _deallocate_prepared_statement(self, dbname, query_signature): e, ) - def _execute_query(self, dbname, query): - with self._check.db_pool.get_connection(dbname, self._check._config.idle_connection_timeout) as conn: - with conn.cursor(row_factory=dict_row) as cursor: - logger.debug('Executing query=[%s]', query) - cursor.execute(query) + def _execute_query(self, conn, query): + with conn.cursor(row_factory=dict_row) as cursor: + logger.debug('Executing query=[%s]', query) + cursor.execute(query) - def _execute_query_and_fetch_rows(self, dbname, query): - with self._check.db_pool.get_connection(dbname, self._check._config.idle_connection_timeout) as conn: - with conn.cursor(row_factory=dict_row) as cursor: - logger.debug('Executing query=[%s]', query) - cursor.execute(query) - return cursor.fetchall() + def _execute_query_and_fetch_rows(self, conn, query): + with conn.cursor(row_factory=dict_row) as cursor: + logger.debug('Executing query=[%s]', query) + cursor.execute(query) + return cursor.fetchall() diff --git a/postgres/datadog_checks/postgres/metadata.py b/postgres/datadog_checks/postgres/metadata.py index 0e23694a05ced..5584115de2bb5 100644 --- a/postgres/datadog_checks/postgres/metadata.py +++ b/postgres/datadog_checks/postgres/metadata.py @@ -214,7 +214,6 @@ def run_job(self): self.tags = [t for t in self._tags if not t.startswith('dd.internal')] self._tags_no_db = [t for t in self.tags if not t.startswith('db:')] self.report_postgres_metadata() - self._check.db_pool.prune_connections() @tracked_method(agent_check_getter=agent_check_getter) def report_postgres_metadata(self): @@ -223,19 +222,19 @@ def report_postgres_metadata(self): elapsed_s = time.time() - self._time_since_last_settings_query if elapsed_s >= self.pg_settings_collection_interval and self._collect_pg_settings_enabled: self._pg_settings_cached = self._collect_postgres_settings() - event = { - "host": self._check.resolved_hostname, - "agent_version": datadog_agent.get_version(), - "dbms": "postgres", - "kind": "pg_settings", - "collection_interval": self.collection_interval, - 'dbms_version': payload_pg_version(self._check.version), - "tags": self._tags_no_db, - "timestamp": time.time() * 1000, - "cloud_metadata": self._config.cloud_metadata, - "metadata": self._pg_settings_cached, - } - self._check.database_monitoring_metadata(json.dumps(event, default=default_json_event_encoding)) + event = { + "host": self._check.resolved_hostname, + "agent_version": datadog_agent.get_version(), + "dbms": "postgres", + "kind": "pg_settings", + "collection_interval": self.collection_interval, + 'dbms_version': payload_pg_version(self._check.version), + "tags": self._tags_no_db, + "timestamp": time.time() * 1000, + "cloud_metadata": self._config.cloud_metadata, + "metadata": self._pg_settings_cached, + } + self._check.database_monitoring_metadata(json.dumps(event, default=default_json_event_encoding)) elapsed_s_schemas = time.time() - self._time_since_last_schemas_query if elapsed_s_schemas >= self.schemas_collection_interval and self._collect_schemas_enabled: @@ -313,7 +312,7 @@ def _get_table_info(self, cursor, dbname, schema_id): """ limit = self._config.schemas_metadata_config.get('max_tables', 1000) if self._config.relations: - if VersionUtils.transform_version(str(self._check._version))['version.major'] == "9": + if VersionUtils.transform_version(str(self._check.version))['version.major'] == "9": cursor.execute(PG_TABLES_QUERY_V9.format(schema_oid=schema_id)) else: cursor.execute(PG_TABLES_QUERY_V10_PLUS.format(schema_oid=schema_id)) @@ -350,7 +349,7 @@ def sort_tables(info): # so we have to grab the total partition activity # note: partitions don't exist in V9, so we have to check this first if ( - VersionUtils.transform_version(str(self._check._version))['version.major'] == "9" + VersionUtils.transform_version(str(self._check.version))['version.major'] == "9" or not info["has_partitions"] ): return ( @@ -405,7 +404,7 @@ def _query_table_information_for_schema( idxs = [dict(row) for row in rows] this_payload.update({'indexes': idxs}) - if VersionUtils.transform_version(str(self._check._version))['version.major'] != "9": + if VersionUtils.transform_version(str(self._check.version))['version.major'] != "9": if table['has_partitions']: cursor.execute(PARTITION_KEY_QUERY.format(parent=name)) row = cursor.fetchone() @@ -464,10 +463,11 @@ def _collect_metadata_for_database(self, dbname): @tracked_method(agent_check_getter=agent_check_getter) def _collect_postgres_settings(self): - with self._check.get_main_db().cursor(row_factory=dict_row) as cursor: - self._log.debug("Running query [%s]", PG_SETTINGS_QUERY) - self._time_since_last_settings_query = time.time() - cursor.execute(PG_SETTINGS_QUERY) - rows = cursor.fetchall() - self._log.debug("Loaded %s rows from pg_settings", len(rows)) - return [dict(row) for row in rows] + with self.db_pool.get_main_db_pool().connection() as conn: + with conn.cursor(row_factory=dict_row) as cursor: + self._log.debug("Running query [%s]", PG_SETTINGS_QUERY) + self._time_since_last_settings_query = time.time() + cursor.execute(PG_SETTINGS_QUERY) + rows = cursor.fetchall() + self._log.debug("Loaded %s rows from pg_settings", len(rows)) + return [dict(row) for row in rows] diff --git a/postgres/datadog_checks/postgres/postgres.py b/postgres/datadog_checks/postgres/postgres.py index 6af6829ee4088..9adfd0cdd4477 100644 --- a/postgres/datadog_checks/postgres/postgres.py +++ b/postgres/datadog_checks/postgres/postgres.py @@ -8,8 +8,8 @@ import psycopg from cachetools import TTLCache -from psycopg import ClientCursor from psycopg.rows import dict_row +from psycopg_pool import ConnectionPool from six import iteritems from datadog_checks.base import AgentCheck @@ -84,8 +84,8 @@ def __init__(self, name, init_config, instances): self.persistent_conn = None self._resolved_hostname = None self._agent_hostname = None - self._version = None - self._is_aurora = None + self.version = None + self.is_aurora = None self._version_utils = VersionUtils() # Deprecate custom_metrics in favor of custom_queries if 'custom_metrics' in self.instance: @@ -112,6 +112,9 @@ def __init__(self, name, init_config, instances): self._clean_state() self.check_initializations.append(lambda: RelationsManager.validate_relations_config(self._config.relations)) self.check_initializations.append(self.set_resolved_hostname_metadata) + self.check_initializations.append(self._connect) + self.check_initializations.append(self.load_version) + self.check_initializations.append(self.initialize_is_aurora) self.tags_without_db = [t for t in copy.copy(self.tags) if not t.startswith("db:")] self.autodiscovery = self._build_autodiscovery() self._dynamic_queries = None @@ -182,17 +185,18 @@ def _new_query_executor(self, queries): ) def execute_query_raw(self, query): - with self.db.cursor() as cursor: - cursor.execute(query) - rows = cursor.fetchall() - return rows + with self.db.connection() as conn: + with conn.cursor() as cursor: + cursor.execute(query) + rows = cursor.fetchall() + return rows @property def dynamic_queries(self): if self._dynamic_queries: return self._dynamic_queries - if self._version is None: + if self.version is None: self.log.debug("Version set to None due to incorrect identified version, aborting dynamic queries") return None @@ -280,8 +284,6 @@ def cancel(self): def _clean_state(self): self.log.debug("Cleaning state") - self._version = None - self._is_aurora = None self.metrics_cache.clean_state() self._dynamic_queries = None @@ -294,11 +296,12 @@ def _get_service_check_tags(self): return list(service_check_tags) def _get_replication_role(self): - with self.db.cursor() as cursor: - cursor.execute('SELECT pg_is_in_recovery();') - role = cursor.fetchone()[0] - # value fetched for role is of - return "standby" if role else "master" + with self.db.connection() as conn: + with conn.cursor() as cursor: + cursor.execute('SELECT pg_is_in_recovery();') + role = cursor.fetchone()[0] + # value fetched for role is of + return "standby" if role else "master" def _collect_wal_metrics(self, instance_tags): if self.version >= V10: @@ -343,19 +346,16 @@ def _get_local_wal_file_age(self): oldest_file_age = now - os.path.getctime(oldest_file) return oldest_file_age - @property - def version(self): - if self._version is None: - raw_version = self._version_utils.get_raw_version(self.db) - self._version = self._version_utils.parse_version(raw_version) - self.set_metadata('version', raw_version) - return self._version + def load_version(self): + raw_version = self._version_utils.get_raw_version(self.db) + self.version = self._version_utils.parse_version(raw_version) + self.set_metadata('version', raw_version) + return self.version - @property - def is_aurora(self): - if self._is_aurora is None: - self._is_aurora = self._version_utils.is_aurora(self.db) - return self._is_aurora + def initialize_is_aurora(self): + if self.is_aurora is None: + self.is_aurora = self._version_utils.is_aurora(self.db) + return self.is_aurora @property def resolved_hostname(self): @@ -413,7 +413,6 @@ def _run_query_scope(self, cursor, scope, is_custom_metrics, cols, descriptors): except psycopg.errors.FeatureNotSupported as e: # This happens for example when trying to get replication metrics from readers in Aurora. Let's ignore it. log_func(e) - self.db.rollback() self.log.debug("Disabling replication metrics") self._is_aurora = False self.metrics_cache.replication_metrics = {} @@ -421,13 +420,11 @@ def _run_query_scope(self, cursor, scope, is_custom_metrics, cols, descriptors): log_func(e) log_func( "It seems the PG version has been incorrectly identified as %s. " - "A reattempt to identify the right version will happen on next agent run." % self._version + "A reattempt to identify the right version will happen on next agent run." % self.version ) self._clean_state() - self.db.rollback() except (psycopg.ProgrammingError, psycopg.errors.QueryCanceled) as e: log_func("Not all metrics may be available: %s" % str(e)) - self.db.rollback() if not results: return None @@ -591,7 +588,6 @@ def _collect_stats(self, instance_tags): # Do we need relation-specific metrics? if self._config.relations: relations_scopes = list(RELATION_METRICS) - if self._config.collect_bloat_metrics: relations_scopes.extend([INDEX_BLOAT, TABLE_BLOAT]) @@ -612,30 +608,33 @@ def _collect_stats(self, instance_tags): if replication_stats_metrics: metric_scope.append(replication_stats_metrics) - with self.db.cursor() as cursor: - results_len = self._query_scope(cursor, db_instance_metrics, instance_tags, False) - if results_len is not None: - self.gauge( - "postgresql.db.count", - results_len, - tags=copy.copy(self.tags_without_db), - hostname=self.resolved_hostname, - ) + with self.db.connection() as conn: + with conn.cursor() as cursor: + results_len = self._query_scope(cursor, db_instance_metrics, instance_tags, False) + if results_len is not None: + self.gauge( + "postgresql.db.count", + results_len, + tags=copy.copy(self.tags_without_db), + hostname=self.resolved_hostname, + ) - self._query_scope(cursor, bgw_instance_metrics, instance_tags, False) - self._query_scope(cursor, archiver_instance_metrics, instance_tags, False) + self._query_scope(cursor, bgw_instance_metrics, instance_tags, False) + self._query_scope(cursor, archiver_instance_metrics, instance_tags, False) - if self._config.collect_activity_metrics: - activity_metrics = self.metrics_cache.get_activity_metrics(self.version) - self._query_scope(cursor, activity_metrics, instance_tags, False) + if self._config.collect_activity_metrics: + activity_metrics = self.metrics_cache.get_activity_metrics(self.version) + self._query_scope(cursor, activity_metrics, instance_tags, False) - for scope in list(metric_scope) + self._config.custom_metrics: - self._query_scope(cursor, scope, instance_tags, scope in self._config.custom_metrics) + for scope in list(metric_scope) + self._config.custom_metrics: + self._query_scope(cursor, scope, instance_tags, scope in self._config.custom_metrics) - if self.dynamic_queries: - self.dynamic_queries.execute() + if self.dynamic_queries: + self.dynamic_queries.execute() - def _new_connection(self, dbname): + def _new_connection(self, dbname: str, min_pool_size: int = 1, max_pool_size: int = None): + # required for autocommit as well as using params in queries + args = {"autocommit": True, "cursor_factory": psycopg.ClientCursor} if self._config.host == 'localhost' and self._config.password == '': # Use ident method connection_string = "user=%s dbname=%s application_name=%s" % ( @@ -645,7 +644,14 @@ def _new_connection(self, dbname): ) if self._config.query_timeout: connection_string += " options='-c statement_timeout=%s'" % self._config.query_timeout - conn = psycopg.connect(conninfo=connection_string, autocommit=True, cursor_factory=ClientCursor) + pool = ConnectionPool( + conninfo=connection_string, + min_size=min_pool_size, + max_size=max_pool_size, + kwargs=args, + open=True, + name=dbname, + ) else: password = self._config.password region = self._config.cloud_metadata.get('aws', {}).get('region', None) @@ -661,7 +667,7 @@ def _new_connection(self, dbname): if client_id is not None: password = azure.generate_managed_identity_token(client_id=client_id, scope=scope) - args = { + conn_args = { 'host': self._config.host, 'user': self._config.user, 'password': password, @@ -670,20 +676,20 @@ def _new_connection(self, dbname): 'application_name': self._config.application_name, } if self._config.port: - args['port'] = self._config.port + conn_args['port'] = self._config.port if self._config.query_timeout: - args['options'] = '-c statement_timeout=%s' % self._config.query_timeout + conn_args['options'] = '-c statement_timeout=%s' % self._config.query_timeout if self._config.ssl_cert: - args['sslcert'] = self._config.ssl_cert + conn_args['sslcert'] = self._config.ssl_cert if self._config.ssl_root_cert: - args['sslrootcert'] = self._config.ssl_root_cert + conn_args['sslrootcert'] = self._config.ssl_root_cert if self._config.ssl_key: - args['sslkey'] = self._config.ssl_key + conn_args['sslkey'] = self._config.ssl_key if self._config.ssl_password: - args['sslpassword'] = self._config.ssl_password - - conn = psycopg.connect(**args, autocommit=True, cursor_factory=ClientCursor) - return conn + conn_args['sslpassword'] = self._config.ssl_password + args.update(conn_args) + pool = ConnectionPool(min_size=min_pool_size, max_size=max_pool_size, kwargs=args, open=True, name=dbname) + return pool def _connect(self): """ @@ -695,28 +701,26 @@ def _connect(self): if self.db and self.db.closed: # Reset the connection object to retry to connect self.db = None - if self.db: - if self.db.info.status != psycopg.pq.ConnStatus.OK: - # Some transaction went wrong and the connection is in an unhealthy state. Let's fix that - self.db.rollback() - else: - self.db = self._new_connection(self._config.dbname) + + if not self.db: + self.db = self._new_connection(self._config.dbname, max_pool_size=1) # Reload pg_settings on a new connection to the main db - def _load_pg_settings(self, db): + def load_pg_settings(self, db): try: - with db.cursor(row_factory=dict_row) as cursor: - self.log.debug("Running query [%s]", PG_SETTINGS_QUERY) - cursor.execute( - PG_SETTINGS_QUERY, - ("pg_stat_statements.max", "track_activity_query_size", "track_io_timing"), - ) - rows = cursor.fetchall() - self.pg_settings.clear() - for setting in rows: - name = setting['name'] - val = setting['setting'] - self.pg_settings[name] = val + with db.connection() as conn: + with conn.cursor(row_factory=dict_row) as cursor: + self.log.debug("Running query [%s]", PG_SETTINGS_QUERY) + cursor.execute( + PG_SETTINGS_QUERY, + ("pg_stat_statements.max", "track_activity_query_size", "track_io_timing"), + ) + rows = cursor.fetchall() + self.pg_settings.clear() + for setting in rows: + name = setting['name'] + val = setting['setting'] + self.pg_settings[name] = val except (psycopg.DatabaseError, psycopg.OperationalError) as err: self.log.warning("Failed to query for pg_settings: %s", repr(err)) self.count( @@ -729,21 +733,6 @@ def _load_pg_settings(self, db): def get_pg_settings(self): return self.pg_settings - def get_main_db(self): - """ - Returns a memoized, persistent psycopg connection to `self.dbname`. - Utilizes the db connection pool, and is meant to be shared across multiple threads. - :return: a psycopg connection - """ - # reload settings for the main DB only once every time the connection is reestablished - conn = self.db_pool._get_connection_raw( - dbname=self._config.dbname, - ttl_ms=self._config.idle_connection_timeout, - startup_fn=self._load_pg_settings, - persistent=True, - ) - return conn - def _close_db_pool(self): self.db_pool.close_all_connections(timeout=self._config.min_collection_interval) @@ -768,79 +757,81 @@ def _collect_custom_queries(self, tags): self.log.error("custom query field `columns` is required for metric_prefix `%s`", metric_prefix) continue - with self.db.cursor() as cursor: - try: - self.log.debug("Running query: %s", query) - cursor.execute(query) - except (psycopg.ProgrammingError, psycopg.errors.QueryCanceled) as e: - self.log.error("Error executing query for metric_prefix %s: %s", metric_prefix, str(e)) - self.db.rollback() - continue - - for row in cursor: - if not row: - self.log.debug("query result for metric_prefix %s: returned an empty result", metric_prefix) - continue - - if len(columns) != len(row): - self.log.error( - "query result for metric_prefix %s: expected %s columns, got %s", - metric_prefix, - len(columns), - len(row), - ) + with self.db.connection() as conn: + with conn.cursor() as cursor: + try: + self.log.debug("Running query: %s", query) + cursor.execute(query) + except (psycopg.ProgrammingError, psycopg.errors.QueryCanceled) as e: + self.log.error("Error executing query for metric_prefix %s: %s", metric_prefix, str(e)) continue - metric_info = [] - query_tags = list(custom_query.get('tags', [])) - query_tags.extend(tags) - - for column, value in zip(columns, row): - # Columns can be ignored via configuration. - if not column: + for row in cursor: + if not row: + self.log.debug("query result for metric_prefix %s: returned an empty result", metric_prefix) continue - name = column.get('name') - if not name: - self.log.error("column field `name` is required for metric_prefix `%s`", metric_prefix) - break - - column_type = column.get('type') - if not column_type: + if len(columns) != len(row): self.log.error( - "column field `type` is required for column `%s` of metric_prefix `%s`", - name, + "query result for metric_prefix %s: expected %s columns, got %s", metric_prefix, + len(columns), + len(row), ) - break + continue - if column_type == 'tag': - query_tags.append('{}:{}'.format(name, value)) - else: - if not hasattr(self, column_type): - self.log.error( - "invalid submission method `%s` for column `%s` of metric_prefix `%s`", - column_type, - name, - metric_prefix, - ) + metric_info = [] + query_tags = list(custom_query.get('tags', [])) + query_tags.extend(tags) + + for column, value in zip(columns, row): + # Columns can be ignored via configuration. + if not column: + continue + + name = column.get('name') + if not name: + self.log.error("column field `name` is required for metric_prefix `%s`", metric_prefix) break - try: - metric_info.append(('{}.{}'.format(metric_prefix, name), float(value), column_type)) - except (ValueError, TypeError): + + column_type = column.get('type') + if not column_type: self.log.error( - "non-numeric value `%s` for metric column `%s` of metric_prefix `%s`", - value, + "column field `type` is required for column `%s` of metric_prefix `%s`", name, metric_prefix, ) break - # Only submit metrics if there were absolutely no errors - all or nothing. - else: - for info in metric_info: - metric, value, method = info - getattr(self, method)(metric, value, tags=set(query_tags), hostname=self.resolved_hostname) + if column_type == 'tag': + query_tags.append('{}:{}'.format(name, value)) + else: + if not hasattr(self, column_type): + self.log.error( + "invalid submission method `%s` for column `%s` of metric_prefix `%s`", + column_type, + name, + metric_prefix, + ) + break + try: + metric_info.append(('{}.{}'.format(metric_prefix, name), float(value), column_type)) + except (ValueError, TypeError): + self.log.error( + "non-numeric value `%s` for metric column `%s` of metric_prefix `%s`", + value, + name, + metric_prefix, + ) + break + + # Only submit metrics if there were absolutely no errors - all or nothing. + else: + for info in metric_info: + metric, value, method = info + getattr(self, method)( + metric, value, tags=set(query_tags), hostname=self.resolved_hostname + ) def record_warning(self, code, message): # type: (DatabaseConfigurationError, str) -> None @@ -881,6 +872,7 @@ def check(self, _): try: # Check version self._connect() + self.load_version() # We don't want to cache versions between runs to capture minor updates for metadata if self._config.tag_replication_role: replication_role_tag = "replication_role:{}".format(self._get_replication_role()) tags.append(replication_role_tag) @@ -921,12 +913,6 @@ def check(self, _): tags=self._get_service_check_tags(), hostname=self.resolved_hostname, ) - try: - # commit to close the current query transaction - self.db.commit() - except Exception as e: - self.log.warning("Unable to commit: %s", e) - self._version = None # We don't want to cache versions between runs to capture minor updates for metadata finally: # Add the warnings saved during the execution of the check self._report_warnings() diff --git a/postgres/datadog_checks/postgres/statement_samples.py b/postgres/datadog_checks/postgres/statement_samples.py index f507a0877af41..405ecfd4be8d5 100644 --- a/postgres/datadog_checks/postgres/statement_samples.py +++ b/postgres/datadog_checks/postgres/statement_samples.py @@ -269,11 +269,11 @@ def _get_active_connections(self): query = PG_ACTIVE_CONNECTIONS_QUERY.format( pg_stat_activity_view=self._config.pg_stat_activity_view, extra_filters=extra_filters ) - with self._check.get_main_db().cursor(row_factory=dict_row) as cursor: - self._log.debug("Running query [%s] %s", query, params) - cursor.execute(query, params) - rows = cursor.fetchall() - + with self.db_pool.get_main_db_pool().connection() as conn: + with conn.cursor(row_factory=dict_row) as cursor: + self._log.debug("Running query [%s] %s", query, params) + cursor.execute(query, params) + rows = cursor.fetchall() self._report_check_hist_metrics(start_time, len(rows), "get_active_connections") self._log.debug("Loaded %s rows from %s", len(rows), self._config.pg_stat_activity_view) return [dict(row) for row in rows] @@ -302,11 +302,11 @@ def _get_new_pg_stat_activity(self, available_activity_columns): pg_stat_activity_view=self._config.pg_stat_activity_view, extra_filters=extra_filters, ) - with self._check.get_main_db().cursor(row_factory=dict_row) as cursor: - self._log.debug("Running query [%s] %s", query, params) - cursor.execute(query, params) - rows = cursor.fetchall() - + with self.db_pool.get_main_db_pool().connection() as conn: + with conn.cursor(row_factory=dict_row) as cursor: + self._log.debug("Running query [%s] %s", query, params) + cursor.execute(query, params) + rows = cursor.fetchall() self._report_check_hist_metrics(start_time, len(rows), "get_new_pg_stat_activity") self._log.debug("Loaded %s rows from %s", len(rows), self._config.pg_stat_activity_view) return rows @@ -320,18 +320,19 @@ def _get_pg_stat_activity_cols_cached(self, expected_cols): @tracked_method(agent_check_getter=agent_check_getter, track_result_length=True) def _get_available_activity_columns(self, all_expected_columns): - with self._check.get_main_db().cursor(row_factory=dict_row) as cursor: - cursor.execute( - "select * from {pg_stat_activity_view} LIMIT 0".format( - pg_stat_activity_view=self._config.pg_stat_activity_view + with self.db_pool.get_main_db_pool().connection() as conn: + with conn.cursor(row_factory=dict_row) as cursor: + cursor.execute( + "select * from {pg_stat_activity_view} LIMIT 0".format( + pg_stat_activity_view=self._config.pg_stat_activity_view + ) ) - ) - all_columns = {i[0] for i in cursor.description} - available_columns = [c for c in all_expected_columns if c in all_columns] - missing_columns = set(all_expected_columns) - set(available_columns) - if missing_columns: - self._log.debug("missing the following expected columns from pg_stat_activity: %s", missing_columns) - self._log.debug("found available pg_stat_activity columns: %s", available_columns) + all_columns = {i[0] for i in cursor.description} + available_columns = [c for c in all_expected_columns if c in all_columns] + missing_columns = set(all_expected_columns) - set(available_columns) + if missing_columns: + self._log.debug("missing the following expected columns from pg_stat_activity: %s", missing_columns) + self._log.debug("found available pg_stat_activity columns: %s", available_columns) return available_columns def _filter_and_normalize_statement_rows(self, rows): diff --git a/postgres/datadog_checks/postgres/statements.py b/postgres/datadog_checks/postgres/statements.py index 446a653295d93..b3e96ac65d1e9 100644 --- a/postgres/datadog_checks/postgres/statements.py +++ b/postgres/datadog_checks/postgres/statements.py @@ -144,6 +144,7 @@ def __init__(self, check, config): maxsize=config.full_statement_text_cache_max_size, ttl=60 * 60 / config.full_statement_text_samples_per_hour_per_query, ) + self._thread_id = "query-metrics" def _execute_query(self, cursor, query, params=()): try: @@ -170,11 +171,12 @@ def _get_pg_stat_statements_columns(self): query = STATEMENTS_QUERY.format( cols='*', pg_stat_statements_view=self._config.pg_stat_statements_view, extra_clauses="LIMIT 0", filters="" ) - with self._check.get_main_db().cursor() as cursor: - self._execute_query(cursor, query, params=()) - col_names = [desc[0] for desc in cursor.description] if cursor.description else [] - self._stat_column_cache = col_names - return col_names + with self.db_pool.get_main_db_pool().connection() as conn: + with conn.cursor() as cursor: + self._execute_query(cursor, query, params=()) + col_names = [desc[0] for desc in cursor.description] if cursor.description else [] + self._stat_column_cache = col_names + return col_names def run_job(self): # do not emit any dd.internal metrics for DBM specific check code @@ -276,17 +278,18 @@ def _load_pg_stat_statements(self): "pg_database.datname NOT ILIKE %s" for _ in self._config.ignore_databases ) params = params + tuple(self._config.ignore_databases) - with self._check.get_main_db().cursor(row_factory=dict_row) as cursor: - return self._execute_query( - cursor, - STATEMENTS_QUERY.format( - cols=', '.join(query_columns), - pg_stat_statements_view=self._config.pg_stat_statements_view, - filters=filters, - extra_clauses="", - ), - params=params, - ) + with self.db_pool.get_main_db_pool().connection() as conn: + with conn.cursor(row_factory=dict_row) as cursor: + return self._execute_query( + cursor, + STATEMENTS_QUERY.format( + cols=', '.join(query_columns), + pg_stat_statements_view=self._config.pg_stat_statements_view, + filters=filters, + extra_clauses="", + ), + params=params, + ) except psycopg.Error as e: error_tag = "error:database-{}".format(type(e).__name__) @@ -349,11 +352,12 @@ def _emit_pg_stat_statements_dealloc(self): if self._check.version < V14: return try: - with self._check.get_main_db().cursor(row_factory=dict_row) as cursor: - rows = self._execute_query( - cursor, - PG_STAT_STATEMENTS_DEALLOC, - ) + with self.db_pool.get_main_db_pool().connection() as conn: + with conn.cursor(row_factory=dict_row) as cursor: + rows = self._execute_query( + cursor, + PG_STAT_STATEMENTS_DEALLOC, + ) if rows: dealloc = list(rows[0].values())[0] self._check.monotonic_count( @@ -369,11 +373,12 @@ def _emit_pg_stat_statements_dealloc(self): def _emit_pg_stat_statements_metrics(self): query = PG_STAT_STATEMENTS_COUNT_QUERY_LT_9_4 if self._check.version < V9_4 else PG_STAT_STATEMENTS_COUNT_QUERY try: - with self._check.get_main_db().cursor(row_factory=dict_row) as cursor: - rows = self._execute_query( - cursor, - query, - ) + with self.db_pool.get_main_db_pool().connection() as conn: + with conn.cursor(row_factory=dict_row) as cursor: + rows = self._execute_query( + cursor, + query, + ) count = 0 if rows and 'count' in rows[0]: count = rows[0]['count'] diff --git a/postgres/datadog_checks/postgres/version_utils.py b/postgres/datadog_checks/postgres/version_utils.py index 991717e501ead..54b99eab8181e 100644 --- a/postgres/datadog_checks/postgres/version_utils.py +++ b/postgres/datadog_checks/postgres/version_utils.py @@ -26,22 +26,24 @@ def __init__(self): @staticmethod def get_raw_version(db): - with db.cursor() as cursor: - cursor.execute('SHOW SERVER_VERSION;') - raw_version = cursor.fetchone()[0] - return raw_version + with db.connection() as conn: + with conn.cursor() as cursor: + cursor.execute('SHOW SERVER_VERSION;') + raw_version = cursor.fetchone()[0] + return raw_version def is_aurora(self, db): if self._seen_aurora_exception: return False try: - with db.cursor() as cursor: - # This query will pollute PG logs in non aurora versions but is the only reliable way to detect aurora - cursor.execute('select AURORA_VERSION();') - return True + with db.connection() as conn: + with conn.cursor() as cursor: + # This query will pollute PG logs in non aurora versions, + # but is the only reliable way to detect aurora + cursor.execute('select AURORA_VERSION();') + return True except Exception as e: self.log.debug("Captured exception %s while determining if the DB is aurora. Assuming is not", str(e)) - db.rollback() self._seen_aurora_exception = True return False diff --git a/postgres/pyproject.toml b/postgres/pyproject.toml index 9720f75c09a06..04b6a0f8e0197 100644 --- a/postgres/pyproject.toml +++ b/postgres/pyproject.toml @@ -41,6 +41,7 @@ deps = [ "azure-identity==1.14.0; python_version > '3.0'", "cachetools==5.3.1; python_version > '3.0'", "psycopg[binary]==3.1.10; python_version > '3.0'", + "psycopg-pool==3.1.7; python_version > '3.0'", "semver==3.0.1; python_version > '3.0'", ] diff --git a/postgres/tests/common.py b/postgres/tests/common.py index ccbe7cce86590..dd3368f175bf9 100644 --- a/postgres/tests/common.py +++ b/postgres/tests/common.py @@ -173,7 +173,7 @@ def check_connection_metrics(aggregator, expected_tags, count=1): aggregator.assert_metric(name, count=count, tags=db_tags) -def check_activity_metrics(aggregator, tags, hostname=None, count=1): +def check_activity_metrics(aggregator, tags): activity_metrics = [ 'postgresql.transactions.open', 'postgresql.transactions.idle_in_transaction', @@ -186,7 +186,7 @@ def check_activity_metrics(aggregator, tags, hostname=None, count=1): # Query won't have xid assigned so postgresql.activity.backend_xid_age won't be emitted activity_metrics.append('postgresql.activity.backend_xmin_age') for name in activity_metrics: - aggregator.assert_metric(name, count=1, tags=tags, hostname=hostname) + assert_metric_at_least(aggregator, name, tags=tags) def check_stat_replication(aggregator, expected_tags, count=1): diff --git a/postgres/tests/conftest.py b/postgres/tests/conftest.py index e83bae7b86aba..f61e9acfb2520 100644 --- a/postgres/tests/conftest.py +++ b/postgres/tests/conftest.py @@ -26,6 +26,7 @@ 'dbname': DB_NAME, 'tags': ['foo:bar'], 'disable_generic_tags': True, + 'dbm': False, } @@ -109,16 +110,21 @@ def e2e_instance(): @pytest.fixture() def mock_cursor_for_replica_stats(): - with mock.patch('psycopg.connect') as connect: - cursor = mock.MagicMock() + with mock.patch('psycopg_pool.ConnectionPool.connection') as pooled_conn: data = deque() - connect.return_value = mock.MagicMock(cursor=mock.MagicMock(return_value=cursor)) + mocked_cursor = mock.MagicMock() + mocked_conn = mock.MagicMock() + mocked_conn.cursor.return_value = mocked_cursor + + pooled_conn.return_value.__enter__.return_value = mocked_conn def cursor_execute(query, second_arg=""): + print(query) if "FROM pg_stat_replication" in query: data.appendleft(['app1', 'streaming', 'async', '1.1.1.1', 12, 12, 12, 12]) data.appendleft(['app2', 'backup', 'sync', '1.1.1.1', 13, 13, 13, 13]) elif query == 'SHOW SERVER_VERSION;': + print("SHOW SERVER_VERSION") data.appendleft(['10.15']) def cursor_fetchall(): @@ -126,10 +132,11 @@ def cursor_fetchall(): yield data.pop() def cursor_fetchone(): + print("fetchone") return data.pop() - cursor.__enter__().execute = cursor_execute - cursor.__enter__().fetchall = cursor_fetchall - cursor.__enter__().fetchone = cursor_fetchone + mocked_cursor.__enter__().execute = cursor_execute + mocked_cursor.__enter__().fetchall = cursor_fetchall + mocked_cursor.__enter__().fetchone = cursor_fetchone yield diff --git a/postgres/tests/test_connections.py b/postgres/tests/test_connections.py index 69310cceb64d3..67cf99caac8ad 100644 --- a/postgres/tests/test_connections.py +++ b/postgres/tests/test_connections.py @@ -7,9 +7,9 @@ import time import uuid -import psycopg import pytest from psycopg.rows import dict_row +from psycopg_pool import ConnectionPool from datadog_checks.postgres import PostgreSql from datadog_checks.postgres.connections import ConnectionPoolFullError, MultiDatabaseConnectionPool @@ -28,25 +28,24 @@ def test_conn_pool(pg_instance): check = PostgreSql('postgres', {}, [pg_instance]) pool = MultiDatabaseConnectionPool(check, check._new_connection) - db = pool._get_connection_raw('postgres', 1) - assert pool._stats.connection_opened == 1 - pool.prune_connections() - assert len(pool._conns) == 1 - assert pool._stats.connection_closed == 0 - - with db.cursor(row_factory=dict_row) as cursor: - cursor.execute("select 1") - rows = cursor.fetchall() - assert len(rows) == 1 and list(rows[0].values())[0] + with pool.get_connection('postgres', 1): + assert pool._stats.connection_opened == 1 - time.sleep(0.001) + # exiting the context block should set the connection to inactive + # and it should be pruned pool.prune_connections() assert len(pool._conns) == 0 assert pool._stats.connection_closed == 1 - assert pool._stats.connection_closed_failed == 0 assert pool._stats.connection_pruned == 1 + assert pool._stats.connection_closed_failed == 0 - db = pool._get_connection_raw('postgres', 999 * 1000) + db = pool._get_connection_pool('postgres', 999 * 1000) + # run a simple query, and return conn object to the pool + with db.connection() as conn: + with conn.cursor(row_factory=dict_row) as cursor: + cursor.execute("select 1") + rows = cursor.fetchall() + assert len(rows) == 1 and list(rows[0].values())[0] assert len(pool._conns) == 1 assert pool._stats.connection_opened == 2 success = pool.close_all_connections(timeout=5) @@ -70,20 +69,20 @@ def test_conn_pool_no_leaks_on_close(pg_instance): # Used to make verification queries pool2 = MultiDatabaseConnectionPool( - check, lambda dbname: psycopg.connect(host=HOST, dbname=dbname, user=USER_ADMIN, password=PASSWORD_ADMIN) + check, lambda dbname, min_pool_size, max_pool_size: local_pool(dbname, min_pool_size, max_pool_size) ) # Iterate in the test many times to detect flakiness for _ in range(20): def exec_connection(pool, wg, dbname): - db = pool._get_connection_raw(dbname, 10 * 1000) - with db.cursor(row_factory=dict_row) as cursor: - cursor.execute("select current_database()") - rows = cursor.fetchall() - assert len(rows) == 1 - assert list(rows[0].values())[0] == dbname - wg.done() + with pool._get_connection_pool(dbname, 10 * 1000).connection() as conn: + with conn.cursor(row_factory=dict_row) as cursor: + cursor.execute("select current_database()") + rows = cursor.fetchall() + assert len(rows) == 1 + assert list(rows[0].values())[0] == dbname + wg.done() conn_count = 100 threadpool = [] @@ -131,7 +130,7 @@ def test_conn_pool_no_leaks_on_prune(pg_instance): pool = MultiDatabaseConnectionPool(check, check._new_connection) # Used to make verification queries pool2 = MultiDatabaseConnectionPool( - check, lambda dbname: psycopg.connect(host=HOST, dbname=dbname, user=USER_ADMIN, password=PASSWORD_ADMIN) + check, lambda dbname, min_pool_size, max_pool_size: local_pool(dbname, min_pool_size, max_pool_size) ) ttl_long = 90 * 1000 ttl_short = 1 @@ -140,14 +139,17 @@ def get_many_connections(count, ttl): """ Retrieves the number of connections from the pool with the specified TTL """ + conn_pids = [] for i in range(0, count): dbname = 'dogs_{}'.format(i) - db = pool._get_connection_raw(dbname, ttl) - with db.cursor(row_factory=dict_row) as cursor: - cursor.execute("select current_database()") - rows = cursor.fetchall() - assert len(rows) == 1 - assert list(rows[0].values())[0] == dbname + with pool.get_connection(dbname, ttl) as conn: + with conn.cursor(row_factory=dict_row) as cursor: + cursor.execute("select current_database()") + rows = cursor.fetchall() + assert len(rows) == 1 + assert list(rows[0].values())[0] == dbname + conn_pids.append(conn.info.backend_pid) + return set(conn_pids) pool.close_all_connections(timeout=5) @@ -170,18 +172,16 @@ def get_many_connections(count, ttl): < approximate_deadline + datetime.timedelta(seconds=1) ) assert not db.closed - assert db.info.status == psycopg.pq.ConnStatus.OK # Check that those pooled connections do exist on the database rows = get_activity(pool2, unique_id) assert len(rows) == 50 assert len({row['datname'] for row in rows}) == 50 assert all(row['state'] == 'idle' for row in rows) - pool._stats.reset() # Repeat this process many times and expect that only one connection is created per database for _ in range(100): - get_many_connections(51, ttl_long) + conn_pids = get_many_connections(51, ttl_long) assert pool._stats.connection_opened == 1 attempts_to_verify = 10 @@ -190,8 +190,6 @@ def get_many_connections(count, ttl): for attempt in range(attempts_to_verify): rows = get_activity(pool2, unique_id) server_pids = {row['pid'] for row in rows} - conns = [c.connection for c in pool._conns.values()] - conn_pids = {db.info.backend_pid for db in conns} leaked_rows = [row for row in rows if row['pid'] in server_pids - conn_pids] if not leaked_rows: break @@ -243,7 +241,7 @@ def test_conn_pool_single_connection(pg_instance): # Used to make verification queries pool2 = MultiDatabaseConnectionPool( - check, lambda dbname: psycopg.connect(host=HOST, dbname=dbname, user=USER_ADMIN, password=PASSWORD_ADMIN) + check, lambda dbname, min_pool_size, max_pool_size: local_pool(dbname, min_pool_size, max_pool_size) ) pool = MultiDatabaseConnectionPool(check, check._new_connection) @@ -295,7 +293,7 @@ def pretend_to_run_query(pool, dbname): # ask for one more connection with pytest.raises(ConnectionPoolFullError): - with pool.get_connection('dogs_{}'.format(limit + 1), 1, 1): + with pool.get_connection(dbname='dogs_{}'.format(limit + 1), ttl_ms=1, timeout=1): pass # join threads @@ -313,6 +311,16 @@ def pretend_to_run_query(pool, dbname): assert pool._stats.connection_closed == limit + 1 +def local_pool(dbname, min_pool_size, max_pool_size): + args = { + 'host': HOST, + 'user': USER_ADMIN, + 'password': PASSWORD_ADMIN, + 'dbname': dbname, + } + return ConnectionPool(min_size=min_pool_size, max_size=max_pool_size, kwargs=args, open=True, name=dbname) + + def get_activity(db_pool, unique_id): """ Fetches all pg_stat_activity rows generated by this test and connection to a "dogs%" database diff --git a/postgres/tests/test_deadlock.py b/postgres/tests/test_deadlock.py index a238464c6021f..66b07d3c89f21 100644 --- a/postgres/tests/test_deadlock.py +++ b/postgres/tests/test_deadlock.py @@ -12,14 +12,17 @@ from .common import DB_NAME, HOST, PORT, POSTGRES_VERSION -def wait_on_result(cursor=None, sql=None, binds=None, expected_value=None): - for _i in range(300): - cursor.execute(sql, binds) - result = cursor.fetchone()[0] - if result == expected_value: - break - - time.sleep(0.1) +def wait_on_result(sql=None, binds=None, expected_value=None): + for _i in range(5): + with psycopg.connect( + host=HOST, dbname=DB_NAME, user="bob", password="bob", cursor_factory=ClientCursor + ) as tconn: + with tconn.cursor() as cursor: + cursor.execute(sql, binds) + result = cursor.fetchone()[0] + if result == expected_value: + break + time.sleep(0.1) else: return False @@ -33,8 +36,6 @@ def wait_on_result(cursor=None, sql=None, binds=None, expected_value=None): def test_deadlock(aggregator, dd_run_check, integration_check, pg_instance): check = integration_check(pg_instance) check._connect() - conn = check._new_connection(pg_instance['dbname']) - cursor = conn.cursor() def execute_in_thread(q, args): with psycopg.connect( @@ -55,8 +56,10 @@ def execute_in_thread(q, args): update_sql = "update personsdup1 set address = 'changed' where personid = %s" deadlock_count_sql = "select deadlocks from pg_stat_database where datname = %s" - cursor.execute(deadlock_count_sql, (DB_NAME,)) - deadlocks_before = cursor.fetchone()[0] + with check._new_connection(pg_instance['dbname']).connection() as conn: + with conn.cursor() as cursor: + cursor.execute(deadlock_count_sql, (DB_NAME,)) + deadlocks_before = cursor.fetchone()[0] conn_args = {'host': HOST, 'dbname': DB_NAME, 'user': "bob", 'password': "bob"} conn1 = psycopg.connect(**conn_args, autocommit=False, cursor_factory=ClientCursor) @@ -98,7 +101,7 @@ def execute_in_thread(q, args): AND blocking_activity.application_name = %s AND blocked_activity.application_name = %s """ - is_locked = wait_on_result(cursor=cursor, sql=lock_count_sql, binds=(appname1, appname2), expected_value=1) + is_locked = wait_on_result(sql=lock_count_sql, binds=(appname1, appname2), expected_value=1) if not is_locked: raise Exception("ERROR: Couldn't reproduce a deadlock. That can happen on an extremely overloaded system.") @@ -111,7 +114,7 @@ def execute_in_thread(q, args): dd_run_check(check) - wait_on_result(cursor=cursor, sql=deadlock_count_sql, binds=(DB_NAME,), expected_value=deadlocks_before + 1) + wait_on_result(sql=deadlock_count_sql, binds=(DB_NAME,), expected_value=deadlocks_before + 1) aggregator.assert_metric( 'postgresql.deadlocks.count', diff --git a/postgres/tests/test_explain_parameterized_queries.py b/postgres/tests/test_explain_parameterized_queries.py index 06a152f6a3b9b..db4980a5403d8 100644 --- a/postgres/tests/test_explain_parameterized_queries.py +++ b/postgres/tests/test_explain_parameterized_queries.py @@ -52,16 +52,6 @@ def test_explain_parameterized_queries(integration_check, dbm_instance, query, e assert explain_err_code == expected_explain_err_code assert err is None - explain_param_queries = check.statement_samples._explain_parameterized_queries - # check that we deallocated the prepared statement after explaining - rows = explain_param_queries._execute_query_and_fetch_rows( - DB_NAME, - "SELECT * FROM pg_prepared_statements WHERE name = 'dd_{query_signature}'".format( - query_signature=compute_sql_signature(query) - ), - ) - assert len(rows) == 0 - @pytest.mark.parametrize( "query,expected_generic_values", @@ -85,7 +75,8 @@ def test_explain_parameterized_queries_generic_params(integration_check, dbm_ins query_signature = compute_sql_signature(query) explain_param_queries = check.statement_samples._explain_parameterized_queries - assert explain_param_queries._create_prepared_statement(DB_NAME, query, query, query_signature) is True - assert expected_generic_values == explain_param_queries._get_number_of_parameters_for_prepared_statement( - DB_NAME, query_signature - ) + with check._new_connection(DB_NAME).connection() as conn: + assert explain_param_queries._create_prepared_statement(conn, query, query, query_signature) is True + assert expected_generic_values == explain_param_queries._get_number_of_parameters_for_prepared_statement( + conn, query_signature + ) diff --git a/postgres/tests/test_pg_integration.py b/postgres/tests/test_pg_integration.py index e2343cc9f10dc..b2b7a7c613c78 100644 --- a/postgres/tests/test_pg_integration.py +++ b/postgres/tests/test_pg_integration.py @@ -7,7 +7,6 @@ import mock import psycopg import pytest -from semver import VersionInfo from datadog_checks.postgres import PostgreSql from datadog_checks.postgres.__about__ import __version__ @@ -52,7 +51,7 @@ ) def test_common_metrics(aggregator, integration_check, pg_instance, is_aurora): check = integration_check(pg_instance) - check._is_aurora = is_aurora + check.is_aurora = is_aurora check.check(pg_instance) expected_tags = _get_expected_tags(check, pg_instance) @@ -404,12 +403,14 @@ def test_backend_transaction_age(aggregator, integration_check, pg_instance): @requires_over_10 def test_wrong_version(aggregator, integration_check, pg_instance): check = integration_check(pg_instance) - # Enforce to cache wrong version - check._version = VersionInfo(*[9, 6, 0]) + # Enforce the wrong version + check._version_utils.get_raw_version = mock.MagicMock(return_value="9.6.0") check.check(pg_instance) assert_state_clean(check) + # Reset the mock to a good version + check._version_utils.get_raw_version = mock.MagicMock(return_value="13.0.0") check.check(pg_instance) assert_state_set(check) @@ -491,9 +492,10 @@ def test_query_timeout(integration_check, pg_instance): pg_instance['query_timeout'] = 1000 check = integration_check(pg_instance) check._connect() - cursor = check.db.cursor() with pytest.raises(psycopg.errors.QueryCanceled): - cursor.execute("select pg_sleep(2000)") + with check.db.connection() as conn: + with conn.cursor() as cursor: + cursor.execute("select pg_sleep(2000)") @requires_over_10 @@ -599,7 +601,7 @@ def test_correct_hostname(dbm_enabled, reported_hostname, expected_hostname, agg c_metrics = c_metrics + DBM_MIGRATED_METRICS for name in c_metrics: aggregator.assert_metric(name, count=1, tags=expected_tags_with_db, hostname=expected_hostname) - check_activity_metrics(aggregator, tags=expected_activity_tags, hostname=expected_hostname) + check_activity_metrics(aggregator, tags=expected_activity_tags) for name in CONNECTION_METRICS: aggregator.assert_metric(name, count=1, tags=expected_tags_no_db, hostname=expected_hostname) @@ -626,6 +628,9 @@ def test_correct_hostname(dbm_enabled, reported_hostname, expected_hostname, agg @pytest.mark.usefixtures('dd_environment') def test_database_instance_metadata(aggregator, pg_instance, dbm_enabled, reported_hostname): pg_instance['dbm'] = dbm_enabled + # this will block on cancel and wait for the coll interval of 600 seconds, + # unless the collection_interval is set to a short amount of time + pg_instance['collect_resources'] = {'collection_interval': 0.1} if reported_hostname: pg_instance['reported_hostname'] = reported_hostname expected_host = reported_hostname if reported_hostname else 'stubbed.hostname' @@ -661,7 +666,6 @@ def assert_state_clean(check): assert check.metrics_cache.archiver_metrics is None assert check.metrics_cache.replication_metrics is None assert check.metrics_cache.activity_metrics is None - assert check._is_aurora is None def assert_state_set(check): @@ -670,4 +674,3 @@ def assert_state_set(check): if POSTGRES_VERSION != '9.3': assert check.metrics_cache.archiver_metrics assert check.metrics_cache.replication_metrics - assert check._is_aurora is False diff --git a/postgres/tests/test_pg_replication.py b/postgres/tests/test_pg_replication.py index bc5245d488856..886fdba328b2b 100644 --- a/postgres/tests/test_pg_replication.py +++ b/postgres/tests/test_pg_replication.py @@ -31,6 +31,7 @@ @requires_over_10 def test_common_replica_metrics(aggregator, integration_check, metrics_cache_replica, pg_replica_instance): check = integration_check(pg_replica_instance) + check.initialize_is_aurora() check.check(pg_replica_instance) expected_tags = _get_expected_tags(check, pg_replica_instance) @@ -54,6 +55,7 @@ def test_common_replica_metrics(aggregator, integration_check, metrics_cache_rep @requires_over_10 def test_wal_receiver_metrics(aggregator, integration_check, pg_instance, pg_replica_instance): check = integration_check(pg_replica_instance) + check.initialize_is_aurora() expected_tags = _get_expected_tags(check, pg_replica_instance, status='streaming') with _get_superconn(pg_instance) as conn: with conn.cursor() as cur: diff --git a/postgres/tests/test_statements.py b/postgres/tests/test_statements.py index cc6a463556333..b739a5da9f839 100644 --- a/postgres/tests/test_statements.py +++ b/postgres/tests/test_statements.py @@ -99,12 +99,12 @@ def test_statement_samples_enabled_config( def test_statement_metrics_version(integration_check, dbm_instance, version, expected_payload_version): if version: check = integration_check(dbm_instance) - check._version = version + check.version = version check._connect() assert payload_pg_version(check.version) == expected_payload_version else: with mock.patch( - 'datadog_checks.postgres.postgres.PostgreSql.version', new_callable=mock.PropertyMock + 'datadog_checks.postgres.postgres.PostgreSql.load_version', new_callable=mock.MagicMock ) as patched_version: patched_version.return_value = None check = integration_check(dbm_instance) @@ -362,28 +362,29 @@ def obfuscate_sql(query, options=None): check = integration_check(dbm_instance) check._connect() - cursor = check.db.cursor() - - # Execute the query once to begin tracking it. Execute again between checks to track the difference. - # This should result in a single metric for that query_signature having a value of 2 - with mock.patch.object(datadog_agent, 'obfuscate_sql', passthrough=True) as mock_agent: - mock_agent.side_effect = obfuscate_sql - cursor.execute(query, (['app1', 'app2'],)) - cursor.execute(query, (['app1', 'app2', 'app3'],)) - run_one_check(check, dbm_instance) + # Get a connection separate from the one used by the check to avoid hitting the connection pool limit + with check._new_connection('postgres', max_pool_size=1).connection() as conn: + with conn.cursor() as cursor: + # Execute the query once to begin tracking it. Execute again between checks to track the difference. + # This should result in a single metric for that query_signature having a value of 2 + with mock.patch.object(datadog_agent, 'obfuscate_sql', passthrough=True) as mock_agent: + mock_agent.side_effect = obfuscate_sql + cursor.execute(query, (['app1', 'app2'],)) + cursor.execute(query, (['app1', 'app2', 'app3'],)) + run_one_check(check, dbm_instance) - cursor.execute(query, (['app1', 'app2'],)) - cursor.execute(query, (['app1', 'app2', 'app3'],)) - run_one_check(check, dbm_instance) + cursor.execute(query, (['app1', 'app2'],)) + cursor.execute(query, (['app1', 'app2', 'app3'],)) + run_one_check(check, dbm_instance) - events = aggregator.get_event_platform_events("dbm-metrics") - assert len(events) == 1 - event = events[0] + events = aggregator.get_event_platform_events("dbm-metrics") + assert len(events) == 1 + event = events[0] - matching = [e for e in event['postgres_rows'] if e['query_signature'] == query_signature] - assert len(matching) == 1 - row = matching[0] - assert row['calls'] == 2 + matching = [e for e in event['postgres_rows'] if e['query_signature'] == query_signature] + assert len(matching) == 1 + row = matching[0] + assert row['calls'] == 2 @pytest.fixture @@ -750,6 +751,7 @@ def test_statement_metadata( """Tests for metadata in both samples and metrics""" dbm_instance['pg_stat_statements_view'] = pg_stat_statements_view dbm_instance['query_samples']['run_sync'] = True + dbm_instance['query_samples']['explain_parameterized_queries'] = False dbm_instance['query_metrics']['run_sync'] = True # If query or normalized_query changes, the query_signatures for both will need to be updated as well. @@ -1006,7 +1008,7 @@ def execute_in_thread(q): for key in expected_out: assert expected_out[key] == bobs_query[key] if POSTGRES_VERSION.split('.')[0] == "9": - # pg v < 10 does not have a backend_type column + # pg v < 10 does not have a backend_type column, # so we shouldn't see this key in our activity rows expected_keys.remove('backend_type') if POSTGRES_VERSION == '9.5': @@ -1070,7 +1072,6 @@ def execute_in_thread(q): assert bobs_query['state'] == "idle in transaction" finally: blocking_conn.close() - conn.close() @pytest.mark.parametrize( @@ -1338,7 +1339,7 @@ def test_load_pg_settings(aggregator, integration_check, dbm_instance, db_user): dbm_instance["dbname"] = "postgres" check = integration_check(dbm_instance) check._connect() - check._load_pg_settings(check.db) + check.load_pg_settings(check.db) if db_user == 'datadog_no_catalog': aggregator.assert_metric( "dd.postgres.error", @@ -1354,16 +1355,16 @@ def test_load_pg_settings(aggregator, integration_check, dbm_instance, db_user): assert len(aggregator.metrics("dd.postgres.error")) == 0 -def test_pg_settings_caching(aggregator, integration_check, dbm_instance): +def test_pg_settings_caching(integration_check, dbm_instance): dbm_instance["username"] = "datadog" dbm_instance["dbname"] = "postgres" check = integration_check(dbm_instance) assert not check.pg_settings, "pg_settings should not have been initialized yet" check._connect() - check.get_main_db() + check.db_pool.get_main_db_pool() assert "track_activity_query_size" in check.pg_settings check.pg_settings["test_key"] = True - check.get_main_db() + check.db_pool.get_main_db_pool() assert ( "test_key" in check.pg_settings ), "key should not have been blown away. If it was then pg_settings was not cached correctly" @@ -1395,8 +1396,8 @@ def test_statement_samples_main_collection_rate_limit(aggregator, integration_ch check_frequency = collection_interval / 5.0 _check_until_time(check, dbm_instance, sleep_time, check_frequency) max_collections = int(1 / collection_interval * sleep_time) + 1 - check.cancel() metrics = aggregator.metrics("dd.postgres.collect_statement_samples.time") + check.cancel() assert max_collections / 2.0 <= len(metrics) <= max_collections diff --git a/postgres/tests/test_unit.py b/postgres/tests/test_unit.py index 2f6bffd0e92c2..2add2f6aa459a 100644 --- a/postgres/tests/test_unit.py +++ b/postgres/tests/test_unit.py @@ -79,8 +79,8 @@ def test_get_instance_with_default(pg_instance, collect_default_database): """ pg_instance['collect_default_database'] = collect_default_database check = PostgreSql('postgres', {}, [pg_instance]) - check._version = VersionInfo(9, 2, 0) - res = check.metrics_cache.get_instance_metrics(check._version) + check.version = VersionInfo(9, 2, 0) + res = check.metrics_cache.get_instance_metrics(check.version) dbfilter = " AND psd.datname not ilike 'postgres'" if collect_default_database: assert dbfilter not in res['query'] @@ -93,8 +93,7 @@ def test_malformed_get_custom_queries(check): Test early-exit conditions for _get_custom_queries() """ check.log = MagicMock() - db = MagicMock() - check.db = db + check.db = MagicMock() check._config.custom_queries = [{}] @@ -124,7 +123,7 @@ def test_malformed_get_custom_queries(check): # Make sure we gracefully handle an error while performing custom queries malformed_custom_query_column = {} malformed_custom_query['columns'] = [malformed_custom_query_column] - db.cursor().__enter__().execute.side_effect = psycopg.ProgrammingError('FOO') + check.db.connection().__enter__().cursor().__enter__().execute.side_effect = psycopg.ProgrammingError('FOO') check._collect_custom_queries([]) check.log.error.assert_called_once_with( "Error executing query for metric_prefix %s: %s", malformed_custom_query['metric_prefix'], 'FOO' @@ -135,8 +134,8 @@ def test_malformed_get_custom_queries(check): malformed_custom_query_column = {} malformed_custom_query['columns'] = [malformed_custom_query_column] query_return = ['num', 1337] - db.cursor().__enter__().execute.side_effect = None - db.cursor().__enter__().__iter__.return_value = iter([query_return]) + check.db.connection().__enter__().cursor().__enter__().execute.side_effect = None + check.db.connection().__enter__().cursor().__enter__().__iter__.return_value = iter([query_return]) check._collect_custom_queries([]) check.log.error.assert_called_once_with( "query result for metric_prefix %s: expected %s columns, got %s", @@ -147,7 +146,7 @@ def test_malformed_get_custom_queries(check): check.log.reset_mock() # Make sure the query does not return an empty result - db.cursor().__enter__().__iter__.return_value = iter([[]]) + check.db.connection().__enter__().cursor().__enter__().__iter__.return_value = iter([[]]) check._collect_custom_queries([]) check.log.debug.assert_called_with( "query result for metric_prefix %s: returned an empty result", malformed_custom_query['metric_prefix'] @@ -156,7 +155,7 @@ def test_malformed_get_custom_queries(check): # Make sure 'name' is defined in each column malformed_custom_query_column['some_key'] = 'some value' - db.cursor().__enter__().__iter__.return_value = iter([[1337]]) + check.db.connection().__enter__().cursor().__enter__().__iter__.return_value = iter([[1337]]) check._collect_custom_queries([]) check.log.error.assert_called_once_with( "column field `name` is required for metric_prefix `%s`", malformed_custom_query['metric_prefix'] @@ -165,7 +164,7 @@ def test_malformed_get_custom_queries(check): # Make sure 'type' is defined in each column malformed_custom_query_column['name'] = 'num' - db.cursor().__enter__().__iter__.return_value = iter([[1337]]) + check.db.connection().__enter__().cursor().__enter__().__iter__.return_value = iter([[1337]]) check._collect_custom_queries([]) check.log.error.assert_called_once_with( "column field `type` is required for column `%s` of metric_prefix `%s`", @@ -176,7 +175,7 @@ def test_malformed_get_custom_queries(check): # Make sure 'type' is a valid metric type malformed_custom_query_column['type'] = 'invalid_type' - db.cursor().__enter__().__iter__.return_value = iter([[1337]]) + check.db.connection().__enter__().cursor().__enter__().__iter__.return_value = iter([[1337]]) check._collect_custom_queries([]) check.log.error.assert_called_once_with( "invalid submission method `%s` for column `%s` of metric_prefix `%s`", @@ -190,7 +189,7 @@ def test_malformed_get_custom_queries(check): malformed_custom_query_column['type'] = 'gauge' query_return = MagicMock() query_return.__float__.side_effect = ValueError('Mocked exception') - db.cursor().__enter__().__iter__.return_value = iter([[query_return]]) + check.db.connection().__enter__().cursor().__enter__().__iter__.return_value = iter([[query_return]]) check._collect_custom_queries([]) check.log.error.assert_called_once_with( "non-numeric value `%s` for metric column `%s` of metric_prefix `%s`",