From 17e154ff877fadc2a22ce028318f834c2618fe6e Mon Sep 17 00:00:00 2001 From: David Salvisberg Date: Tue, 3 Sep 2024 15:41:54 +0200 Subject: [PATCH] Improves error messages on `RunnerResults` for various scenarios. Fixes handling of `async_` keyword parameter. --- CHANGELOG.rst | 3 +++ docs/index.rst | 17 +++++++++++++++-- scripts/generate_module_hints.py | 6 ++++++ src/suitable/_module_types.py | 26 +++++++++++++------------- src/suitable/module_runner.py | 11 +++++++++-- src/suitable/runner_results.py | 32 ++++++++++++++++++++++++-------- tests/test_api.py | 22 ++++++++++++++++++---- 7 files changed, 88 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index cccf5b6..3305106 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,9 @@ Changelog --------- +- Improves error messages for various scenarios on `RunnerResults` + [Daverball] + - Only sets `ansible_connection` to `local` when `ansible_port` is `22`, since anything else is likely a SSH tunnel [Daverball] diff --git a/docs/index.rst b/docs/index.rst index a1e407f..b8ec456 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -71,7 +71,7 @@ Connect to a server using a username and a password:: remote_pass=password ) - print api.command('whoami').stdout() # prints 'admin' + print(api.command('whoami').stdout()) # prints 'admin' Run a command on multiple servers and get the output for each:: @@ -81,7 +81,20 @@ Run a command on multiple servers and get the output for each:: result = api.command('whoami') for server in servers: - print result.stdout(server) + print(result.stdout(server)) + +Or alternatively:: + + api = Api(['a.example.org', 'b.example.org']) + results = api.command('whoami') + + for server, result in results['contacted'].items(): + if 'stdout' in result: + print(server, result['stdout']) + +The latter is more robust for optional result components, since not +every server's result may contain it. + Which Modules are Available? ---------------------------- diff --git a/scripts/generate_module_hints.py b/scripts/generate_module_hints.py index c23b932..670109a 100755 --- a/scripts/generate_module_hints.py +++ b/scripts/generate_module_hints.py @@ -389,6 +389,12 @@ def write_return_type(returns: dict[str, Any] | None) -> None: return_type += '[Incomplete]' elif return_type == 'dict': return_type += '[str, Incomplete]' + elif return_type == 'complex': + # TODO: This seems to be more or less an alias to dict + # but it contains a schema for the contents. If it + # is always dict, then try to merge this with dict + # and generate a TypedDict using `contains`. + return_type = 'Incomplete' suffix = ' # type:ignore[override]' if name == 'values' else '' if len(name) + len(return_type) + len(suffix) > 33: # signature doesn't fit on one line diff --git a/src/suitable/_module_types.py b/src/suitable/_module_types.py index 73152ea..f4783ac 100644 --- a/src/suitable/_module_types.py +++ b/src/suitable/_module_types.py @@ -1467,7 +1467,7 @@ class PackageFactsResults(RunnerResults): """ - def ansible_facts(self, server: str | None = None) -> complex: + def ansible_facts(self, server: str | None = None) -> Incomplete: """ Facts to add to ansible_facts. @@ -1763,7 +1763,7 @@ class ServiceFactsResults(RunnerResults): """ - def ansible_facts(self, server: str | None = None) -> complex: + def ansible_facts(self, server: str | None = None) -> Incomplete: """ Facts to add to ansible_facts about the services on the system. @@ -2072,7 +2072,7 @@ class SysvinitResults(RunnerResults): """ - def results(self, server: str | None = None) -> complex: + def results(self, server: str | None = None) -> Incomplete: """ results from actions taken. @@ -3126,7 +3126,7 @@ def stdout_lines(self, server: str | None = None) -> list[Incomplete]: """ return self.acquire(server, 'stdout_lines') - def output(self, server: str | None = None) -> complex: + def output(self, server: str | None = None) -> Incomplete: """ Based on the value of display option will return either the set of transformed XML to JSON format from the RPC response with type dict or @@ -3177,7 +3177,7 @@ def stdout_lines(self, server: str | None = None) -> list[Incomplete]: """ return self.acquire(server, 'stdout_lines') - def output(self, server: str | None = None) -> complex: + def output(self, server: str | None = None) -> Incomplete: """ Based on the value of display option will return either the set of transformed XML to JSON format from the RPC response with type dict or @@ -3528,7 +3528,7 @@ def undefined_zones(self, server: str | None = None) -> list[Incomplete]: """ return self.acquire(server, 'undefined_zones') - def firewalld_info(self, server: str | None = None) -> complex: + def firewalld_info(self, server: str | None = None) -> Incomplete: """ Returns various information about firewalld configuration. @@ -3590,7 +3590,7 @@ class RhelFactsResults(RunnerResults): """ - def ansible_facts(self, server: str | None = None) -> complex: + def ansible_facts(self, server: str | None = None) -> Incomplete: """ Relevant Ansible Facts. @@ -4340,7 +4340,7 @@ def exitcode(self, server: str | None = None) -> str: """ return self.acquire(server, 'exitcode') - def feature_result(self, server: str | None = None) -> complex: + def feature_result(self, server: str | None = None) -> Incomplete: """ List of features that were installed or removed. @@ -4411,7 +4411,7 @@ def matched(self, server: str | None = None) -> int: """ return self.acquire(server, 'matched') - def files(self, server: str | None = None) -> complex: + def files(self, server: str | None = None) -> Incomplete: """ Information on the files/folders that match the criteria returned as a list of dictionary elements for each file matched. The entries are @@ -4764,7 +4764,7 @@ class WinPowershellResults(RunnerResults): """ - def result(self, server: str | None = None) -> complex: + def result(self, server: str | None = None) -> Incomplete: """ The values that were set by `$Ansible.Result` in the script. @@ -5276,7 +5276,7 @@ def changed(self, server: str | None = None) -> bool: """ return self.acquire(server, 'changed') - def stat(self, server: str | None = None) -> complex: + def stat(self, server: str | None = None) -> Incomplete: """ dictionary containing all the stat data. @@ -5716,7 +5716,7 @@ def privileges(self, server: str | None = None) -> dict[str, Incomplete]: """ return self.acquire(server, 'privileges') - def label(self, server: str | None = None) -> complex: + def label(self, server: str | None = None) -> Incomplete: """ The mandatory label set to the logon session. @@ -5750,7 +5750,7 @@ def groups(self, server: str | None = None) -> list[Incomplete]: """ return self.acquire(server, 'groups') - def account(self, server: str | None = None) -> complex: + def account(self, server: str | None = None) -> Incomplete: """ The running account SID details. diff --git a/src/suitable/module_runner.py b/src/suitable/module_runner.py index 6ccc865..5cdc0c7 100644 --- a/src/suitable/module_runner.py +++ b/src/suitable/module_runner.py @@ -172,7 +172,7 @@ def get_module_args( args_str = ' '.join(args).replace('=', '\\=') kwargs_str = ' '.join( - '{}="{}"'.format(k.rstrip('_'), v.replace('"', '\\"')) + '{}="{}"'.format(k, v.replace('"', '\\"')) for k, v in kwargs.items() ) @@ -189,6 +189,13 @@ def execute(self, *args: Any, **kwargs: Any) -> RunnerResults: if set_global_context: set_global_context(self.api.options) + # translate parameters that use a reserved keyword + # TODO: For now async is the only one we know about + # but there may be other ones + if 'async_' in kwargs: + # with conflicts prefer the real name + kwargs.setdefault('async', kwargs.pop('async_')) + # legacy key=value pairs shorthand approach module_args: dict[str, Any] | str if args: @@ -392,4 +399,4 @@ def evaluate_results( server: result for server, result in callback.unreachable.items() } - }) + }, dry_run=self.api.options.check) diff --git a/src/suitable/runner_results.py b/src/suitable/runner_results.py index c82af18..68d8389 100644 --- a/src/suitable/runner_results.py +++ b/src/suitable/runner_results.py @@ -34,23 +34,39 @@ class RunnerResults(_Base): """ - def __init__(self, results: _RunnerResults) -> None: + def __init__(self, results: _RunnerResults, dry_run: bool = False) -> None: + self.dry_run = dry_run self.update(results) # type:ignore[arg-type] def __getattr__(self, key: str) -> ResultsCallback: return lambda server=None: self.acquire(server, key) def acquire(self, server: str | None, key: str) -> Any: + contacted = self['contacted'] # if no server is given and exactly one contacted server exists # return the value of said server directly - if server is None and len(self['contacted']) == 1: - server = next((k for k in self['contacted'].keys()), None) - - if server not in self['contacted']: + if server is None: + if len(contacted) == 1: + server = next((k for k in contacted.keys()), None) + elif contacted: + raise ValueError( + "When contacting multiple servers you need to " + "specify which server's result you want" + ) + elif self.dry_run: + raise ValueError('Results are not available in dry run') + elif (unreachable := self['unreachable']): + raise ValueError( + f"{', '.join(unreachable)} could not be contacted" + ) + + if server not in contacted: + if self.dry_run: + raise ValueError('Results are not available in dry run') raise KeyError(f"{server} could not be contacted") - if key not in self['contacted'][server]: - raise AttributeError + if key not in (result := contacted[server]): + raise AttributeError(key) - return self['contacted'][server][key] + return result[key] diff --git a/tests/test_api.py b/tests/test_api.py index af831c3..b396215 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -81,6 +81,16 @@ def test_results(): result.rc('localhost') +def test_results_dry_run(): + result = Api('localhost', dry_run=True).command('whoami') + assert not result['contacted'] + with pytest.raises(ValueError, match=r'not available in dry run'): + result.rc() + + with pytest.raises(ValueError, match=r'not available in dry run'): + result.rc('localhost') + + @pytest.mark.parametrize("server", ('localhost',)) def test_results_single_server(server): result = Api(server).command('whoami') @@ -92,15 +102,17 @@ def test_results_multiple_servers(): result = RunnerResults({ 'contacted': { 'web.seantis.dev': {'rc': 0}, - 'db.seantis.dev': {'rc': 1} + 'db.seantis.dev': {'rc': 1}, + 'buggy.result.dev': {}, } }) - with pytest.raises(KeyError): - result.rc() - assert result.rc('web.seantis.dev') == 0 assert result.rc('db.seantis.dev') == 1 + with pytest.raises(AttributeError, match=r'rc'): + result.rc('buggy.result.dev') + with pytest.raises(ValueError, match=r'When contacting multiple'): + result.rc() @pytest.mark.parametrize("server", (('localhost', 'localhost:22'),)) @@ -109,6 +121,8 @@ def test_whoami_multiple_servers(server): results = host.command('whoami') assert results.rc(server[0]) == 0 assert results.rc(server[1]) == 0 + with pytest.raises(ValueError, match=r'When contacting multiple'): + results.rc() def test_non_scalar_parameter():