Skip to content

Commit

Permalink
Prevent hostname evaluating to None in sqlserver check (#18237)
Browse files Browse the repository at this point in the history
* Prevent hostname evaluating to None in sqlserver check

* always reference resolved_hostname property

* add changelog

* fix tests

* reload hostname when engine edition static info expires

---------

Co-authored-by: Zhengda Lu <[email protected]>
  • Loading branch information
jmeunier28 and lu-zhengda authored Oct 28, 2024
1 parent 3135f32 commit e50672b
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 23 deletions.
1 change: 1 addition & 0 deletions sqlserver/changelog.d/18237.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Prevent hostname evaluating to None in sqlserver check
3 changes: 1 addition & 2 deletions sqlserver/datadog_checks/sqlserver/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ class Connection(object):

VALID_ADOPROVIDERS = ['SQLOLEDB', 'MSOLEDBSQL', 'MSOLEDBSQL19', 'SQLNCLI11']

def __init__(self, host, init_config, instance_config, service_check_handler):
self.host = host
def __init__(self, init_config, instance_config, service_check_handler):
self.instance = instance_config
self.service_check_handler = service_check_handler
self.log = get_check_logger()
Expand Down
13 changes: 8 additions & 5 deletions sqlserver/datadog_checks/sqlserver/sqlserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(self, name, init_config, instances):
# go through the agent internal metrics submission processing those tags
self.non_internal_tags = copy.deepcopy(self.tags)
self.check_initializations.append(self.initialize_connection)
self.check_initializations.append(self.set_resolved_hostname)
self.check_initializations.append(self.load_static_information)
self.check_initializations.append(self.set_resolved_hostname_metadata)
self.check_initializations.append(self.config_checks)
self.check_initializations.append(self.make_metric_list_to_collect)
Expand Down Expand Up @@ -244,7 +244,6 @@ def set_resource_tags(self):

def set_resolved_hostname(self):
# load static information cache
self.load_static_information()
if self._resolved_hostname is None:
if self._config.reported_hostname:
self._resolved_hostname = self._config.reported_hostname
Expand Down Expand Up @@ -280,6 +279,7 @@ def resolved_hostname(self):
return self._resolved_hostname

def load_static_information(self):
engine_edition_reloaded = False
expected_keys = {STATIC_INFO_VERSION, STATIC_INFO_MAJOR_VERSION, STATIC_INFO_ENGINE_EDITION, STATIC_INFO_RDS}
missing_keys = expected_keys - set(self.static_info_cache.keys())
if missing_keys:
Expand Down Expand Up @@ -309,6 +309,7 @@ def load_static_information(self):
result = cursor.fetchone()
if result:
self.static_info_cache[STATIC_INFO_ENGINE_EDITION] = result[0]
engine_edition_reloaded = True
else:
self.log.warning("failed to load version static information due to empty results")
if STATIC_INFO_RDS not in self.static_info_cache:
Expand All @@ -318,9 +319,11 @@ def load_static_information(self):
self.static_info_cache[STATIC_INFO_RDS] = True
else:
self.static_info_cache[STATIC_INFO_RDS] = False
# re-initialize resolved_hostname to ensure we take into consideration the static information
# re-initialize resolved_hostname to ensure we take into consideration the egine edition
# after it's loaded
self._resolved_hostname = None
if engine_edition_reloaded:
self._resolved_hostname = None
self.set_resolved_hostname()

def debug_tags(self):
return self.tags + ["agent_hostname:{}".format(self.agent_hostname)]
Expand All @@ -342,7 +345,6 @@ def agent_hostname(self):

def initialize_connection(self):
self.connection = Connection(
host=self.resolved_hostname,
init_config=self.init_config,
instance_config=self.instance,
service_check_handler=self.handle_service_check,
Expand Down Expand Up @@ -709,6 +711,7 @@ def _check_database_conns(self):

def check(self, _):
if self.do_check:
self.load_static_information()
# configure custom queries for the check
if self._query_manager is None:
# use QueryManager to process custom queries
Expand Down
24 changes: 8 additions & 16 deletions sqlserver/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def test_warn_trusted_connection_username_pass(instance_minimal_defaults, cs, us
instance_minimal_defaults["connection_string"] = cs
instance_minimal_defaults["username"] = username
instance_minimal_defaults["password"] = password
check = SQLServer(CHECK_NAME, {}, [instance_minimal_defaults])
connection = Connection(check.resolved_hostname, {}, instance_minimal_defaults, None)
connection = Connection({}, instance_minimal_defaults, None)
connection.log = mock.MagicMock()
connection._connection_options_validation('somekey', 'somedb')
if expect_warning:
Expand All @@ -97,8 +96,7 @@ def test_warn_trusted_connection_username_pass(instance_minimal_defaults, cs, us
)
def test_will_warn_parameters_for_the_wrong_connection(instance_minimal_defaults, connector, param):
instance_minimal_defaults.update({'connector': connector, param: 'foo'})
check = SQLServer(CHECK_NAME, {}, [instance_minimal_defaults])
connection = Connection(check.resolved_hostname, {}, instance_minimal_defaults, None)
connection = Connection({}, instance_minimal_defaults, None)
connection.log = mock.MagicMock()
connection._connection_options_validation('somekey', 'somedb')
connection.log.warning.assert_called_once_with(
Expand Down Expand Up @@ -130,8 +128,7 @@ def test_will_warn_parameters_for_the_wrong_connection(instance_minimal_defaults
)
def test_will_fail_for_duplicate_parameters(instance_minimal_defaults, connector, cs, param, should_fail):
instance_minimal_defaults.update({'connector': connector, param: 'foo', 'connection_string': cs + "=foo"})
check = SQLServer(CHECK_NAME, {}, [instance_minimal_defaults])
connection = Connection(check.resolved_hostname, {}, instance_minimal_defaults, None)
connection = Connection({}, instance_minimal_defaults, None)
if should_fail:
match = (
"%s has been provided both in the connection string and as a configuration option (%s), "
Expand Down Expand Up @@ -162,8 +159,7 @@ def test_will_fail_for_duplicate_parameters(instance_minimal_defaults, connector
def test_will_fail_for_wrong_parameters_in_the_connection_string(instance_minimal_defaults, connector, cs):
instance_minimal_defaults.update({'connector': connector, 'connection_string': cs + '=foo'})
other_connector = 'odbc' if connector != 'odbc' else 'adodbapi'
check = SQLServer(CHECK_NAME, {}, [instance_minimal_defaults])
connection = Connection(check.resolved_hostname, {}, instance_minimal_defaults, None)
connection = Connection({}, instance_minimal_defaults, None)
match = (
"%s has been provided in the connection string. "
"This option is only available for %s connections, however %s has been selected"
Expand Down Expand Up @@ -226,8 +222,7 @@ def test_managed_auth_config_valid(instance_minimal_defaults, name, managed_iden
for k, v in managed_identity_config.items():
instance_minimal_defaults[k] = v
instance_minimal_defaults.update({'connector': 'odbc'})
check = SQLServer(CHECK_NAME, {}, [instance_minimal_defaults])
connection = Connection(check.resolved_hostname, {}, instance_minimal_defaults, None)
connection = Connection({}, instance_minimal_defaults, None)
if should_fail:
with pytest.raises(ConfigurationError, match=re.escape(expected_err)):
connection._connection_options_validation('somekey', 'somedb')
Expand Down Expand Up @@ -287,8 +282,7 @@ def test_managed_auth_config_valid(instance_minimal_defaults, name, managed_iden
def test_config_with_and_without_port(instance_minimal_defaults, host, port, expected_host):
instance_minimal_defaults["host"] = host
instance_minimal_defaults["port"] = port
check = SQLServer(CHECK_NAME, {}, [instance_minimal_defaults])
connection = Connection(check.resolved_hostname, {}, instance_minimal_defaults, None)
connection = Connection({}, instance_minimal_defaults, None)
_, result_host, _, _, _, _ = connection._get_access_info('somekey', 'somedb')
assert result_host == expected_host

Expand Down Expand Up @@ -369,9 +363,7 @@ def test_connection_failure(aggregator, dd_run_check, instance_docker):

try:
# Break the connection
check.connection = Connection(
check.resolved_hostname, {}, {'host': '', 'username': '', 'password': ''}, check.handle_service_check
)
check.connection = Connection({}, {'host': '', 'username': '', 'password': ''}, check.handle_service_check)
dd_run_check(check)
except Exception:
aggregator.assert_service_check(
Expand Down Expand Up @@ -495,7 +487,7 @@ def test_connection_error_reporting(
expected_error_pattern = matching_patterns[0]

check = SQLServer(CHECK_NAME, {}, [instance_docker])
connection = Connection(check.resolved_hostname, check.init_config, check.instance, check.handle_service_check)
connection = Connection(check.init_config, check.instance, check.handle_service_check)
with pytest.raises(SQLConnectionError) as excinfo:
with connection.open_managed_default_connection():
pytest.fail("connection should not have succeeded")
Expand Down
24 changes: 24 additions & 0 deletions sqlserver/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,3 +873,27 @@ def test_propagate_agent_tags(
status=SQLServer.OK,
tags=expected_tags,
)


@pytest.mark.integration
@pytest.mark.usefixtures('dd_environment')
def test_check_static_information_expire(aggregator, dd_run_check, init_config, instance_docker):
sqlserver_check = SQLServer(CHECK_NAME, init_config, [instance_docker])
dd_run_check(sqlserver_check)
assert sqlserver_check.static_info_cache is not None
assert len(sqlserver_check.static_info_cache.keys()) == 4
assert sqlserver_check.resolved_hostname == 'stubbed.hostname'

# manually clear static information cache
sqlserver_check.static_info_cache.clear()
dd_run_check(sqlserver_check)
assert sqlserver_check.static_info_cache is not None
assert len(sqlserver_check.static_info_cache.keys()) == 4
assert sqlserver_check.resolved_hostname == 'stubbed.hostname'

# manually pop STATIC_INFO_ENGINE_EDITION to make sure it is reloaded
sqlserver_check.static_info_cache.pop(STATIC_INFO_ENGINE_EDITION)
dd_run_check(sqlserver_check)
assert sqlserver_check.static_info_cache is not None
assert len(sqlserver_check.static_info_cache.keys()) == 4
assert sqlserver_check.resolved_hostname == 'stubbed.hostname'

0 comments on commit e50672b

Please sign in to comment.