diff --git a/README.md b/README.md index 828026b7..de126309 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # AWS CloudFormation CLI -The CloudFormation CLI (cfn) allows you to author your own resource providers and modules that can be used by CloudFormation. +The CloudFormation CLI (cfn) allows you to author your own resource providers, hooks, and modules that can be used by CloudFormation. ## Usage @@ -12,7 +12,7 @@ Primary documentation for the CloudFormation CLI can be found at the [AWS Docume ### Installation -This tool can be installed using [pip](https://pypi.org/project/pip/) from the Python Package Index (PyPI). It requires Python 3. For resource types, the tool requires at least one language plugin. Language plugins are not needed to create a module type. The language plugins are also available on PyPI and as such can be installed all at once: +This tool can be installed using [pip](https://pypi.org/project/pip/) from the Python Package Index (PyPI). It requires Python 3. For resource and hook types, the tool requires at least one language plugin. Language plugins are not needed to create a module type. The language plugins are also available on PyPI and as such can be installed all at once: ```bash pip install cloudformation-cli cloudformation-cli-java-plugin cloudformation-cli-go-plugin cloudformation-cli-python-plugin cloudformation-cli-typescript-plugin @@ -38,7 +38,7 @@ cfn generate ### Command: submit -To register a resource provider or module in your account, use the `submit` command. +To register a resource provider, module, or hook in your account, use the `submit` command. ```bash cfn submit @@ -65,7 +65,7 @@ Note: To use your type configuration in contract tests, you will need to save yo To validate the schema, use the `validate` command. -This command is automatically run whenever one attempts to submit a resource or module. Errors will prevent you from submitting your resource/module. Module fragments will additionally be validated via [`cfn-lint`](https://github.com/aws-cloudformation/cfn-python-lint/) (but resulting warnings will not cause this step to fail). +This command is automatically run whenever one attempts to submit a resource, module, or hook. Errors will prevent you from submitting your resource/module. Module fragments will additionally be validated via [`cfn-lint`](https://github.com/aws-cloudformation/cfn-python-lint/) (but resulting warnings will not cause this step to fail). ```bash cfn validate @@ -99,7 +99,7 @@ pip install -e . -r requirements.txt pre-commit install ``` -If you're creating a resource type, you will also need to install a language plugin, such as [the Java language plugin](https://github.com/aws-cloudformation/cloudformation-cli-java-plugin), also via `pip install`. For example, assuming the plugin is checked out in the same parent directory as this repository: +If you're creating a resource or hook type, you will also need to install a language plugin, such as [the Java language plugin](https://github.com/aws-cloudformation/cloudformation-cli-java-plugin), also via `pip install`. For example, assuming the plugin is checked out in the same parent directory as this repository: ```bash pip install -e ../cloudformation-cli-java-plugin @@ -128,6 +128,7 @@ Plugins must provide the same interface as `LanguagePlugin` (in `plugin_base.py` ### Supported plugins +#### Resource Types Supported Plugins | Language | Status | Github | PyPI | | -------- | ----------------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------- | | Java | Available | [cloudformation-cli-java-plugin](https://github.com/aws-cloudformation/cloudformation-cli-java-plugin/) | [cloudformation-cli-java-plugin](https://pypi.org/project/cloudformation-cli-java-plugin/) | @@ -135,6 +136,12 @@ Plugins must provide the same interface as `LanguagePlugin` (in `plugin_base.py` | Python | Available | [cloudformation-cli-python-plugin](https://github.com/aws-cloudformation/cloudformation-cli-python-plugin/) | [cloudformation-cli-python-plugin](https://pypi.org/project/cloudformation-cli-python-plugin/) | | TypeScript| Available | [cloudformation-cli-typescript-plugin](https://github.com/aws-cloudformation/cloudformation-cli-typescript-plugin/) | [cloudformation-cli-typescript-plugin](https://pypi.org/project/cloudformation-cli-typescript-plugin/) | +#### Hook Types Supported Plugins +| Language | Status | Github | PyPI | +| -------- | ----------------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------- | +| Java | Available | [cloudformation-cli-java-plugin](https://github.com/aws-cloudformation/cloudformation-cli-java-plugin/) | [cloudformation-cli-java-plugin](https://pypi.org/project/cloudformation-cli-java-plugin/) | +| Python | Available | [cloudformation-cli-python-plugin](https://github.com/aws-cloudformation/cloudformation-cli-python-plugin/) | [cloudformation-cli-python-plugin](https://pypi.org/project/cloudformation-cli-python-plugin/) | + ## License This library is licensed under the Apache 2.0 License. diff --git a/setup.cfg b/setup.cfg index 276d678d..4933275c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,7 +29,7 @@ include_trailing_comma = true combine_as_imports = True force_grid_wrap = 0 known_first_party = rpdk -known_third_party = boto3,botocore,cfn_tools,cfnlint,colorama,docker,hypothesis,jinja2,jsonschema,nested_lookup,ordered_set,pkg_resources,pytest,pytest_localserver,setuptools,yaml +known_third_party = boto3,botocore,cfn_tools,cfnlint,colorama,docker,hypothesis,jinja2,jsonschema,nested_lookup,ordered_set,pkg_resources,pytest,pytest_localserver,requests,setuptools,yaml [tool:pytest] # can't do anything about 3rd part modules, so don't spam us diff --git a/src/rpdk/core/__init__.py b/src/rpdk/core/__init__.py index b7c308d3..a3679d7d 100644 --- a/src/rpdk/core/__init__.py +++ b/src/rpdk/core/__init__.py @@ -1,5 +1,5 @@ import logging -__version__ = "0.2.22" +__version__ = "0.2.23" logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/src/rpdk/core/contract/contract_plugin.py b/src/rpdk/core/contract/contract_plugin.py index 55293168..c1003f4d 100644 --- a/src/rpdk/core/contract/contract_plugin.py +++ b/src/rpdk/core/contract/contract_plugin.py @@ -1,10 +1,37 @@ import pytest +from rpdk.core.contract.hook_client import HookClient + +from .resource_client import ResourceClient + class ContractPlugin: - def __init__(self, resource_client): - self._resource_client = resource_client + def __init__(self, plugin_clients): + if not plugin_clients: + raise RuntimeError("No plugin clients are set up") + + self._plugin_clients = plugin_clients @pytest.fixture(scope="module") def resource_client(self): - return self._resource_client + try: + resource_client = self._plugin_clients["resource_client"] + except KeyError: + resource_client = None + + if not isinstance(resource_client, ResourceClient): + raise ValueError("Contract plugin client not setup for RESOURCE type") + + return resource_client + + @pytest.fixture(scope="module") + def hook_client(self): + try: + hook_client = self._plugin_clients["hook_client"] + except KeyError: + hook_client = None + + if not isinstance(hook_client, HookClient): + raise ValueError("Contract plugin client not setup for HOOK type") + + return hook_client diff --git a/src/rpdk/core/contract/hook_client.py b/src/rpdk/core/contract/hook_client.py new file mode 100644 index 00000000..64d59577 --- /dev/null +++ b/src/rpdk/core/contract/hook_client.py @@ -0,0 +1,444 @@ +# pylint: disable=import-outside-toplevel +# pylint: disable=R0904 +# have to skip B404, import_subprocess is required for executing typescript +# have to skip B60*, to allow typescript code to be executed using subprocess +import json +import logging +import re +import time +from uuid import uuid4 + +import docker +from botocore import UNSIGNED +from botocore.config import Config +from jinja2 import Environment, PackageLoader, select_autoescape + +from rpdk.core.boto_helpers import ( + LOWER_CAMEL_CRED_KEYS, + create_sdk_session, + get_account, + get_temporary_credentials, +) +from rpdk.core.contract.interface import ( + HandlerErrorCode, + HookInvocationPoint, + HookStatus, +) +from rpdk.core.contract.resource_client import override_properties +from rpdk.core.contract.type_configuration import TypeConfiguration +from rpdk.core.exceptions import InvalidProjectError +from rpdk.core.utils.handler_utils import generate_handler_name + +LOG = logging.getLogger(__name__) + + +def override_target_properties(document, overrides): + overridden = dict(document) + for key, value in document.items(): + overridden[key] = override_properties(value, overrides.get(key, {})) + return overridden + + +class HookClient: # pylint: disable=too-many-instance-attributes + def __init__( + self, + function_name, + endpoint, + region, + schema, + overrides, + inputs=None, + role_arn=None, + timeout_in_seconds="60", + type_name=None, + log_group_name=None, + log_role_arn=None, + docker_image=None, + executable_entrypoint=None, + target_info=None, + ): # pylint: disable=too-many-arguments + self._schema = schema + self._session = create_sdk_session(region) + self._role_arn = role_arn + self._type_name = type_name + self._log_group_name = log_group_name + self._log_role_arn = log_role_arn + self.region = region + self.account = get_account( + self._session, + get_temporary_credentials(self._session, LOWER_CAMEL_CRED_KEYS, role_arn), + ) + self._function_name = function_name + if endpoint.startswith("http://"): + self._client = self._session.client( + "lambda", + endpoint_url=endpoint, + use_ssl=False, + verify=False, + config=Config( + signature_version=UNSIGNED, + # needs to be long if docker is running on a slow machine + read_timeout=5 * 60, + retries={"max_attempts": 0}, + region_name=self._session.region_name, + ), + ) + else: + self._client = self._session.client("lambda", endpoint_url=endpoint) + + self._schema = None + self._configuration_schema = None + self._overrides = overrides + self._update_schema(schema) + self._inputs = inputs + self._timeout_in_seconds = int(timeout_in_seconds) + self._docker_image = docker_image + self._docker_client = docker.from_env() if self._docker_image else None + self._executable_entrypoint = executable_entrypoint + self._target_info = self._setup_target_info(target_info) + + @staticmethod + def _setup_target_info(hook_target_info): + if not hook_target_info: + return hook_target_info + + # imported here to avoid hypothesis being loaded before pytest is loaded + from .resource_generator import ResourceGenerator + + target_info = dict(hook_target_info) + for _target, info in target_info.items(): + + # make a copy so the original schema is never modified + target_schema = json.loads(json.dumps(info["Schema"])) + info["SchemaStrategy"] = ResourceGenerator( + target_schema + ).generate_schema_strategy(target_schema) + return target_info + + def _update_schema(self, schema): + # TODO: resolve $ref + self.env = Environment( + trim_blocks=True, + lstrip_blocks=True, + keep_trailing_newline=True, + loader=PackageLoader(__name__, "templates/"), + autoescape=select_autoescape(["html", "htm", "xml", "md"]), + ) + self._schema = schema + self._configuration_schema = schema.get("typeConfiguration") + + def get_hook_type_name(self): + return self._type_name if self._type_name else self._schema["typeName"] + + def get_handler_targets(self, invocation_point): + try: + handlers = self._schema["handlers"] + handler = handlers[generate_handler_name(invocation_point)] + return handler["targetNames"] + except KeyError: + return set() + + @staticmethod + def assert_in_progress(status, response): + assert status == HookStatus.IN_PROGRESS, "status should be IN_PROGRESS" + assert ( + response.get("errorCode", 0) == 0 + ), "IN_PROGRESS events should have no error code set" + assert ( + response.get("result") is None + ), "IN_PROGRESS events should have no result" + + return response.get("callbackDelaySeconds", 0) + + @staticmethod + def assert_success(status, response): + assert status == HookStatus.SUCCESS, "status should be SUCCESS" + assert ( + response.get("errorCode", 0) == 0 + ), "SUCCESS events should have no error code set" + assert ( + response.get("callbackDelaySeconds", 0) == 0 + ), "SUCCESS events should have no callback delay" + + @staticmethod + def assert_failed(status, response): + assert status == HookStatus.FAILED, "status should be FAILED" + assert "errorCode" in response, "FAILED events must have an error code set" + # raises a KeyError if the error code is invalid + error_code = HandlerErrorCode[response["errorCode"]] + assert ( + response.get("callbackDelaySeconds", 0) == 0 + ), "FAILED events should have no callback delay" + assert ( + response.get("message") is not None + ), "FAILED events should have a message" + + return error_code + + @staticmethod + # pylint: disable=R0913 + def make_request( + target_name, + hook_type_name, + account, + invocation_point, + creds, + log_group_name, + log_creds, + token, + target_model, + hook_type_version="00000001", + target_type="RESOURCE", + callback_context=None, + type_configuration=None, + **kwargs, + ): + request_body = { + "requestData": { + "callerCredentials": creds + if isinstance(creds, str) + else json.dumps(creds), + "targetName": target_name, + "targetType": target_type, + "targetLogicalId": token, + "targetModel": target_model, + }, + "requestContext": {"callbackContext": callback_context}, + "hookTypeName": hook_type_name, + "hookTypeVersion": hook_type_version, + "clientRequestToken": token, + "stackId": token, + "awsAccountId": account, + "actionInvocationPoint": invocation_point, + "hookModel": type_configuration, + **kwargs, + } + if log_group_name and log_creds: + request_body["requestData"]["providerLogGroupName"] = log_group_name + request_body["requestData"]["providerCredentials"] = ( + log_creds if isinstance(log_creds, str) else json.dumps(log_creds) + ) + return request_body + + def _generate_target_example(self, target): + if not self._target_info: + return {} + + return self._target_info.get(target).get("SchemaStrategy").example() + + def _generate_target_model(self, target, invocation_point): + if self._inputs: + if "INVALID" in invocation_point: + try: + return self._inputs[invocation_point][target] + except KeyError: + return self._inputs["INVALID"][target] + return self._inputs[invocation_point][target] + + target_model = {"resourceProperties": self._generate_target_example(target)} + if "UPDATE_PRE_PROVISION" in invocation_point: + target_model["previousResourceProperties"] = self._generate_target_example( + target + ) + + if "INVALID" in invocation_point: + try: + return override_target_properties( + target_model, self._overrides[invocation_point].get(target, {}) + ) + except KeyError: + return override_target_properties( + target_model, self._overrides.get("INVALID", {}).get(target, {}) + ) + + return override_target_properties( + target_model, self._overrides.get(invocation_point, {}).get(target, {}) + ) + + def generate_request(self, target, invocation_point): + target_model = self._generate_target_model(target, invocation_point.name) + return self._make_payload(invocation_point, target, target_model) + + def generate_invalid_request(self, target, invocation_point): + target_model = self._generate_target_model( + target, f"INVALID_{invocation_point.name}" + ) + return self._make_payload(invocation_point, target, target_model) + + def generate_request_example(self, target, invocation_point): + request = self.generate_request(target, invocation_point) + target_model = request["requestData"]["targetModel"] + + return invocation_point, target, target_model + + def generate_invalid_request_example(self, target, invocation_point): + request = self.generate_invalid_request(target, invocation_point) + target_model = request["requestData"]["targetModel"] + + return invocation_point, target, target_model + + def generate_request_examples(self, invocation_point): + return [ + self.generate_request_example(target, invocation_point) + for target in self.get_handler_targets(invocation_point) + ] + + def generate_invalid_request_examples(self, invocation_point): + return [ + self.generate_invalid_request_example(target, invocation_point) + for target in self.get_handler_targets(invocation_point) + ] + + def generate_all_request_examples(self): + examples = {} + for invoke_point in HookInvocationPoint: + examples[invoke_point] = self.generate_request_examples(invoke_point) + return examples + + @staticmethod + def generate_token(): + return str(uuid4()) + + @staticmethod + def is_update_invocation_point(invocation_point): + return invocation_point in (HookInvocationPoint.UPDATE_PRE_PROVISION,) + + def assert_time(self, start_time, end_time, action): + timeout_in_seconds = self._timeout_in_seconds + assert end_time - start_time <= timeout_in_seconds, ( + "Handler %r timed out." % action + ) + + def _make_payload( + self, + invocation_point, + target, + target_model, + type_configuration=None, + **kwargs, + ): + return self.make_request( + target, + self.get_hook_type_name(), + self.account, + invocation_point, + get_temporary_credentials( + self._session, LOWER_CAMEL_CRED_KEYS, self._role_arn + ), + self._log_group_name, + get_temporary_credentials( + self._session, LOWER_CAMEL_CRED_KEYS, self._log_role_arn + ), + self.generate_token(), + target_model, + type_configuration=type_configuration, + **kwargs, + ) + + def _call(self, payload): + payload_to_log = { + "hookTypeName": payload["hookTypeName"], + "actionInvocationPoint": payload["actionInvocationPoint"], + "requestData": { + "targetName": payload["requestData"]["targetName"], + "targetLogicalId": payload["requestData"]["targetLogicalId"], + "targetModel": payload["requestData"]["targetModel"], + }, + "awsAccountId": payload["awsAccountId"], + "clientRequestToken": payload["clientRequestToken"], + } + + LOG.debug( + "Sending request\n%s", + json.dumps(payload_to_log, ensure_ascii=False, indent=2), + ) + payload = json.dumps(payload, ensure_ascii=False, indent=2) + if self._docker_image: + if not self._executable_entrypoint: + raise InvalidProjectError( + "executableEntrypoint not set in .rpdk-config. " + "Have you run cfn generate?" + ) + result = ( + self._docker_client.containers.run( + self._docker_image, + self._executable_entrypoint + " '" + payload + "'", + environment={"AWS_REGION": self.region}, + ) + .decode() + .strip() + ) + LOG.debug("=== Handler execution logs ===") + LOG.debug(result) + # pylint: disable=W1401 + regex = "__CFN_HOOK_START_RESPONSE__([\s\S]*)__CFN_HOOK_END_RESPONSE__" # noqa: W605,B950 # pylint: disable=C0301 + payload = json.loads(re.search(regex, result).group(1)) + else: + result = self._client.invoke( + FunctionName=self._function_name, Payload=payload.encode("utf-8") + ) + + try: + payload = json.load(result["Payload"]) + except json.decoder.JSONDecodeError as json_error: + LOG.debug("Received invalid response\n%s", result["Payload"]) + raise ValueError( + "Handler Output is not a valid JSON document" + ) from json_error + + LOG.debug("Received response\n%s", json.dumps(payload, indent=2)) + return payload + + # pylint: disable=R0913 + def call_and_assert( + self, + invocation_point, + assert_status, + target, + target_model, + **kwargs, + ): + if assert_status not in [HookStatus.SUCCESS, HookStatus.FAILED]: + raise ValueError("Assert status {} not supported.".format(assert_status)) + + status, response = self.call(invocation_point, target, target_model, **kwargs) + if assert_status == HookStatus.SUCCESS: + self.assert_success(status, response) + error_code = None + else: + error_code = self.assert_failed(status, response) + return status, response, error_code + + def call( + self, + invocation_point, + target, + target_model, + **kwargs, + ): + request = self._make_payload( + invocation_point, + target, + target_model, + TypeConfiguration.get_hook_configuration(), + **kwargs, + ) + start_time = time.time() + response = self._call(request) + self.assert_time(start_time, time.time(), invocation_point) + + # this throws a KeyError if status isn't present, or if it isn't a valid status + status = HookStatus[response["hookStatus"]] + + while status == HookStatus.IN_PROGRESS: + callback_delay_seconds = self.assert_in_progress(status, response) + time.sleep(callback_delay_seconds) + + request["requestContext"]["callbackContext"] = response.get( + "callbackContext" + ) + + response = self._call(request) + status = HookStatus[response["hookStatus"]] + + return status, response diff --git a/src/rpdk/core/contract/interface.py b/src/rpdk/core/contract/interface.py index 446a4d34..5c42abf7 100644 --- a/src/rpdk/core/contract/interface.py +++ b/src/rpdk/core/contract/interface.py @@ -22,6 +22,18 @@ class OperationStatus(AutoName): FAILED = auto() +class HookInvocationPoint(str, AutoName): + CREATE_PRE_PROVISION = auto() + UPDATE_PRE_PROVISION = auto() + DELETE_PRE_PROVISION = auto() + + +class HookStatus(AutoName): + IN_PROGRESS = auto() + SUCCESS = auto() + FAILED = auto() + + # pylint: disable=invalid-name class HandlerErrorCode(AutoName): NotUpdatable = auto() @@ -39,3 +51,6 @@ class HandlerErrorCode(AutoName): NetworkFailure = auto() InternalFailure = auto() InvalidTypeConfiguration = auto() + HandlerInternalFailure = auto() + NonCompliant = auto() + Unknown = auto() diff --git a/src/rpdk/core/contract/resource_client.py b/src/rpdk/core/contract/resource_client.py index 1eedb027..2ce200e7 100644 --- a/src/rpdk/core/contract/resource_client.py +++ b/src/rpdk/core/contract/resource_client.py @@ -384,6 +384,7 @@ def generate_update_example(self, create_model): ) ) return {**create_model_with_read_only_properties, **update_example} + overrides = self._overrides.get("UPDATE", self._overrides.get("CREATE", {})) example = override_properties(self.update_strategy.example(), overrides) return {**create_model, **example} diff --git a/src/rpdk/core/contract/resource_generator.py b/src/rpdk/core/contract/resource_generator.py index f9da5c5f..7818d708 100644 --- a/src/rpdk/core/contract/resource_generator.py +++ b/src/rpdk/core/contract/resource_generator.py @@ -28,7 +28,8 @@ # Arn is just a placeholder for testing STRING_FORMATS = { - "arn": "^arn:aws(-(cn|gov))?:[a-z-]+:(([a-z]+-)+[0-9])?:([0-9]{12})?:[^.]+$" + "arn": "^arn:aws(-(cn|gov))?:[a-z-]+:(([a-z]+-)+[0-9])?:([0-9]{12})?:[^.]+$", + "uri": "^(https?|ftp|file)://[0-9a-zA-Z]([-.\\w]*[0-9a-zA-Z])(:[0-9]*)*([?/#].*)?$", } NEG_INF = float("-inf") diff --git a/src/rpdk/core/contract/suite/contract_asserts.py b/src/rpdk/core/contract/suite/contract_asserts_commons.py similarity index 59% rename from src/rpdk/core/contract/suite/contract_asserts.py rename to src/rpdk/core/contract/suite/contract_asserts_commons.py index df7507cd..14cfa0e0 100644 --- a/src/rpdk/core/contract/suite/contract_asserts.py +++ b/src/rpdk/core/contract/suite/contract_asserts_commons.py @@ -1,8 +1,6 @@ from functools import wraps from inspect import Parameter, signature -import pytest - from rpdk.core.contract.interface import HandlerErrorCode @@ -72,66 +70,6 @@ def function(*args, **kwargs): return inner_decorator -@decorate() -def response_does_not_contain_write_only_properties(resource_client, response): - resource_client.assert_write_only_property_does_not_exist(response["resourceModel"]) - - -@decorate() -def response_contains_resource_model_equal_updated_model( - response, current_resource_model, update_resource_model -): - assert response["resourceModel"] == { - **current_resource_model, - **update_resource_model, - }, "All properties specified in the update request MUST be present in the \ - model returned, and they MUST match exactly, with the exception of \ - properties defined as writeOnlyProperties in the resource schema" - - -@decorate() -def response_contains_primary_identifier(resource_client, response): - resource_client.assert_primary_identifier( - resource_client.primary_identifier_paths, response["resourceModel"] - ) - - -@decorate() -def response_contains_unchanged_primary_identifier( - resource_client, response, current_resource_model -): - assert resource_client.is_primary_identifier_equal( - resource_client.primary_identifier_paths, - current_resource_model, - response["resourceModel"], - ), "PrimaryIdentifier returned in every progress event must match \ - the primaryIdentifier passed into the request" - - -@decorate(after=False) -def skip_not_writable_identifier(resource_client): - if not resource_client.has_only_writable_identifiers(): - pytest.skip("No writable identifiers. Skipping test.") - - -@decorate(after=False) -def skip_no_tagging(resource_client): - if not resource_client.contains_tagging_metadata(): - pytest.skip("Resource does not contain tagging metadata. Skipping test.") - - -@decorate(after=False) -def skip_not_taggable(resource_client): - if not resource_client.is_taggable(): - pytest.skip("Resource is not taggable. Skipping test.") - - -@decorate(after=False) -def skip_not_tag_updatable(resource_client): - if not resource_client.is_tag_updatable(): - pytest.skip("Resource is not tagUpdatable. Skipping test.") - - def failed_event(error_code, msg=""): def decorator_wrapper(func: object): @wraps(func) diff --git a/src/rpdk/core/contract/suite/hook/__init__.py b/src/rpdk/core/contract/suite/hook/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/rpdk/core/contract/suite/hook/handler_pre_create.py b/src/rpdk/core/contract/suite/hook/handler_pre_create.py new file mode 100644 index 00000000..c7651e39 --- /dev/null +++ b/src/rpdk/core/contract/suite/hook/handler_pre_create.py @@ -0,0 +1,23 @@ +import logging + +import pytest + +from rpdk.core.contract.interface import HookInvocationPoint +from rpdk.core.contract.suite.hook.hook_handler_commons import ( + test_hook_handlers_failed, + test_hook_handlers_success, +) + +LOG = logging.getLogger(__name__) + +INVOCATION_POINT = HookInvocationPoint.CREATE_PRE_PROVISION + + +@pytest.mark.create_pre_provision +def contract_pre_create_success(hook_client): + test_hook_handlers_success(hook_client, INVOCATION_POINT) + + +@pytest.mark.create_pre_provision +def contract_pre_create_failed(hook_client): + test_hook_handlers_failed(hook_client, INVOCATION_POINT) diff --git a/src/rpdk/core/contract/suite/hook/handler_pre_delete.py b/src/rpdk/core/contract/suite/hook/handler_pre_delete.py new file mode 100644 index 00000000..bd8a1725 --- /dev/null +++ b/src/rpdk/core/contract/suite/hook/handler_pre_delete.py @@ -0,0 +1,23 @@ +import logging + +import pytest + +from rpdk.core.contract.interface import HookInvocationPoint +from rpdk.core.contract.suite.hook.hook_handler_commons import ( + test_hook_handlers_failed, + test_hook_handlers_success, +) + +LOG = logging.getLogger(__name__) + +INVOCATION_POINT = HookInvocationPoint.DELETE_PRE_PROVISION + + +@pytest.mark.delete_pre_provision +def contract_pre_delete_success(hook_client): + test_hook_handlers_success(hook_client, INVOCATION_POINT) + + +@pytest.mark.delete_pre_provision +def contract_pre_delete_failed(hook_client): + test_hook_handlers_failed(hook_client, INVOCATION_POINT) diff --git a/src/rpdk/core/contract/suite/hook/handler_pre_update.py b/src/rpdk/core/contract/suite/hook/handler_pre_update.py new file mode 100644 index 00000000..d34aa903 --- /dev/null +++ b/src/rpdk/core/contract/suite/hook/handler_pre_update.py @@ -0,0 +1,23 @@ +import logging + +import pytest + +from rpdk.core.contract.interface import HookInvocationPoint +from rpdk.core.contract.suite.hook.hook_handler_commons import ( + test_hook_handlers_failed, + test_hook_handlers_success, +) + +LOG = logging.getLogger(__name__) + +INVOCATION_POINT = HookInvocationPoint.UPDATE_PRE_PROVISION + + +@pytest.mark.update_pre_provision +def contract_pre_update_success(hook_client): + test_hook_handlers_success(hook_client, INVOCATION_POINT) + + +@pytest.mark.update_pre_provision +def contract_pre_update_failed(hook_client): + test_hook_handlers_failed(hook_client, INVOCATION_POINT) diff --git a/src/rpdk/core/contract/suite/hook/hook_handler_commons.py b/src/rpdk/core/contract/suite/hook/hook_handler_commons.py new file mode 100644 index 00000000..b940d50a --- /dev/null +++ b/src/rpdk/core/contract/suite/hook/hook_handler_commons.py @@ -0,0 +1,68 @@ +import logging + +from rpdk.core.contract.hook_client import HookClient +from rpdk.core.contract.interface import HookStatus + +LOG = logging.getLogger(__name__) + + +def test_hook_success(hook_client, invocation_point, target, target_model): + if HookClient.is_update_invocation_point(invocation_point): + raise ValueError( + "Invocation point {} not supported for this testing operation".format( + invocation_point + ) + ) + + _status, response, _error_code = hook_client.call_and_assert( + invocation_point, HookStatus.SUCCESS, target, target_model + ) + + return response + + +def test_update_hook_success(hook_client, invocation_point, target, target_model): + if not HookClient.is_update_invocation_point(invocation_point): + raise ValueError( + "Invocation point {} not supported for testing UPDATE hook operation".format( + invocation_point + ) + ) + + _status, response, _error_code = hook_client.call_and_assert( + invocation_point, HookStatus.SUCCESS, target, target_model + ) + + return response + + +def test_hook_failed(hook_client, invocation_point, target, target_model=None): + _status, response, error_code = hook_client.call_and_assert( + invocation_point, HookStatus.FAILED, target, target_model + ) + assert response["message"] + return response, error_code + + +def test_hook_handlers_success(hook_client, invocation_point): + is_update_hook = HookClient.is_update_invocation_point(invocation_point) + for ( + _invocation_point, + target, + target_model, + ) in hook_client.generate_request_examples(invocation_point): + if is_update_hook: + test_update_hook_success( + hook_client, invocation_point, target, target_model + ) + else: + test_hook_success(hook_client, invocation_point, target, target_model) + + +def test_hook_handlers_failed(hook_client, invocation_point): + for ( + _invocation_point, + target, + target_model, + ) in hook_client.generate_invalid_request_examples(invocation_point): + test_hook_failed(hook_client, invocation_point, target, target_model) diff --git a/src/rpdk/core/contract/suite/resource/__init__.py b/src/rpdk/core/contract/suite/resource/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/rpdk/core/contract/suite/resource/contract_asserts.py b/src/rpdk/core/contract/suite/resource/contract_asserts.py new file mode 100644 index 00000000..978c7713 --- /dev/null +++ b/src/rpdk/core/contract/suite/resource/contract_asserts.py @@ -0,0 +1,63 @@ +import pytest + +from rpdk.core.contract.suite.contract_asserts_commons import decorate + + +@decorate() +def response_does_not_contain_write_only_properties(resource_client, response): + resource_client.assert_write_only_property_does_not_exist(response["resourceModel"]) + + +@decorate() +def response_contains_resource_model_equal_updated_model( + response, current_resource_model, update_resource_model +): + assert response["resourceModel"] == { + **current_resource_model, + **update_resource_model, + }, "All properties specified in the update request MUST be present in the \ + model returned, and they MUST match exactly, with the exception of \ + properties defined as writeOnlyProperties in the resource schema" + + +@decorate() +def response_contains_primary_identifier(resource_client, response): + resource_client.assert_primary_identifier( + resource_client.primary_identifier_paths, response["resourceModel"] + ) + + +@decorate() +def response_contains_unchanged_primary_identifier( + resource_client, response, current_resource_model +): + assert resource_client.is_primary_identifier_equal( + resource_client.primary_identifier_paths, + current_resource_model, + response["resourceModel"], + ), "PrimaryIdentifier returned in every progress event must match \ + the primaryIdentifier passed into the request" + + +@decorate(after=False) +def skip_not_writable_identifier(resource_client): + if not resource_client.has_only_writable_identifiers(): + pytest.skip("No writable identifiers. Skipping test.") + + +@decorate(after=False) +def skip_no_tagging(resource_client): + if not resource_client.contains_tagging_metadata(): + pytest.skip("Resource does not contain tagging metadata. Skipping test.") + + +@decorate(after=False) +def skip_not_taggable(resource_client): + if not resource_client.is_taggable(): + pytest.skip("Resource is not taggable. Skipping test.") + + +@decorate(after=False) +def skip_not_tag_updatable(resource_client): + if not resource_client.is_tag_updatable(): + pytest.skip("Resource is not tagUpdatable. Skipping test.") diff --git a/src/rpdk/core/contract/suite/handler_commons.py b/src/rpdk/core/contract/suite/resource/handler_commons.py similarity index 97% rename from src/rpdk/core/contract/suite/handler_commons.py rename to src/rpdk/core/contract/suite/resource/handler_commons.py index 492cd336..7e1b45f3 100644 --- a/src/rpdk/core/contract/suite/handler_commons.py +++ b/src/rpdk/core/contract/suite/resource/handler_commons.py @@ -6,8 +6,8 @@ prune_properties_from_model, prune_properties_if_not_exist_in_path, ) -from rpdk.core.contract.suite.contract_asserts import ( - failed_event, +from rpdk.core.contract.suite.contract_asserts_commons import failed_event +from rpdk.core.contract.suite.resource.contract_asserts import ( response_contains_primary_identifier, response_contains_resource_model_equal_updated_model, response_contains_unchanged_primary_identifier, diff --git a/src/rpdk/core/contract/suite/handler_create.py b/src/rpdk/core/contract/suite/resource/handler_create.py similarity index 95% rename from src/rpdk/core/contract/suite/handler_create.py rename to src/rpdk/core/contract/suite/resource/handler_create.py index fc097713..979b9c15 100644 --- a/src/rpdk/core/contract/suite/handler_create.py +++ b/src/rpdk/core/contract/suite/resource/handler_create.py @@ -7,12 +7,12 @@ # WARNING: contract tests should use fully qualified imports to avoid issues # when being loaded by pytest from rpdk.core.contract.interface import Action, OperationStatus -from rpdk.core.contract.suite.contract_asserts import ( +from rpdk.core.contract.suite.resource.contract_asserts import ( skip_no_tagging, skip_not_taggable, skip_not_writable_identifier, ) -from rpdk.core.contract.suite.handler_commons import ( +from rpdk.core.contract.suite.resource.handler_commons import ( test_create_failure_if_repeat_writeable_id, test_create_success, test_delete_success, diff --git a/src/rpdk/core/contract/suite/handler_delete.py b/src/rpdk/core/contract/suite/resource/handler_delete.py similarity index 98% rename from src/rpdk/core/contract/suite/handler_delete.py rename to src/rpdk/core/contract/suite/resource/handler_delete.py index 34635cac..6ea3c0c0 100644 --- a/src/rpdk/core/contract/suite/handler_delete.py +++ b/src/rpdk/core/contract/suite/resource/handler_delete.py @@ -7,7 +7,7 @@ # WARNING: contract tests should use fully qualified imports to avoid issues # when being loaded by pytest from rpdk.core.contract.interface import Action, HandlerErrorCode, OperationStatus -from rpdk.core.contract.suite.handler_commons import ( +from rpdk.core.contract.suite.resource.handler_commons import ( test_create_success, test_delete_failure_not_found, test_input_equals_output, diff --git a/src/rpdk/core/contract/suite/handler_misc.py b/src/rpdk/core/contract/suite/resource/handler_misc.py similarity index 100% rename from src/rpdk/core/contract/suite/handler_misc.py rename to src/rpdk/core/contract/suite/resource/handler_misc.py diff --git a/src/rpdk/core/contract/suite/handler_update.py b/src/rpdk/core/contract/suite/resource/handler_update.py similarity index 96% rename from src/rpdk/core/contract/suite/handler_update.py rename to src/rpdk/core/contract/suite/resource/handler_update.py index 5bfcff38..0a5df956 100644 --- a/src/rpdk/core/contract/suite/handler_update.py +++ b/src/rpdk/core/contract/suite/resource/handler_update.py @@ -6,11 +6,11 @@ # WARNING: contract tests should use fully qualified imports to avoid issues # when being loaded by pytest from rpdk.core.contract.interface import Action, OperationStatus -from rpdk.core.contract.suite.contract_asserts import ( +from rpdk.core.contract.suite.resource.contract_asserts import ( skip_no_tagging, skip_not_tag_updatable, ) -from rpdk.core.contract.suite.handler_commons import ( +from rpdk.core.contract.suite.resource.handler_commons import ( test_input_equals_output, test_model_in_list, test_read_success, diff --git a/src/rpdk/core/contract/suite/handler_update_invalid.py b/src/rpdk/core/contract/suite/resource/handler_update_invalid.py similarity index 93% rename from src/rpdk/core/contract/suite/handler_update_invalid.py rename to src/rpdk/core/contract/suite/resource/handler_update_invalid.py index ed4fe068..6a0ca101 100644 --- a/src/rpdk/core/contract/suite/handler_update_invalid.py +++ b/src/rpdk/core/contract/suite/resource/handler_update_invalid.py @@ -6,7 +6,7 @@ # WARNING: contract tests should use fully qualified imports to avoid issues # when being loaded by pytest from rpdk.core.contract.interface import Action, HandlerErrorCode, OperationStatus -from rpdk.core.contract.suite.contract_asserts import failed_event +from rpdk.core.contract.suite.contract_asserts_commons import failed_event @pytest.mark.update diff --git a/src/rpdk/core/contract/type_configuration.py b/src/rpdk/core/contract/type_configuration.py index 3dafb8ed..9f61a995 100644 --- a/src/rpdk/core/contract/type_configuration.py +++ b/src/rpdk/core/contract/type_configuration.py @@ -38,3 +38,16 @@ def get_type_configuration(): TYPE_CONFIGURATION_FILE_PATH, ) return TypeConfiguration.TYPE_CONFIGURATION + + @staticmethod + def get_hook_configuration(): + type_configuration = TypeConfiguration.get_type_configuration() + if type_configuration: + try: + return type_configuration.get("CloudFormationConfiguration", {})[ + "HookConfiguration" + ]["Properties"] + except KeyError as e: + LOG.warning("Hook configuration is invalid") + raise InvalidProjectError("Hook configuration is invalid") from e + return type_configuration diff --git a/src/rpdk/core/data/examples/hook/sse.verification.v1.json b/src/rpdk/core/data/examples/hook/sse.verification.v1.json new file mode 100644 index 00000000..513bc732 --- /dev/null +++ b/src/rpdk/core/data/examples/hook/sse.verification.v1.json @@ -0,0 +1,38 @@ +{ + "typeName": "AWS::Example::SSEVerificationHook", + "description": "Example resource SSE (Server Side Encryption) verification hook", + "sourceUrl": "https://github.com/aws-cloudformation/example-sse-hook", + "documentationUrl": "https://github.com/aws-cloudformation/example-sse-hook/blob/master/README.md", + "typeConfiguration": { + "properties": { + "EncryptionAlgorithm": { + "description": "Encryption algorithm for SSE", + "default": "AES256", + "type": "string" + } + }, + "additionalProperties": false + }, + "required": [], + "handlers": { + "preCreate": { + "targetNames": [ + "My::Example::Resource" + ], + "permissions": [] + }, + "preUpdate": { + "targetNames": [ + "My::Example::Resource" + ], + "permissions": [] + }, + "preDelete": { + "targetNames": [ + "Other::Example::Resource" + ], + "permissions": [] + } + }, + "additionalProperties": false +} diff --git a/src/rpdk/core/data/examples/resource/initech.tps.report.v1.json b/src/rpdk/core/data/examples/resource/initech.tps.report.v1.json index 6a3ca511..84199ac2 100644 --- a/src/rpdk/core/data/examples/resource/initech.tps.report.v1.json +++ b/src/rpdk/core/data/examples/resource/initech.tps.report.v1.json @@ -42,7 +42,7 @@ "Value" ], "additionalProperties": false - } + } }, "properties": { "TPSCode": { @@ -94,7 +94,7 @@ "items": { "$ref": "#/definitions/Tag" } - } + } }, "additionalProperties": false, "required": [ diff --git a/src/rpdk/core/data/managed-upload-infrastructure.yaml b/src/rpdk/core/data/managed-upload-infrastructure.yaml index d2068f9d..a8d9e830 100644 --- a/src/rpdk/core/data/managed-upload-infrastructure.yaml +++ b/src/rpdk/core/data/managed-upload-infrastructure.yaml @@ -90,6 +90,7 @@ Resources: Principal: Service: - resources.cloudformation.amazonaws.com + - hooks.cloudformation.amazonaws.com Action: sts:AssumeRole Condition: StringEquals: diff --git a/src/rpdk/core/data/pytest-contract.ini b/src/rpdk/core/data/pytest-contract.ini index fe0198fd..cdc322ad 100644 --- a/src/rpdk/core/data/pytest-contract.ini +++ b/src/rpdk/core/data/pytest-contract.ini @@ -14,3 +14,9 @@ markers = read: read handler related tests. update: update handler related tests. list: list handler related tests. + create_pre_provision: preCreate handler related tests. + update_pre_provision: preUpdate handler related tests. + delete_pre_provision: preDelete handler related tests. + +filterwarnings = + ignore::hypothesis.errors.NonInteractiveExampleWarning:hypothesis diff --git a/src/rpdk/core/data/schema/provider.configuration.definition.schema.hooks.v1.json b/src/rpdk/core/data/schema/provider.configuration.definition.schema.hooks.v1.json new file mode 100644 index 00000000..4ea88ddd --- /dev/null +++ b/src/rpdk/core/data/schema/provider.configuration.definition.schema.hooks.v1.json @@ -0,0 +1,49 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "$id": "https://schema.cloudformation.us-east-1.amazonaws.com/provider.configuration.definition.schema.hooks.v1.json", + "title": "CloudFormation Hook Provider Configuration Definition MetaSchema", + "description": "This schema validates a CloudFormation hook provider configuration definition.", + "type": "object", + "properties": { + "additionalProperties": { + "$comment": "All properties must be expressed in the schema - arbitrary inputs are not allowed", + "type": "boolean", + "const": false + }, + "deprecatedProperties": { + "$ref": "base.definition.schema.v1.json#/properties/deprecatedProperties" + }, + "allOf": { + "$ref": "base.definition.schema.v1.json#/definitions/schemaArray" + }, + "anyOf": { + "$ref": "base.definition.schema.v1.json#/definitions/schemaArray" + }, + "oneOf": { + "$ref": "base.definition.schema.v1.json#/definitions/schemaArray" + }, + "required": { + "$ref": "base.definition.schema.v1.json#/properties/required" + }, + "description": { + "$comment": "A short description of the hook configuration. This will be shown in the AWS CloudFormation console.", + "$ref": "base.definition.schema.v1.json#/properties/description" + }, + "properties": { + "type": "object", + "patternProperties": { + "(?!CloudFormation)^[A-Za-z0-9]{1,64}$": { + "$comment": "TypeConfiguration properties starting with `CloudFormation` are reserved for CloudFormation use", + "$ref": "base.definition.schema.v1.json#/definitions/properties" + } + }, + "minProperties": 0, + "additionalProperties": false + } + }, + "required": [ + "properties", + "additionalProperties" + ], + "additionalProperties": false +} diff --git a/src/rpdk/core/data/schema/provider.definition.schema.hooks.v1.json b/src/rpdk/core/data/schema/provider.definition.schema.hooks.v1.json new file mode 100644 index 00000000..6763d3b5 --- /dev/null +++ b/src/rpdk/core/data/schema/provider.definition.schema.hooks.v1.json @@ -0,0 +1,136 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "$id": "https://schema.cloudformation.us-east-1.amazonaws.com/provider.definition.schema.hooks.v1.json", + "title": "CloudFormation Hook Provider Definition MetaSchema", + "description": "This schema validates a CloudFormation hook provider definition.", + "definitions": { + "handlerDefinition": { + "description": "Defines any execution operations which can be performed on this hook provider", + "type": "object", + "properties": { + "targetNames": { + "type": "array", + "items": { + "type": "string" + }, + "additionalItems": false + }, + "permissions": { + "type": "array", + "items": { + "type": "string" + }, + "additionalItems": false + } + }, + "additionalProperties": false, + "required": [ + "targetNames", + "permissions" + ] + } + }, + "type": "object", + "patternProperties": { + "^\\$id$": { + "$ref": "http://json-schema.org/draft-07/schema#/properties/$id" + } + }, + "properties": { + "$schema": { + "$ref": "base.definition.schema.v1.json#/properties/$schema" + }, + "type": { + "$comment": "Hook Type", + "type": "string", + "const": "HOOK" + }, + "typeName": { + "$comment": "Hook Type Identifier", + "examples": [ + "Organization::Service::Hook", + "AWS::EC2::Hook", + "Initech::TPS::Hook" + ], + "$ref": "base.definition.schema.v1.json#/properties/typeName" + }, + "$comment": { + "$ref": "base.definition.schema.v1.json#/properties/$comment" + }, + "title": { + "$ref": "base.definition.schema.v1.json#/properties/title" + }, + "description": { + "$comment": "A short description of the hook provider. This will be shown in the AWS CloudFormation console.", + "$ref": "base.definition.schema.v1.json#/properties/description" + }, + "sourceUrl": { + "$comment": "The location of the source code for this hook provider, to help interested parties submit issues or improvements.", + "examples": [ + "https://github.com/aws-cloudformation/aws-cloudformation-resource-providers-s3" + ], + "$ref": "base.definition.schema.v1.json#/properties/sourceUrl" + }, + "documentationUrl": { + "$comment": "A page with supplemental documentation. The property documentation in schemas should be able to stand alone, but this is an opportunity for e.g. rich examples or more guided documents.", + "examples": [ + "https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/CHAP_Using.html" + ], + "$ref": "base.definition.schema.v1.json#/definitions/httpsUrl" + }, + "additionalProperties": { + "$comment": "All properties of a hook must be expressed in the schema - arbitrary inputs are not allowed", + "$ref": "base.definition.schema.v1.json#/properties/additionalProperties" + }, + "definitions": { + "$ref": "base.definition.schema.v1.json#/properties/definitions" + }, + "handlers": { + "description": "Defines the provisioning operations which can be performed on this hook type", + "type": "object", + "properties": { + "preCreate": { + "$ref": "#/definitions/handlerDefinition" + }, + "preUpdate": { + "$ref": "#/definitions/handlerDefinition" + }, + "preDelete": { + "$ref": "#/definitions/handlerDefinition" + } + }, + "additionalProperties": false + }, + "remote": { + "description": "Reserved for CloudFormation use. A namespace to inline remote schemas.", + "$ref": "base.definition.schema.v1.json#/properties/remote" + }, + "deprecatedProperties": { + "description": "A list of JSON pointers to properties that have been deprecated by the underlying service provider. These properties are still accepted in create & update operations, however they may be ignored, or converted to a consistent model on application. Deprecated properties are not guaranteed to be present in read paths.", + "$ref": "base.definition.schema.v1.json#/definitions/jsonPointerArray" + }, + "required": { + "$ref": "base.definition.schema.v1.json#/properties/required" + }, + "allOf": { + "$ref": "base.definition.schema.v1.json#/definitions/schemaArray" + }, + "anyOf": { + "$ref": "base.definition.schema.v1.json#/definitions/schemaArray" + }, + "oneOf": { + "$ref": "base.definition.schema.v1.json#/definitions/schemaArray" + }, + "typeConfiguration": { + "description": "TypeConfiguration to set the configuration data for registry types. This configuration data is not passed through the hook properties in template. One of the possible use cases is configuring auth keys for 3P hook providers.", + "$ref": "provider.configuration.definition.schema.hooks.v1.json" + } + }, + "required": [ + "typeName", + "typeConfiguration", + "description", + "additionalProperties" + ], + "additionalProperties": false +} diff --git a/src/rpdk/core/data_loaders.py b/src/rpdk/core/data_loaders.py index 0c797773..069cfada 100644 --- a/src/rpdk/core/data_loaders.py +++ b/src/rpdk/core/data_loaders.py @@ -118,6 +118,13 @@ def make_resource_validator_with_additional_properties_check(): return make_validator(schema) +def make_hook_validator(): + schema = resource_json( + __name__, "data/schema/provider.definition.schema.hooks.v1.json" + ) + return make_validator(schema) + + def get_file_base_uri(file): try: name = file.name @@ -360,3 +367,53 @@ def load_resource_spec(resource_spec_file): # pylint: disable=R # noqa: C901 raise InternalError() from e return inlined + + +def load_hook_spec(hook_spec_file): # pylint: disable=R # noqa: C901 + """Load a hook definition from a file, and validate it.""" + try: + hook_spec = json.load(hook_spec_file) + except ValueError as e: + LOG.debug("Hook spec decode failed", exc_info=True) + raise SpecValidationError(str(e)) from e + + # TODO: Add schema validation after we have hook schema finalized + + if hook_spec.get("properties"): + raise SpecValidationError( + "Hook types do not support 'properties' directly. Properties must be specified in the 'typeConfiguration' section." + ) + + validator = make_hook_validator() + try: + validator.validate(hook_spec) + except ValidationError as e: + LOG.debug("Hook spec validation failed", exc_info=True) + raise SpecValidationError(str(e)) from e + + blocked_handler_permissions = {"cloudformation:RegisterType"} + for handler in hook_spec.get("handlers", []): + for permission in hook_spec.get("handlers", [])[handler]["permissions"]: + if "cloudformation:*" in permission: + raise SpecValidationError( + f"Wildcards for cloudformation are not allowed for hook handler permissions: '{permission}'" + ) + + if permission in blocked_handler_permissions: + raise SpecValidationError( + f"Permission is not allowed for hook handler permissions: '{permission}'" + ) + + try: + base_uri = hook_spec["$id"] + except KeyError: + base_uri = get_file_base_uri(hook_spec_file) + + inliner = RefInliner(base_uri, hook_spec) + try: + inlined = inliner.inline() + except RefResolutionError as e: + LOG.debug("Hook spec validation failed", exc_info=True) + raise SpecValidationError(str(e)) from e + + return inlined diff --git a/src/rpdk/core/generate.py b/src/rpdk/core/generate.py index bc9a0f64..69a69fe1 100644 --- a/src/rpdk/core/generate.py +++ b/src/rpdk/core/generate.py @@ -9,10 +9,10 @@ LOG = logging.getLogger(__name__) -def generate(_args): +def generate(args): project = Project() project.load() - project.generate() + project.generate(args.endpoint_url, args.region, args.target_schemas) project.generate_docs() LOG.warning("Generated files for %s", project.type_name) @@ -22,3 +22,9 @@ def setup_subparser(subparsers, parents): # see docstring of this file parser = subparsers.add_parser("generate", description=__doc__, parents=parents) parser.set_defaults(command=generate) + + parser.add_argument("--endpoint-url", help="CloudFormation endpoint to use.") + parser.add_argument("--region", help="AWS Region to submit the type.") + parser.add_argument( + "--target-schemas", help="Path to target schemas.", nargs="*", default=[] + ) diff --git a/src/rpdk/core/hook/__init__.py b/src/rpdk/core/hook/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/rpdk/core/hook/init_hook.py b/src/rpdk/core/hook/init_hook.py new file mode 100644 index 00000000..55dee3d2 --- /dev/null +++ b/src/rpdk/core/hook/init_hook.py @@ -0,0 +1,102 @@ +import logging +import re + +from rpdk.core.exceptions import WizardAbortError, WizardValidationError +from rpdk.core.plugin_registry import get_plugin_choices +from rpdk.core.utils.init_utils import input_with_validation, print_error + +LOG = logging.getLogger(__name__) +HOOK_TYPE_NAME_REGEX = r"^[a-zA-Z0-9]{2,64}::[a-zA-Z0-9]{2,64}::[a-zA-Z0-9]{2,64}$" + + +def init_hook(args, project): + if args.type_name: + try: + type_name = validate_type_name(args.type_name) + except WizardValidationError as error: + print_error(error) + type_name = input_typename() + else: + type_name = input_typename() + + if "language" in vars(args): + language = args.language.lower() + else: + language = input_language() + + settings = { + arg: getattr(args, arg) + for arg in vars(args) + if not callable(getattr(args, arg)) + } + + project.init_hook(type_name, language, settings) + project.generate(args.endpoint_url, args.region, args.target_schemas) + project.generate_docs() + + +def input_typename(): + type_name = input_with_validation( + "What's the name of your hook type?", + validate_type_name, + "\n(Organization::Service::Hook)", + ) + LOG.debug("Hook type identifier: %s", type_name) + return type_name + + +def input_language(): + # language/plugin + if validate_plugin_choice.max < 1: + LOG.critical("No language plugins found") + raise WizardAbortError() + + if validate_plugin_choice.max == 1: + language = validate_plugin_choice.choices[0] + LOG.warning("One language plugin found, defaulting to %s", language) + else: + language = input_with_validation( + validate_plugin_choice.message, validate_plugin_choice + ) + LOG.debug("Language plugin: %s", language) + return language + + +def validate_type_name(value): + match = re.match(HOOK_TYPE_NAME_REGEX, value) + if match: + return value + LOG.debug("'%s' did not match '%s'", value, HOOK_TYPE_NAME_REGEX) + raise WizardValidationError( + "Please enter a value matching '{}'".format(HOOK_TYPE_NAME_REGEX) + ) + + +class ValidatePluginChoice: + def __init__(self, choices): + self.choices = tuple(choices) + self.max = len(self.choices) + + pretty = "\n".join( + "[{}] {}".format(i, choice) for i, choice in enumerate(self.choices, 1) + ) + self.message = ( + "Select a language for code generation:\n" + + pretty + + "\n(enter an integer): " + ) + + def __call__(self, value): + try: + choice = int(value) + except ValueError as e: + raise WizardValidationError("Please enter an integer") from e + choice -= 1 + if choice < 0 or choice >= self.max: + raise WizardValidationError("Please select a choice") + return self.choices[choice] + + +validate_plugin_choice = ValidatePluginChoice( # pylint: disable=invalid-name + get_plugin_choices() +) diff --git a/src/rpdk/core/init.py b/src/rpdk/core/init.py index 7742df50..ca00d77d 100644 --- a/src/rpdk/core/init.py +++ b/src/rpdk/core/init.py @@ -9,9 +9,10 @@ from colorama import Fore, Style from .exceptions import WizardAbortError, WizardValidationError +from .hook.init_hook import init_hook from .module.init_module import init_module from .plugin_registry import get_parsers, get_plugin_choices -from .project import ARTIFACT_TYPE_MODULE, Project +from .project import ARTIFACT_TYPE_HOOK, ARTIFACT_TYPE_MODULE, Project from .resource.init_resource import init_resource from .utils.init_utils import init_artifact_type, validate_yes @@ -137,7 +138,9 @@ def init(args): if artifact_type == ARTIFACT_TYPE_MODULE: init_module(args, project) - # artifact type can only be module or resource at this point + elif artifact_type == ARTIFACT_TYPE_HOOK: + init_hook(args, project) + # artifact type can only be module, hook, or resource at this point else: init_resource(args, project) @@ -183,5 +186,16 @@ def setup_subparser(subparsers, parents): parser.add_argument( "-a", "--artifact-type", - help="Select the type of artifact (RESOURCE or MODULE)", + help="Select the type of artifact (RESOURCE or MODULE or HOOK)", + ) + + parser.add_argument("--endpoint-url", help="CloudFormation endpoint to use.") + + parser.add_argument("--region", help="AWS Region to submit the type.") + + parser.add_argument( + "--target-schemas", + help="Path to target schemas.", + default=[], + type=lambda s: [i.strip() for i in s.split(",")], ) diff --git a/src/rpdk/core/invoke.py b/src/rpdk/core/invoke.py index f8a7d329..4f99835c 100644 --- a/src/rpdk/core/invoke.py +++ b/src/rpdk/core/invoke.py @@ -9,21 +9,58 @@ from argparse import FileType from time import sleep -from .contract.interface import Action, OperationStatus +from .contract.hook_client import HookClient +from .contract.interface import Action, HookInvocationPoint, HookStatus, OperationStatus from .contract.resource_client import ResourceClient from .exceptions import SysExitRecommendedError -from .project import Project +from .project import ARTIFACT_TYPE_HOOK, ARTIFACT_TYPE_RESOURCE, Project from .test import _sam_arguments, _validate_sam_args LOG = logging.getLogger(__name__) -def invoke(args): - _validate_sam_args(args) - project = Project() - project.load() - - client = ResourceClient( +def get_payload_to_log(payload, artifact_type): + if artifact_type == ARTIFACT_TYPE_HOOK: + return { + "hookTypeName": payload["hookTypeName"], + "actionInvocationPoint": payload["actionInvocationPoint"], + "requestData": { + "targetName": payload["requestData"]["targetName"], + "targetLogicalId": payload["requestData"]["targetLogicalId"], + "targetModel": payload["requestData"]["targetModel"], + }, + "awsAccountId": payload["awsAccountId"], + "clientRequestToken": payload["clientRequestToken"], + } + + return { + "callbackContext": payload["callbackContext"], + "action": payload["action"], + "requestData": { + "resourceProperties": payload["requestData"]["resourceProperties"], + "previousResourceProperties": payload["requestData"][ + "previousResourceProperties" + ], + }, + "region": payload["region"], + "awsAccountId": payload["awsAccountId"], + "bearerToken": payload["bearerToken"], + } + + +def get_contract_client(args, project): + if project.artifact_type == ARTIFACT_TYPE_HOOK: + return HookClient( + args.function_name, + args.endpoint, + args.region, + project.schema, + {}, + executable_entrypoint=project.executable_entrypoint, + docker_image=args.docker_image, + ) + + return ResourceClient( args.function_name, args.endpoint, args.region, @@ -33,17 +70,45 @@ def invoke(args): docker_image=args.docker_image, ) - action = Action[args.action] + +def prepare_payload_for_reinvocation(payload, response, artifact_type): + if artifact_type == ARTIFACT_TYPE_RESOURCE: + payload["callbackContext"] = response.get("callbackContext") + + return payload + + +def invoke(args): + _validate_sam_args(args) + project = Project() + project.load() + + client = get_contract_client(args, project) + try: request = json.load(args.request) except ValueError as e: raise SysExitRecommendedError(f"Invalid JSON: {e}") from e - payload = client._make_payload( - action, - request["desiredResourceState"], - request["previousResourceState"], - request.get("typeConfiguration"), - ) + + if project.artifact_type == ARTIFACT_TYPE_HOOK: + status_type = HookStatus + in_progress_status = HookStatus.IN_PROGRESS + action_invocation_point = HookInvocationPoint[args.action_invocation_point] + payload = client._make_payload( + action_invocation_point, + request["targetName"], + request["targetModel"], + ) + else: + status_type = OperationStatus + in_progress_status = OperationStatus.IN_PROGRESS + action = Action[args.action] + payload = client._make_payload( + action, + request["desiredResourceState"], + request["previousResourceState"], + request.get("typeConfiguration"), + ) # pylint: disable=too-many-function-args current_invocation = 0 @@ -51,32 +116,20 @@ def invoke(args): while _needs_reinvocation(args.max_reinvoke, current_invocation): print("=== Handler input ===") - payload_to_log = { - "callbackContext": payload["callbackContext"], - "action": payload["action"], - "requestData": { - "resourceProperties": payload["requestData"]["resourceProperties"], - "previousResourceProperties": payload["requestData"][ - "previousResourceProperties" - ], - }, - "region": payload["region"], - "awsAccountId": payload["awsAccountId"], - "bearerToken": payload["bearerToken"], - } + payload_to_log = get_payload_to_log(payload, project.artifact_type) print(json.dumps({**payload_to_log}, indent=2)) response = client._call(payload) current_invocation = current_invocation + 1 print("=== Handler response ===") print(json.dumps(response, indent=2)) - status = OperationStatus[response["status"]] + status = status_type[response["status"]] - if status != OperationStatus.IN_PROGRESS: + if status != in_progress_status: break sleep(response.get("callbackDelaySeconds", 0)) - payload["callbackContext"] = response.get("callbackContext") + prepare_payload_for_reinvocation(payload, response, project.artifact_type) except KeyboardInterrupt: pass @@ -85,21 +138,13 @@ def _needs_reinvocation(max_reinvoke, current_invocation): return max_reinvoke is None or max_reinvoke >= current_invocation -def setup_subparser(subparsers, parents): - # see docstring of this file - parser = subparsers.add_parser("invoke", description=__doc__, parents=parents) - parser.set_defaults(command=invoke) - parser.add_argument( - "action", - choices=list(Action.__members__), - help="The provisioning action, i.e. which handler to invoke.", - ) - parser.add_argument( +def _setup_invoke_subparser(subparser): + subparser.add_argument( "request", type=FileType("r", encoding="utf-8"), help="A JSON file that contains the request to invoke the function with.", ) - parser.add_argument( + subparser.add_argument( "--max-reinvoke", type=int, default=None, @@ -107,9 +152,34 @@ def setup_subparser(subparsers, parents): "exiting. If not specified, will continue to " "re-invoke until terminal status is reached.", ) - parser.add_argument( + subparser.add_argument( "--docker-image", help="Docker image name to run. If specified, invoke will use docker instead " "of SAM", ) + + +def setup_subparser(subparsers, parents): + # see docstring of this file + parser = subparsers.add_parser("invoke", description=__doc__, parents=parents) + parser.set_defaults(command=invoke) + + invoke_subparsers = parser.add_subparsers(dest="subparser_name") + invoke_subparsers.required = True + resource_parser = invoke_subparsers.add_parser("resource", description=__doc__) + resource_parser.add_argument( + "action", + choices=list(Action.__members__), + help="The provisioning action, i.e. which resource handler to invoke.", + ) + _setup_invoke_subparser(resource_parser) + + hook_parser = invoke_subparsers.add_parser("hook", description=__doc__) + hook_parser.add_argument( + "action_invocation_point", + choices=list(HookInvocationPoint.__members__), + help="The provisioning action invocation point, i.e. which hook handler to invoke.", + ) + _setup_invoke_subparser(hook_parser) + _sam_arguments(parser) diff --git a/src/rpdk/core/project.py b/src/rpdk/core/project.py index e2b62c7e..004231e2 100644 --- a/src/rpdk/core/project.py +++ b/src/rpdk/core/project.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-lines import json import logging import os @@ -14,10 +15,16 @@ from rpdk.core.fragment.generator import TemplateFragment from rpdk.core.jsonutils.flattener import JsonSchemaFlattener +from rpdk.core.type_schema_loader import TypeSchemaLoader from . import __version__ from .boto_helpers import create_sdk_session -from .data_loaders import load_resource_spec, resource_json +from .data_loaders import ( + load_hook_spec, + load_resource_spec, + make_resource_validator, + resource_json, +) from .exceptions import ( DownstreamError, FragmentValidationError, @@ -37,15 +44,20 @@ SCHEMA_UPLOAD_FILENAME = "schema.json" CONFIGURATION_SCHEMA_UPLOAD_FILENAME = "configuration-schema.json" OVERRIDES_FILENAME = "overrides.json" +TARGET_INFO_FILENAME = "target-info.json" INPUTS_FOLDER = "inputs" EXAMPLE_INPUTS_FOLDER = "example_inputs" -ROLE_TEMPLATE_FILENAME = "resource-role.yaml" +TARGET_SCHEMAS_FOLDER = "target-schemas" +HOOK_ROLE_TEMPLATE_FILENAME = "hook-role.yaml" +RESOURCE_ROLE_TEMPLATE_FILENAME = "resource-role.yaml" TYPE_NAME_RESOURCE_REGEX = "^[a-zA-Z0-9]{2,64}::[a-zA-Z0-9]{2,64}::[a-zA-Z0-9]{2,64}$" TYPE_NAME_MODULE_REGEX = ( "^[a-zA-Z0-9]{2,64}::[a-zA-Z0-9]{2,64}::[a-zA-Z0-9]{2,64}::MODULE$" ) +TYPE_NAME_HOOK_REGEX = "^[a-zA-Z0-9]{2,64}::[a-zA-Z0-9]{2,64}::[a-zA-Z0-9]{2,64}$" ARTIFACT_TYPE_RESOURCE = "RESOURCE" ARTIFACT_TYPE_MODULE = "MODULE" +ARTIFACT_TYPE_HOOK = "HOOK" DEFAULT_ROLE_TIMEOUT_MINUTES = 120 # 2 hours # min and max are according to CreateRole API restrictions @@ -96,6 +108,20 @@ } ) +HOOK_SETTINGS_VALIDATOR = Draft7Validator( + { + "properties": { + "artifact_type": {"type": "string"}, + "language": {"type": "string"}, + "typeName": {"type": "string", "pattern": TYPE_NAME_HOOK_REGEX}, + "runtime": {"type": "string", "enum": list(LAMBDA_RUNTIMES)}, + "entrypoint": {"type": ["string", "null"]}, + "testEntrypoint": {"type": ["string", "null"]}, + "settings": {"type": "object"}, + }, + "required": ["language", "typeName", "runtime", "entrypoint"], + } +) BASIC_TYPE_MAPPINGS = { "string": "String", @@ -135,6 +161,7 @@ def __init__(self, overwrite_enabled=False, root=None): self.test_entrypoint = None self.executable_entrypoint = None self.fragment_dir = None + self.target_info = {} self.env = Environment( trim_blocks=True, @@ -188,6 +215,14 @@ def inputs_path(self): def example_inputs_path(self): return self.root / EXAMPLE_INPUTS_FOLDER + @property + def target_schemas_path(self): + return self.root / TARGET_SCHEMAS_FOLDER + + @property + def target_info_path(self): + return self.root / TARGET_INFO_FILENAME + @staticmethod def _raise_invalid_project(msg, e): LOG.debug(msg, exc_info=e) @@ -209,9 +244,28 @@ def load_settings(self): if raw_settings["artifact_type"] == ARTIFACT_TYPE_RESOURCE: self.validate_and_load_resource_settings(raw_settings) + elif raw_settings["artifact_type"] == ARTIFACT_TYPE_HOOK: + self.validate_and_load_hook_settings(raw_settings) else: self.validate_and_load_module_settings(raw_settings) + def validate_and_load_hook_settings(self, raw_settings): + try: + HOOK_SETTINGS_VALIDATOR.validate(raw_settings) + except ValidationError as e: + self._raise_invalid_project( + f"Project file '{self.settings_path}' is invalid", e + ) + self.type_name = raw_settings["typeName"] + self.artifact_type = raw_settings["artifact_type"] + self.language = raw_settings["language"] + self.runtime = raw_settings["runtime"] + self.entrypoint = raw_settings["entrypoint"] + self.test_entrypoint = raw_settings["testEntrypoint"] + self.executable_entrypoint = raw_settings.get("executableEntrypoint") + self._plugin = load_plugin(raw_settings["language"]) + self.settings = raw_settings.get("settings", {}) + def validate_and_load_module_settings(self, raw_settings): try: MODULE_SETTINGS_VALIDATOR.validate(raw_settings) @@ -252,6 +306,18 @@ def _write(f): self.safewrite(self.schema_path, _write) + def _write_example_hook_schema(self): + self.schema = resource_json( + __name__, "data/examples/hook/sse.verification.v1.json" + ) + self.schema["typeName"] = self.type_name + + def _write(f): + json.dump(self.schema, f, indent=4) + f.write("\n") + + self.safewrite(self.schema_path, _write) + def _write_example_inputs(self): shutil.rmtree(self.example_inputs_path, ignore_errors=True) @@ -313,8 +379,32 @@ def _write_module_settings(f): ) f.write("\n") + def _write_hook_settings(f): + executable_entrypoint_dict = ( + {"executableEntrypoint": self.executable_entrypoint} + if self.executable_entrypoint + else {} + ) + json.dump( + { + "artifact_type": self.artifact_type, + "typeName": self.type_name, + "language": self.language, + "runtime": self.runtime, + "entrypoint": self.entrypoint, + "testEntrypoint": self.test_entrypoint, + "settings": self.settings, + **executable_entrypoint_dict, + }, + f, + indent=4, + ) + f.write("\n") + if self.artifact_type == ARTIFACT_TYPE_RESOURCE: self.overwrite(self.settings_path, _write_resource_settings) + elif self.artifact_type == ARTIFACT_TYPE_HOOK: + self.overwrite(self.settings_path, _write_hook_settings) else: self.overwrite(self.settings_path, _write_module_settings) @@ -335,6 +425,25 @@ def init_module(self, type_name): self.settings = {} self.write_settings() + def init_hook(self, type_name, language, settings=None): + self.artifact_type = ARTIFACT_TYPE_HOOK + self.type_name = type_name + self.language = language + self._plugin = load_plugin(language) + self.settings = settings or {} + self._write_example_hook_schema() + self._plugin.init(self) + self.write_settings() + + def load_hook_schema(self): + if not self.type_info: + msg = "Internal error (Must load settings first)" + LOG.critical(msg) + raise InternalError(msg) + + with self.schema_path.open("r", encoding="utf-8") as f: + self.schema = load_hook_spec(f) + def load_schema(self): if not self.type_info: msg = "Internal error (Must load settings first)" @@ -352,7 +461,7 @@ def load_configuration_schema(self): if "typeConfiguration" in self.schema: configuration_schema = self.schema["typeConfiguration"] - configuration_schema["definitions"] = self.schema["definitions"] + configuration_schema["definitions"] = self.schema.get("definitions", {}) configuration_schema["typeName"] = self.type_name self.configuration_schema = configuration_schema @@ -390,7 +499,7 @@ def safewrite(self, path, contents): except FileExistsError: LOG.info("File already exists, not overwriting '%s'", path) - def generate(self): + def generate(self, endpoint_url=None, region_name=None, target_schemas=None): if self.artifact_type == ARTIFACT_TYPE_MODULE: return # for Modules, the schema is already generated in cfn validate @@ -398,10 +507,14 @@ def generate(self): # to provision resources if schema has handlers defined if "handlers" in self.schema: handlers = self.schema["handlers"] - template = self.env.get_template("resource-role.yml") permission = "Allow" - path = self.root / ROLE_TEMPLATE_FILENAME - LOG.debug("Writing Resource Role CloudFormation template: %s", path) + if self.artifact_type == ARTIFACT_TYPE_HOOK: + template = self.env.get_template("hook-role.yml") + path = self.root / HOOK_ROLE_TEMPLATE_FILENAME + else: + template = self.env.get_template("resource-role.yml") + path = self.root / RESOURCE_ROLE_TEMPLATE_FILENAME + LOG.debug("Writing Execution Role CloudFormation template: %s", path) actions = { action for handler in handlers.values() @@ -438,6 +551,9 @@ def generate(self): role_session_timeout=role_session_timeout, ) self.overwrite(path, contents) + self.target_info = self._load_target_info( + endpoint_url, region_name, target_schemas + ) self._plugin.generate(self) @@ -452,6 +568,8 @@ def load(self): if self.artifact_type == ARTIFACT_TYPE_MODULE: self._load_modules_project() + elif self.artifact_type == ARTIFACT_TYPE_HOOK: + self._load_hooks_project() else: self._load_resources_project() @@ -479,6 +597,17 @@ def _load_modules_project(self): self.schema = template_fragment.generate_schema() self.fragment_dir = template_fragment.fragment_dir + def _load_hooks_project(self): + LOG.info("Validating your hook specification...") + try: + self.load_hook_schema() + self.load_configuration_schema() + except FileNotFoundError as e: + self._raise_invalid_project("Hook specification not found.", e) + except SpecValidationError as e: + msg = "Hook specification is invalid: " + str(e) + self._raise_invalid_project(msg, e) + def _add_modules_content_to_zip(self, zip_file): if not os.path.exists(self.root / SCHEMA_UPLOAD_FILENAME): msg = "Module schema could not be found" @@ -515,6 +644,8 @@ def submit( zip_file.write(self.settings_path, SETTINGS_FILENAME) if self.artifact_type == ARTIFACT_TYPE_MODULE: self._add_modules_content_to_zip(zip_file) + elif self.artifact_type == ARTIFACT_TYPE_HOOK: + self._add_hooks_content_to_zip(zip_file, endpoint_url, region_name) else: self._add_resources_content_to_zip(zip_file) @@ -555,6 +686,41 @@ def _add_resources_content_to_zip(self, zip_file): cli_metadata["cli-version"] = __version__ zip_file.writestr(CFN_METADATA_FILENAME, json.dumps(cli_metadata)) + def _add_hooks_content_to_zip(self, zip_file, endpoint_url=None, region_name=None): + zip_file.write(self.schema_path, SCHEMA_UPLOAD_FILENAME) + if os.path.isdir(self.inputs_path): + for filename in os.listdir(self.inputs_path): + absolute_path = self.inputs_path / filename + zip_file.write(absolute_path, INPUTS_FOLDER + "/" + filename) + LOG.debug("%s found. Writing to package.", filename) + else: + LOG.debug("%s not found. Not writing to package.", INPUTS_FOLDER) + + target_info = ( + self.target_info + if self.target_info + else self._load_target_info(endpoint_url, region_name) + ) + zip_file.writestr(TARGET_INFO_FILENAME, json.dumps(target_info, indent=4)) + for target_name, info in target_info.items(): + filename = "{}.json".format( + "-".join(s.lower() for s in target_name.split("::")) + ) + content = json.dumps(info.get("Schema", {}), indent=4).encode("utf-8") + zip_file.writestr(TARGET_SCHEMAS_FOLDER + "/" + filename, content) + LOG.debug("%s found. Writing to package.", filename) + + self._plugin.package(self, zip_file) + cli_metadata = {} + try: + cli_metadata = self._plugin.get_plugin_information(self) + except AttributeError: + LOG.debug( + "Version info is not available for plugins, not writing to metadata file" + ) + cli_metadata["cli-version"] = __version__ + zip_file.writestr(CFN_METADATA_FILENAME, json.dumps(cli_metadata)) + # pylint: disable=R1732 def _create_context_manager(self, dry_run): # if it's a dry run, keep the file; otherwise can delete after upload @@ -567,7 +733,10 @@ def _get_zip_file_path(self): return Path(f"{self.hypenated_name}.zip") def generate_docs(self): - if self.artifact_type == ARTIFACT_TYPE_MODULE: + if ( + self.artifact_type == ARTIFACT_TYPE_MODULE + or self.artifact_type == ARTIFACT_TYPE_HOOK + ): return # generate the docs folder that contains documentation based on the schema @@ -781,7 +950,6 @@ def __set_property_type(prop_type, single_type=True): type_json = type_yaml = type_longform = "Map" if object_properties: - subproperty_name = " ".join(proppath) subproperty_filename = "-".join(proppath).lower() + ".md" subproperty_path = docs_path / subproperty_filename @@ -852,10 +1020,15 @@ def _upload( cfn_client = session.client("cloudformation", endpoint_url=endpoint_url) s3_client = session.client("s3") uploader = Uploader(cfn_client, s3_client) + if use_role and not role_arn and "handlers" in self.schema: LOG.debug("Creating execution role for provider to use") + if self.artifact_type == ARTIFACT_TYPE_HOOK: + role_template_file = HOOK_ROLE_TEMPLATE_FILENAME + else: + role_template_file = RESOURCE_ROLE_TEMPLATE_FILENAME role_arn = uploader.create_or_update_role( - self.root / ROLE_TEMPLATE_FILENAME, self.hypenated_name + self.root / role_template_file, self.hypenated_name ) s3_url = uploader.upload(self.hypenated_name, fileobj) @@ -933,3 +1106,106 @@ def _wait_for_registration(cfn_client, registration_token, set_default): ) raise DownstreamError("Error setting default version") from e LOG.warning("Set default version to '%s", arn) + + # pylint: disable=R0912,R0914,R0915 + # flake8: noqa: C901 + def _load_target_info(self, endpoint_url, region_name, provided_schemas=None): + if self.artifact_type != ARTIFACT_TYPE_HOOK or not self.schema: + return {} + + if provided_schemas is None: + provided_schemas = [] + + target_names = set() + for handler in self.schema.get("handlers", []).values(): + for target_name in handler.get("targetNames", []): + target_names.add(target_name) + + loader = TypeSchemaLoader.get_type_schema_loader(endpoint_url, region_name) + + provided_target_info = {} + + if os.path.isfile(self.target_info_path): + try: + with self.target_info_path.open("r", encoding="utf-8") as f: + provided_target_info = json.load(f) + except json.JSONDecodeError as e: # pragma: no cover + self._raise_invalid_project( + f"Target info file '{self.target_info_path}' is invalid", e + ) + + if os.path.isdir(self.target_schemas_path): + for filename in os.listdir(self.target_schemas_path): + absolute_path = self.target_schemas_path / filename + if absolute_path.is_file() and absolute_path.match( + "*.json" + ): # pragma: no cover + provided_schemas.append(str(absolute_path)) + + loaded_schemas = {} + for provided_schema in provided_schemas: + loaded_schema = loader.load_type_schema(provided_schema) + if not loaded_schema: + continue + + if isinstance(loaded_schema, dict): + type_schemas = (loaded_schema,) + else: + type_schemas = loaded_schema + + for type_schema in type_schemas: + try: + type_name = type_schema["typeName"] + if type_name in loaded_schemas: + raise InvalidProjectError( + "Duplicate schemas for '{}' target type.".format(type_name) + ) + + loaded_schemas[type_name] = type_schema + except (KeyError, TypeError) as e: + LOG.warning( + "Error while loading a provided schema: %s", + provided_schema, + exc_info=e, + ) + + validator = make_resource_validator() + + target_info = {} + for target_name in target_names: + if target_name in loaded_schemas: + target_schema = loaded_schemas[target_name] + target_type = "RESOURCE" + provisioning_type = provided_target_info.get(target_name, {}).get( + "ProvisioningType", + loader.get_provision_type(target_name, "RESOURCE"), + ) + else: + ( + target_schema, + target_type, + provisioning_type, + ) = loader.load_schema_from_cfn_registry(target_name, "RESOURCE") + + is_registry_type = bool( + provisioning_type and provisioning_type != "NON_PROVISIONABLE" + ) + + if is_registry_type: # pragma: no cover + try: + validator.validate(target_schema) + except (SpecValidationError, ValidationError) as e: + self._raise_invalid_project( + f"Target schema for '{target_name}' is invalid: " + str(e), e + ) + + target_info[target_name] = { + "TargetName": target_name, + "TargetType": target_type, + "Schema": target_schema, + "ProvisioningType": provisioning_type, + "IsCfnRegistrySupportedType": is_registry_type, + "SchemaFileAvailable": bool(target_schema), + } + + return target_info diff --git a/src/rpdk/core/templates/hook-role.yml b/src/rpdk/core/templates/hook-role.yml new file mode 100644 index 00000000..000f2ba2 --- /dev/null +++ b/src/rpdk/core/templates/hook-role.yml @@ -0,0 +1,42 @@ +AWSTemplateFormatVersion: "2010-09-09" +Description: > + This CloudFormation template creates a role assumed by CloudFormation + during Hook operations on behalf of the customer. + +Resources: + ExecutionRole: + Type: AWS::IAM::Role + Properties: + MaxSessionDuration: {{ role_session_timeout }} + AssumeRolePolicyDocument: + Version: '2012-10-17' + Statement: + - Effect: Allow + Principal: + Service: + - hooks.cloudformation.amazonaws.com + - resources.cloudformation.amazonaws.com + Action: sts:AssumeRole + Condition: + StringEquals: + aws:SourceAccount: + Ref: AWS::AccountId + StringLike: + aws:SourceArn: + Fn::Sub: arn:${AWS::Partition}:cloudformation:${AWS::Region}:${AWS::AccountId}:type/hook/{{ type_name }}/* + Path: "/" + Policies: + - PolicyName: HookTypePolicy + PolicyDocument: + Version: '2012-10-17' + Statement: + - Effect: {{ permission }} + Action: + {% for action in actions %} + - "{{ action }}" + {% endfor %} + Resource: "*" +Outputs: + ExecutionRoleArn: + Value: + Fn::GetAtt: ExecutionRole.Arn diff --git a/src/rpdk/core/templates/template_hook.yml b/src/rpdk/core/templates/template_hook.yml new file mode 100644 index 00000000..7329df46 --- /dev/null +++ b/src/rpdk/core/templates/template_hook.yml @@ -0,0 +1,25 @@ +{% macro function(name, params) %} + {{ name }}: + Type: AWS::Serverless::Function + Properties: + {% for key, value in params.items() %} + {{ key }}: {{ value }} + {% endfor %} +{% endmacro %} +AWSTemplateFormatVersion: "2010-09-09" +Transform: AWS::Serverless-2016-10-31 +Description: AWS SAM template for the {{ resource_type }} resource type + +Globals: + Function: + Timeout: 180 # docker start-up times can be long for SAM CLI + MemorySize: 256 + +Resources: +{% if functions %} +{% for name, params in functions.items() %} +{{ function(name, params) }} +{% endfor %} +{% else %} +{{ function("TypeFunction", handler_params) }} +{% endif %} diff --git a/src/rpdk/core/test.py b/src/rpdk/core/test.py index 7e3beb79..b7e24993 100644 --- a/src/rpdk/core/test.py +++ b/src/rpdk/core/test.py @@ -6,6 +6,7 @@ import logging import os from argparse import SUPPRESS +from collections import OrderedDict from contextlib import contextmanager from pathlib import Path from tempfile import NamedTemporaryFile @@ -15,15 +16,17 @@ from jsonschema import Draft6Validator from jsonschema.exceptions import ValidationError +from rpdk.core.contract.hook_client import HookClient from rpdk.core.jsonutils.pointer import fragment_decode +from rpdk.core.utils.handler_utils import generate_handler_name from .boto_helpers import create_sdk_session, get_temporary_credentials from .contract.contract_plugin import ContractPlugin -from .contract.interface import Action +from .contract.interface import Action, HookInvocationPoint from .contract.resource_client import ResourceClient from .data_loaders import copy_resource from .exceptions import SysExitRecommendedError -from .project import ARTIFACT_TYPE_MODULE, Project +from .project import ARTIFACT_TYPE_HOOK, ARTIFACT_TYPE_MODULE, Project LOG = logging.getLogger(__name__) @@ -33,7 +36,7 @@ DEFAULT_TIMEOUT = "30" INPUTS = "inputs" -OVERRIDES_VALIDATOR = Draft6Validator( +RESOURCE_OVERRIDES_VALIDATOR = Draft6Validator( { "properties": {"CREATE": {"type": "object"}, "UPDATE": {"type": "object"}}, "anyOf": [{"required": ["CREATE"]}, {"required": ["UPDATE"]}], @@ -41,11 +44,39 @@ } ) +HOOK_OVERRIDES_VALIDATOR = Draft6Validator( + { + "properties": { + "CREATE_PRE_PROVISION": {"type": "object"}, + "UPDATE_PRE_PROVISION": {"type": "object"}, + "DELETE_PRE_PROVISION": {"type": "object"}, + "INVALID_CREATE_PRE_PROVISION": {"type": "object"}, + "INVALID_UPDATE_PRE_PROVISION": {"type": "object"}, + "INVALID_DELETE_PRE_PROVISION": {"type": "object"}, + "INVALID": {"type": "object"}, + }, + "anyOf": [ + {"required": ["CREATE_PRE_PROVISION"]}, + {"required": ["UPDATE_PRE_PROVISION"]}, + {"required": ["DELETE_PRE_PROVISION"]}, + {"required": ["INVALID_CREATE_PRE_PROVISION"]}, + {"required": ["INVALID_UPDATE_PRE_PROVISION"]}, + {"required": ["INVALID_DELETE_PRE_PROVISION"]}, + {"required": ["INVALID"]}, + ], + "additionalProperties": False, + } +) + def empty_override(): return {"CREATE": {}} +def empty_hook_override(): + return {"CREATE_PRE_PROVISION": {}} + + @contextmanager def temporary_ini_file(): with NamedTemporaryFile( @@ -92,6 +123,18 @@ def render_jinja(overrides_string, region_name, endpoint_url, role_arn): return to_return +def filter_overrides(overrides, project): + if project.artifact_type == ARTIFACT_TYPE_HOOK: + actions = set(HookInvocationPoint) + else: + actions = set(Action) + + for k in set(overrides) - actions: + del overrides[k] + + return overrides + + def get_overrides(root, region_name, endpoint_url, role_arn): if not root: return empty_override() @@ -105,7 +148,7 @@ def get_overrides(root, region_name, endpoint_url, role_arn): return empty_override() try: - OVERRIDES_VALIDATOR.validate(overrides_raw) + RESOURCE_OVERRIDES_VALIDATOR.validate(overrides_raw) except ValidationError as e: LOG.warning("Override file invalid: %s\n" "No overrides will be applied", e) return empty_override() @@ -125,6 +168,60 @@ def get_overrides(root, region_name, endpoint_url, role_arn): return overrides +# pylint: disable=R0914 +# flake8: noqa: C901 +def get_hook_overrides(root, region_name, endpoint_url, role_arn): + if not root: + return empty_hook_override() + + path = root / "overrides.json" + try: + with path.open("r", encoding="utf-8") as f: + overrides_raw = render_jinja(f.read(), region_name, endpoint_url, role_arn) + except FileNotFoundError: + LOG.debug("Override file '%s' not found. No overrides will be applied", path) + return empty_hook_override() + + try: + HOOK_OVERRIDES_VALIDATOR.validate(overrides_raw) + except ValidationError as e: + LOG.warning("Override file invalid: %s\n" "No overrides will be applied", e) + return empty_hook_override() + + overrides = empty_hook_override() + for ( + operation, + operation_items_raw, + ) in overrides_raw.items(): # Hook invocation point (e.g. CREATE_PRE_PROVISION) + operation_items = {} + for ( + target_name, + target_items_raw, + ) in operation_items_raw.items(): # Hook targets (e.g. AWS::S3::Bucket) + target_items = {} + for ( + item, + items_raw, + ) in ( + target_items_raw.items() + ): # Target Model fields (e.g. 'resourceProperties', 'previousResourceProperties') + items = {} + for pointer, obj in items_raw.items(): + try: + pointer = fragment_decode(pointer, prefix="") + except ValueError: # pragma: no cover + LOG.warning( + "%s pointer '%s' is invalid. Skipping", operation, pointer + ) + else: + items[pointer] = obj + target_items[item] = items + operation_items[target_name] = target_items + overrides[operation] = operation_items + + return overrides + + # pylint: disable=R0914 def get_inputs(root, region_name, endpoint_url, value, role_arn): inputs = {} @@ -158,7 +255,20 @@ def get_inputs(root, region_name, endpoint_url, value, role_arn): return None +# pylint: disable=too-many-return-statements def get_type(file_name): + if "invalid_pre_create" in file_name: + return "INVALID_CREATE_PRE_PROVISION" + if "invalid_pre_update" in file_name: + return "INVALID_UPDATE_PRE_PROVISION" + if "invalid_pre_delete" in file_name: + return "INVALID_DELETE_PRE_PROVISION" + if "pre_create" in file_name: + return "CREATE_PRE_PROVISION" + if "pre_update" in file_name: + return "UPDATE_PRE_PROVISION" + if "pre_delete" in file_name: + return "DELETE_PRE_PROVISION" if "create" in file_name: return "CREATE" if "update" in file_name: @@ -168,13 +278,76 @@ def get_type(file_name): return None +def get_resource_marker_options(schema): + lowercase_actions = [action.lower() for action in Action] + handlers = schema.get("handlers", {}).keys() + return [action for action in lowercase_actions if action not in handlers] + + +def get_hook_marker_options(schema): + handlers = schema.get("handlers", {}).keys() + action_to_handler = OrderedDict() + for invocation_point in HookInvocationPoint: + handler_name = generate_handler_name(invocation_point) + action_to_handler[handler_name] = invocation_point.lower() + + excluded_actions = [ + action for action in action_to_handler.keys() if action not in handlers + ] + return [action_to_handler[excluded_action] for excluded_action in excluded_actions] + + def get_marker_options(schema): - lowercase_actions = {action.lower() for action in Action} - excluded_actions = lowercase_actions - schema.get("handlers", {}).keys() + excluded_actions = get_resource_marker_options(schema) + get_hook_marker_options( + schema + ) marker_list = ["not " + action for action in excluded_actions] return " and ".join(marker_list) +def get_contract_plugin_client(args, project, overrides, inputs): + plugin_clients = {} + if project.artifact_type == ARTIFACT_TYPE_HOOK: + plugin_clients["hook_client"] = HookClient( + args.function_name, + args.endpoint, + args.region, + project.schema, + overrides, + inputs, + args.role_arn, + args.enforce_timeout, + project.type_name, + args.log_group_name, + args.log_role_arn, + executable_entrypoint=project.executable_entrypoint, + docker_image=args.docker_image, + target_info=project._load_target_info( # pylint: disable=protected-access + args.cloudformation_endpoint_url, args.region + ), + ) + LOG.debug("Setup plugin for HOOK type") + return plugin_clients + + plugin_clients["resource_client"] = ResourceClient( + args.function_name, + args.endpoint, + args.region, + project.schema, + overrides, + inputs, + args.role_arn, + args.enforce_timeout, + project.type_name, + args.log_group_name, + args.log_role_arn, + executable_entrypoint=project.executable_entrypoint, + docker_image=args.docker_image, + ) + LOG.debug("Setup plugin for RESOURCE type") + return plugin_clients + + def test(args): _validate_sam_args(args) project = Project() @@ -183,9 +356,15 @@ def test(args): LOG.warning("The test command is not supported in a module project") return - overrides = get_overrides( - project.root, args.region, args.cloudformation_endpoint_url, args.role_arn - ) + if project.artifact_type == ARTIFACT_TYPE_HOOK: + overrides = get_hook_overrides( + project.root, args.region, args.cloudformation_endpoint_url, args.role_arn + ) + else: + overrides = get_overrides( + project.root, args.region, args.cloudformation_endpoint_url, args.role_arn + ) + filter_overrides(overrides, project) index = 1 while True: @@ -206,24 +385,8 @@ def test(args): def invoke_test(args, project, overrides, inputs): - plugin = ContractPlugin( - ResourceClient( - args.function_name, - args.endpoint, - args.region, - project.schema, - overrides, - inputs, - args.role_arn, - args.enforce_timeout, - project.type_name, - args.log_group_name, - args.log_role_arn, - executable_entrypoint=project.executable_entrypoint, - docker_image=args.docker_image, - ) - ) - + plugin_clients = get_contract_plugin_client(args, project, overrides, inputs) + plugin = ContractPlugin(plugin_clients) with temporary_ini_file() as path: pytest_args = ["-c", path, "-m", get_marker_options(project.schema)] if args.passed_to_pytest: diff --git a/src/rpdk/core/type_schema_loader.py b/src/rpdk/core/type_schema_loader.py new file mode 100644 index 00000000..806c20bf --- /dev/null +++ b/src/rpdk/core/type_schema_loader.py @@ -0,0 +1,210 @@ +import json +import logging +import os +import re +from urllib.parse import urlparse + +import requests +from botocore.exceptions import ClientError + +from .boto_helpers import create_sdk_session +from .exceptions import RPDKBaseException + +LOG = logging.getLogger(__name__) + +VALID_TYPE_SCHEMA_URI_REGEX = "^(https?|file|s3)://.+$" + + +def is_valid_type_schema_uri(uri): + if uri is None: + return False + + pattern = re.compile(VALID_TYPE_SCHEMA_URI_REGEX) + return bool(re.search(pattern, uri)) + + +class TypeSchemaLoader: + """ + This class is constructed to return schema of the target resource type + There are four options: + * Reads the schema from a JSON file + * Reads the schema from a provided url + * Reads the schema from file in a S3 bucket + * Calls CFN DescribeType API to retrieve the schema + """ + + @staticmethod + def get_type_schema_loader(endpoint_url=None, region_name=None): + cfn_client = None + s3_client = None + try: + session = create_sdk_session(region_name) + cfn_client = session.client("cloudformation", endpoint_url=endpoint_url) + s3_client = session.client("s3", endpoint_url=endpoint_url) + except RPDKBaseException as err: # pragma: no cover + LOG.debug("Type schema loader setup resulted in error", exc_info=err) + + return TypeSchemaLoader(cfn_client, s3_client) + + def __init__(self, cfn_client, s3_client): + self.cfn_client = cfn_client + self.s3_client = s3_client + + def load_type_schema(self, provided_schema, default_schema=None): + if not provided_schema: + return default_schema + + if provided_schema.startswith("{") and provided_schema.endswith("}"): + type_schema = self.load_type_schema_from_json( + provided_schema, default_schema + ) + elif provided_schema.startswith("[") and provided_schema.endswith("]"): + type_schema = self.load_type_schema_from_json( + provided_schema, default_schema + ) + elif os.path.isfile(provided_schema): + type_schema = self.load_type_schema_from_file( + provided_schema, default_schema + ) + elif is_valid_type_schema_uri(provided_schema): + type_schema = self.load_type_schema_from_uri( + provided_schema, default_schema + ) + else: + type_schema = default_schema + + return type_schema + + @staticmethod + def load_type_schema_from_json(schema_json, default_schema=None): + if not schema_json: + return default_schema + + try: + return json.loads(schema_json) + except json.JSONDecodeError: + LOG.debug( + "Provided schema is not valid JSON. Falling back to default schema." + ) + return default_schema + + def load_type_schema_from_uri(self, schema_uri, default_schema=None): + if not is_valid_type_schema_uri(schema_uri): + return default_schema + + uri = urlparse(schema_uri) + if uri.scheme == "file": + type_schema = self.load_type_schema_from_file(uri.path, default_schema) + elif uri.scheme == "https": + type_schema = self._get_type_schema_from_url(uri.geturl(), default_schema) + elif uri.scheme == "s3": + bucket = uri.netloc + key = uri.path.lstrip("/") + type_schema = self._get_type_schema_from_s3(bucket, key, default_schema) + else: + LOG.debug( + "URI provided '%s' is not supported. Falling back to default schema", + schema_uri, + ) + type_schema = default_schema + + return type_schema + + @staticmethod + def load_type_schema_from_file(schema_path, default_schema=None): + if not schema_path: + return default_schema + + try: + with open(schema_path, "r") as file: + return TypeSchemaLoader.load_type_schema_from_json(file.read()) + except FileNotFoundError: + LOG.debug( + "Target schema file '%s' not found. Falling back to default schema.", + schema_path, + ) + return default_schema + + @staticmethod + def _get_type_schema_from_url(url, default_schema=None): + response = requests.get(url, timeout=60) + if response.status_code == 200: + type_schema = TypeSchemaLoader.load_type_schema_from_json( + response.content.decode("utf-8") + ) + else: + LOG.debug( + "Received status code of '%s' when calling url '%s.'", + str(response.status_code), + url, + ) + LOG.debug("Falling back to default schema.") + type_schema = default_schema + + return type_schema + + def _get_type_schema_from_s3(self, bucket, key, default_schema=None): + if self.s3_client is None: # pragma: no cover + LOG.debug("S3 client is not set up") + LOG.debug("Falling back to default schema") + return default_schema + + try: + type_schema = ( + self.s3_client.get_object(Bucket=bucket, Key=key)["Body"] + .read() + .decode("utf-8") + ) + return self.load_type_schema_from_json(type_schema) + except ClientError as err: + LOG.debug( + "Getting S3 object in bucket '%s' with key '%s' resulted in unknown ClientError", + bucket, + key, + exc_info=err, + ) + LOG.debug("Falling back to default schema") + return default_schema + + def load_schema_from_cfn_registry( + self, type_name, extension_type, default_schema=None + ): + if self.cfn_client is None: # pragma: no cover + LOG.debug("CloudFormation client is not set up") + LOG.debug("Falling back to default schema for type '%s'", type_name) + return default_schema, None, None + + try: + response = self.cfn_client.describe_type( + Type=extension_type, TypeName=type_name + ) + return ( + self.load_type_schema_from_json(response["Schema"]), + response["Type"], + response["ProvisioningType"], + ) + except ClientError as err: + LOG.debug( + "Describing type '%s' resulted in unknown ClientError", + type_name, + exc_info=err, + ) + LOG.debug("Falling back to default schema for type '%s'", type_name) + return default_schema, None, None + + def get_provision_type(self, type_name, extension_type): + if self.cfn_client is None: # pragma: no cover + LOG.debug("CloudFormation client is not set up") + return None + + try: + return self.cfn_client.describe_type( + Type=extension_type, TypeName=type_name + )["ProvisioningType"] + except ClientError as err: + LOG.debug( + "Describing type '%s' resulted in unknown ClientError", + type_name, + exc_info=err, + ) + return None diff --git a/src/rpdk/core/upload.py b/src/rpdk/core/upload.py index ca5c44b2..827c52f7 100644 --- a/src/rpdk/core/upload.py +++ b/src/rpdk/core/upload.py @@ -142,10 +142,11 @@ def create_or_update_role(self, template_path, resource_type): template = f.read() except FileNotFoundError: LOG.critical( - "CloudFormation template 'resource-role.yaml' " + "CloudFormation template '%s' " "for execution role not found. " "Please run `generate` or " - "provide an execution role via the --role-arn parameter." + "provide an execution role via the --role-arn parameter.", + template_path.name, ) # pylint: disable=W0707 raise InvalidProjectError() diff --git a/src/rpdk/core/utils/handler_utils.py b/src/rpdk/core/utils/handler_utils.py new file mode 100644 index 00000000..fc61cc67 --- /dev/null +++ b/src/rpdk/core/utils/handler_utils.py @@ -0,0 +1,17 @@ +import logging + +LOG = logging.getLogger(__name__) + + +def generate_handler_name(operation): + if operation.endswith("_PROVISION"): + # CREATE_PRE_PROVISION -> preCreate + *action, prefix = operation.split("_PROVISION")[0].split("_") + else: + # CREATE -> create + # SOME_OPERATION -> someOperation + prefix, *action = operation.split("_") + + handler_name = prefix.lower() + "".join(act.title() for act in action) + + return handler_name diff --git a/src/rpdk/core/utils/init_utils.py b/src/rpdk/core/utils/init_utils.py index bf7b7015..dc1fdede 100644 --- a/src/rpdk/core/utils/init_utils.py +++ b/src/rpdk/core/utils/init_utils.py @@ -3,13 +3,18 @@ from colorama import Fore, Style from rpdk.core.exceptions import WizardValidationError -from rpdk.core.project import ARTIFACT_TYPE_MODULE, ARTIFACT_TYPE_RESOURCE +from rpdk.core.project import ( + ARTIFACT_TYPE_HOOK, + ARTIFACT_TYPE_MODULE, + ARTIFACT_TYPE_RESOURCE, +) LOG = logging.getLogger(__name__) -INPUT_TYPES_STRING = "resource(r) or a module(m)" +INPUT_TYPES_STRING = "resource(r) or a module(m) or a hook(h)" VALID_RESOURCES_REPRESENTATION = {"r", "resource", "resources"} VALID_MODULES_REPRESENTATION = {"m", "module", "modules"} +VALID_HOOKS_REPRESENTATION = {"h", "hook", "hooks"} # NOTE this function is also in init, for compatibility with language plugins @@ -61,6 +66,8 @@ def validate_artifact_type(value): return ARTIFACT_TYPE_RESOURCE if value.lower() in VALID_MODULES_REPRESENTATION: return ARTIFACT_TYPE_MODULE + if value.lower() in VALID_HOOKS_REPRESENTATION: + return ARTIFACT_TYPE_HOOK raise WizardValidationError( "Please enter a value matching {}".format(INPUT_TYPES_STRING) ) diff --git a/tests/contract/test_contract_plugin.py b/tests/contract/test_contract_plugin.py index 34ae8a36..513f889c 100644 --- a/tests/contract/test_contract_plugin.py +++ b/tests/contract/test_contract_plugin.py @@ -1,7 +1,64 @@ +from unittest.mock import MagicMock + +import pytest + from rpdk.core.contract.contract_plugin import ContractPlugin +from rpdk.core.contract.hook_client import HookClient +from rpdk.core.contract.resource_client import ResourceClient + + +def test_contract_plugin_no_client(): + plugin_clients = None + expected_err_msg = "No plugin clients are set up" + with pytest.raises(RuntimeError) as excinfo: + ContractPlugin(plugin_clients) + + assert expected_err_msg in str(excinfo.value) + + plugin_clients = {} + with pytest.raises(RuntimeError) as excinfo: + ContractPlugin(plugin_clients) + + assert expected_err_msg in str(excinfo.value) def test_contract_plugin_fixture_resource_client(): - resource_client = object() - plugin = ContractPlugin(resource_client) + resource_client = MagicMock(spec=ResourceClient) + plugin_clients = {"resource_client": resource_client} + plugin = ContractPlugin(plugin_clients) assert plugin.resource_client.__wrapped__(plugin) is resource_client + + +def test_contract_plugin_fixture_resource_client_not_set(): + plugin = ContractPlugin({"client": object()}) + with pytest.raises(ValueError) as excinfo: + plugin.resource_client.__wrapped__(plugin) + assert "Contract plugin client not setup for RESOURCE type" in str(excinfo.value) + + +def test_contract_plugin_fixture_resource_client_invalid(): + plugin = ContractPlugin({"resource_client": object()}) + with pytest.raises(ValueError) as excinfo: + plugin.resource_client.__wrapped__(plugin) + assert "Contract plugin client not setup for RESOURCE type" in str(excinfo.value) + + +def test_contract_plugin_fixture_hook_client(): + hook_client = MagicMock(spec=HookClient) + plugin_clients = {"hook_client": hook_client} + plugin = ContractPlugin(plugin_clients) + assert plugin.hook_client.__wrapped__(plugin) is hook_client + + +def test_contract_plugin_fixture_hook_client_not_set(): + plugin = ContractPlugin({"client": object()}) + with pytest.raises(ValueError) as excinfo: + plugin.hook_client.__wrapped__(plugin) + assert "Contract plugin client not setup for HOOK type" in str(excinfo.value) + + +def test_contract_plugin_fixture_hook_client_invalid(): + plugin = ContractPlugin({"hook_client": object()}) + with pytest.raises(ValueError) as excinfo: + plugin.hook_client.__wrapped__(plugin) + assert "Contract plugin client not setup for HOOK type" in str(excinfo.value) diff --git a/tests/contract/test_hook_client.py b/tests/contract/test_hook_client.py new file mode 100644 index 00000000..ab3b964c --- /dev/null +++ b/tests/contract/test_hook_client.py @@ -0,0 +1,997 @@ +# fixture and parameter have the same name +# pylint: disable=redefined-outer-name,protected-access +import json +import logging +import time +from io import StringIO +from unittest import TestCase +from unittest.mock import ANY, patch + +import pytest + +from rpdk.core.boto_helpers import LOWER_CAMEL_CRED_KEYS +from rpdk.core.contract.hook_client import HookClient +from rpdk.core.contract.interface import ( + HandlerErrorCode, + HookInvocationPoint, + HookStatus, +) +from rpdk.core.contract.type_configuration import TypeConfiguration +from rpdk.core.exceptions import InvalidProjectError +from rpdk.core.test import DEFAULT_ENDPOINT, DEFAULT_FUNCTION, DEFAULT_REGION + +EMPTY_OVERRIDE = {} +ACCOUNT = "11111111" +LOG = logging.getLogger(__name__) + +HOOK_TYPE_NAME = "AWS::UnitTest::Hook" +HOOK_TARGET_TYPE_NAME = "AWS::Example::Resource" +OTHER_HOOK_TARGET_TYPE_NAME = "AWS::Other::Resource" + +SCHEMA_ = { + "typeName": HOOK_TYPE_NAME, + "description": "Test Hook", + "typeConfiguration": { + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + "c": {"type": "number"}, + "d": {"type": "number"}, + }, + }, + "additionalProperties": False, +} + +HOOK_CONFIGURATION = '{"CloudFormationConfiguration": {"HookConfiguration": {"Properties": {"key": "value"}}}}' + +HOOK_TARGET_INFO = { + "My::Example::Resource": { + "TargetName": "My::Example::Resource", + "TargetType": "RESOURCE", + "Schema": { + "typeName": "My::Example::Resource", + "additionalProperties": False, + "properties": { + "Id": {"type": "string"}, + "Tags": { + "type": "array", + "uniqueItems": False, + "items": {"$ref": "#/definitions/Tag"}, + }, + }, + "required": [], + "definitions": { + "Tag": { + "type": "object", + "additionalProperties": False, + "properties": { + "Value": {"type": "string"}, + "Key": {"type": "string"}, + }, + "required": ["Value", "Key"], + } + }, + }, + "ProvisioningType": "FULLY_MUTTABLE", + "IsCfnRegistrySupportedType": True, + "SchemaFileAvailable": True, + } +} + + +@pytest.fixture +def hook_client(): + endpoint = "https://" + patch_sesh = patch( + "rpdk.core.contract.hook_client.create_sdk_session", autospec=True + ) + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + patch_account = patch( + "rpdk.core.contract.hook_client.get_account", + autospec=True, + return_value=ACCOUNT, + ) + with patch_sesh as mock_create_sesh, patch_creds as mock_creds: + with patch_account as mock_account: + mock_sesh = mock_create_sesh.return_value + mock_sesh.region_name = DEFAULT_REGION + client = HookClient( + DEFAULT_FUNCTION, endpoint, DEFAULT_REGION, SCHEMA_, EMPTY_OVERRIDE + ) + + mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_account.assert_called_once_with(mock_sesh, {}) + assert client._function_name == DEFAULT_FUNCTION + assert client._schema == SCHEMA_ + assert client._configuration_schema == SCHEMA_["typeConfiguration"] + assert client._overrides == EMPTY_OVERRIDE + assert client.account == ACCOUNT + + return client + + +@pytest.fixture +def hook_client_inputs(): + endpoint = "https://" + patch_sesh = patch( + "rpdk.core.contract.hook_client.create_sdk_session", autospec=True + ) + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + patch_account = patch( + "rpdk.core.contract.hook_client.get_account", + autospec=True, + return_value=ACCOUNT, + ) + with patch_sesh as mock_create_sesh, patch_creds as mock_creds: + with patch_account as mock_account: + mock_sesh = mock_create_sesh.return_value + mock_sesh.region_name = DEFAULT_REGION + client = HookClient( + DEFAULT_FUNCTION, + endpoint, + DEFAULT_REGION, + SCHEMA_, + EMPTY_OVERRIDE, + { + "CREATE_PRE_PROVISION": { + "My::Example::Resource": {"resourceProperties": {"a": 1}} + }, + "UPDATE_PRE_PROVISION": { + "My::Example::Resource": { + "resourceProperties": {"a": 2}, + "previousResourceProperties": {"c": 4}, + } + }, + "INVALID_DELETE_PRE_PROVISION": { + "My::Example::Resource": {"resourceProperties": {"b": 2}} + }, + "INVALID": { + "My::Example::Resource": {"resourceProperties": {"b": 1}} + }, + }, + target_info=HOOK_TARGET_INFO, + ) + + mock_sesh.client.assert_called_once_with("lambda", endpoint_url=endpoint) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_account.assert_called_once_with(mock_sesh, {}) + assert client._function_name == DEFAULT_FUNCTION + assert client._schema == SCHEMA_ + assert client._configuration_schema == SCHEMA_["typeConfiguration"] + assert client._overrides == EMPTY_OVERRIDE + assert client.account == ACCOUNT + + return client + + +def test_init_sam_cli_client(): + patch_sesh = patch( + "rpdk.core.contract.hook_client.create_sdk_session", autospec=True + ) + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + patch_account = patch( + "rpdk.core.contract.hook_client.get_account", + autospec=True, + return_value=ACCOUNT, + ) + with patch_sesh as mock_create_sesh, patch_creds as mock_creds: + with patch_account as mock_account: + mock_sesh = mock_create_sesh.return_value + mock_sesh.region_name = DEFAULT_REGION + client = HookClient( + DEFAULT_FUNCTION, DEFAULT_ENDPOINT, DEFAULT_REGION, {}, EMPTY_OVERRIDE + ) + + mock_sesh.client.assert_called_once_with( + "lambda", endpoint_url=DEFAULT_ENDPOINT, use_ssl=False, verify=False, config=ANY + ) + mock_creds.assert_called_once_with(mock_sesh, LOWER_CAMEL_CRED_KEYS, None) + mock_account.assert_called_once_with(mock_sesh, {}) + assert client.account == ACCOUNT + + +def test_generate_token(): + token = HookClient.generate_token() + assert isinstance(token, str) + assert len(token) == 36 + + +@pytest.mark.parametrize("hook_type", [None, "Org::Srv::Type"]) +@pytest.mark.parametrize("log_group_name", [None, "random_name"]) +@pytest.mark.parametrize( + "log_creds", + [ + {}, + { + "AccessKeyId": "access", + "SecretAccessKey": "secret", + "SessionToken": "token", + }, + ], +) +def test_make_request(hook_type, log_group_name, log_creds): + target_model = object() + token = object() + request = HookClient.make_request( + HOOK_TARGET_TYPE_NAME, + hook_type, + ACCOUNT, + "CREATE_PRE_PROVISION", + {}, + log_group_name, + log_creds, + token, + target_model, + "00000001", + "RESOURCE", + ) + + expected_request = { + "requestData": { + "callerCredentials": json.dumps({}), + "targetName": HOOK_TARGET_TYPE_NAME, + "targetLogicalId": token, + "targetModel": target_model, + "targetType": "RESOURCE", + }, + "requestContext": {"callbackContext": None}, + "hookTypeName": hook_type, + "hookTypeVersion": "00000001", + "clientRequestToken": token, + "stackId": token, + "awsAccountId": ACCOUNT, + "actionInvocationPoint": "CREATE_PRE_PROVISION", + "hookModel": None, + } + if log_group_name and log_creds: + expected_request["requestData"]["providerCredentials"] = json.dumps(log_creds) + expected_request["requestData"]["providerLogGroupName"] = log_group_name + assert request == expected_request + + +def test_get_handler_target(hook_client): + targets = [HOOK_TARGET_TYPE_NAME] + schema = {"handlers": {"preCreate": {"targetNames": targets, "permissions": []}}} + hook_client._update_schema(schema) + + target_names = hook_client.get_handler_targets( + HookInvocationPoint.CREATE_PRE_PROVISION + ) + TestCase().assertCountEqual(target_names, targets) + + +def test_get_handler_target_no_targets(hook_client): + + schema = {"handlers": {"preCreate": {"permissions": []}}} + hook_client._update_schema(schema) + TestCase().assertFalse( + hook_client.get_handler_targets(HookInvocationPoint.CREATE_PRE_PROVISION) + ) + + +def test_make_payload(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + + token = "ecba020e-b2e6-4742-a7d0-8a06ae7c4b2f" + with patch.object(hook_client, "generate_token", return_value=token), patch_creds: + payload = hook_client._make_payload( + "CREATE_PRE_PROVISION", HOOK_TARGET_TYPE_NAME, {"foo": "bar"} + ) + + assert payload == { + "requestData": { + "callerCredentials": json.dumps({}), + "targetName": HOOK_TARGET_TYPE_NAME, + "targetType": "RESOURCE", + "targetLogicalId": token, + "targetModel": {"foo": "bar"}, + }, + "requestContext": {"callbackContext": None}, + "hookTypeName": HOOK_TYPE_NAME, + "hookTypeVersion": "00000001", + "clientRequestToken": token, + "stackId": token, + "awsAccountId": ACCOUNT, + "actionInvocationPoint": "CREATE_PRE_PROVISION", + "hookModel": None, + } + + +def test_generate_request(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + + token = "ecba020e-b2e6-4742-a7d0-8a06ae7c4b2f" + with patch.object(hook_client, "generate_token", return_value=token), patch_creds: + request = hook_client.generate_request( + HOOK_TARGET_TYPE_NAME, HookInvocationPoint.DELETE_PRE_PROVISION + ) + + assert request == { + "requestData": { + "callerCredentials": json.dumps({}), + "targetName": HOOK_TARGET_TYPE_NAME, + "targetLogicalId": token, + "targetModel": {"resourceProperties": {}}, + "targetType": "RESOURCE", + }, + "requestContext": {"callbackContext": None}, + "hookTypeName": HOOK_TYPE_NAME, + "hookTypeVersion": "00000001", + "clientRequestToken": token, + "stackId": token, + "awsAccountId": ACCOUNT, + "actionInvocationPoint": "DELETE_PRE_PROVISION", + "hookModel": None, + } + + +def test_generate_pre_update_request(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + + token = "ecba020e-b2e6-4742-a7d0-8a06ae7c4b2f" + with patch.object(hook_client, "generate_token", return_value=token), patch_creds: + request = hook_client.generate_request( + HOOK_TARGET_TYPE_NAME, HookInvocationPoint.UPDATE_PRE_PROVISION + ) + + assert request == { + "requestData": { + "callerCredentials": json.dumps({}), + "targetName": HOOK_TARGET_TYPE_NAME, + "targetType": "RESOURCE", + "targetLogicalId": token, + "targetModel": { + "resourceProperties": {}, + "previousResourceProperties": {}, + }, + }, + "requestContext": {"callbackContext": None}, + "hookTypeName": HOOK_TYPE_NAME, + "hookTypeVersion": "00000001", + "clientRequestToken": token, + "stackId": token, + "awsAccountId": ACCOUNT, + "actionInvocationPoint": "UPDATE_PRE_PROVISION", + "hookModel": None, + } + + +def test_generate_request_example(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + token = "ecba020e-b2e6-4742-a7d0-8a06ae7c4b2f" + with patch.object(hook_client, "generate_token", return_value=token), patch_creds: + ( + invocation_point, + target, + target_model, + ) = hook_client.generate_request_example( + HOOK_TARGET_TYPE_NAME, HookInvocationPoint.CREATE_PRE_PROVISION + ) + assert invocation_point == HookInvocationPoint.CREATE_PRE_PROVISION + assert target == HOOK_TARGET_TYPE_NAME + assert target_model == {"resourceProperties": {}} + assert not target_model.get("previousResourceProperties") + + +def test_generate_pre_update_request_example(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + token = "ecba020e-b2e6-4742-a7d0-8a06ae7c4b2f" + with patch.object(hook_client, "generate_token", return_value=token), patch_creds: + ( + invocation_point, + target, + target_model, + ) = hook_client.generate_request_example( + HOOK_TARGET_TYPE_NAME, HookInvocationPoint.UPDATE_PRE_PROVISION + ) + assert invocation_point == HookInvocationPoint.UPDATE_PRE_PROVISION + assert target == HOOK_TARGET_TYPE_NAME + assert target_model == {"resourceProperties": {}, "previousResourceProperties": {}} + + +def test_generate_request_examples(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + token = "ecba020e-b2e6-4742-a7d0-8a06ae7c4b2f" + targets = [HOOK_TARGET_TYPE_NAME, OTHER_HOOK_TARGET_TYPE_NAME] + schema = { + "typeName": HOOK_TYPE_NAME, + "handlers": {"preCreate": {"targetNames": targets, "permissions": []}}, + } + hook_client._update_schema(schema) + + with patch.object(hook_client, "generate_token", return_value=token), patch_creds: + examples = hook_client.generate_request_examples( + HookInvocationPoint.CREATE_PRE_PROVISION + ) + assert len(examples) == len(targets) + for i in range(len(examples)): + invoke_point, target, target_model = examples[i] + assert invoke_point == HookInvocationPoint.CREATE_PRE_PROVISION + assert target == targets[i] + assert target_model == {"resourceProperties": {}} + assert not target_model.get("previousResourceProperties") + + +def test_generate_invalid_request_examples(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + token = "ecba020e-b2e6-4742-a7d0-8a06ae7c4b2f" + targets = [HOOK_TARGET_TYPE_NAME, OTHER_HOOK_TARGET_TYPE_NAME] + schema = { + "typeName": HOOK_TYPE_NAME, + "handlers": {"preCreate": {"targetNames": targets, "permissions": []}}, + } + hook_client._update_schema(schema) + + with patch.object(hook_client, "generate_token", return_value=token), patch_creds: + examples = hook_client.generate_invalid_request_examples( + HookInvocationPoint.CREATE_PRE_PROVISION + ) + assert len(examples) == len(targets) + for i in range(len(examples)): + invoke_point, target, target_model = examples[i] + assert invoke_point == HookInvocationPoint.CREATE_PRE_PROVISION + assert target == targets[i] + assert target_model == {"resourceProperties": {}} + assert not target_model.get("previousResourceProperties") + + +def test_generate_update_request_examples(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + token = "ecba020e-b2e6-4742-a7d0-8a06ae7c4b2f" + targets = [HOOK_TARGET_TYPE_NAME, OTHER_HOOK_TARGET_TYPE_NAME] + schema = { + "typeName": HOOK_TYPE_NAME, + "handlers": {"preUpdate": {"targetNames": targets, "permissions": []}}, + } + hook_client._update_schema(schema) + + with patch.object(hook_client, "generate_token", return_value=token), patch_creds: + examples = hook_client.generate_request_examples( + HookInvocationPoint.UPDATE_PRE_PROVISION + ) + assert len(examples) == len(targets) + for i in range(len(examples)): + invoke_point, target, target_model = examples[i] + assert invoke_point == HookInvocationPoint.UPDATE_PRE_PROVISION + assert target == targets[i] + assert target_model == { + "resourceProperties": {}, + "previousResourceProperties": {}, + } + + +def test_generate_all_request_examples(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + token = "ecba020e-b2e6-4742-a7d0-8a06ae7c4b2f" + schema = { + "typeName": HOOK_TYPE_NAME, + "handlers": { + "preCreate": {"targetNames": [HOOK_TARGET_TYPE_NAME], "permissions": []}, + "preUpdate": { + "targetNames": [OTHER_HOOK_TARGET_TYPE_NAME], + "permissions": [], + }, + "preDelete": { + "targetNames": [HOOK_TARGET_TYPE_NAME, OTHER_HOOK_TARGET_TYPE_NAME], + "permissions": [], + }, + }, + } + hook_client._update_schema(schema) + + with patch.object(hook_client, "generate_token", return_value=token), patch_creds: + examples = hook_client.generate_all_request_examples() + + pre_create_examples = examples.get(HookInvocationPoint.CREATE_PRE_PROVISION) + assert pre_create_examples + assert len(pre_create_examples) == 1 + for example in pre_create_examples: + invoke_point, target, target_model = example + assert invoke_point == HookInvocationPoint.CREATE_PRE_PROVISION + assert target == HOOK_TARGET_TYPE_NAME + assert target_model == {"resourceProperties": {}} + assert not target_model.get("previousResourceProperties") + + pre_update_examples = examples.get(HookInvocationPoint.UPDATE_PRE_PROVISION) + assert pre_update_examples + assert len(pre_update_examples) == 1 + for example in pre_update_examples: + invoke_point, target, target_model = example + assert invoke_point == HookInvocationPoint.UPDATE_PRE_PROVISION + assert target == OTHER_HOOK_TARGET_TYPE_NAME + assert target_model == { + "resourceProperties": {}, + "previousResourceProperties": {}, + } + + pre_delete_examples = examples.get(HookInvocationPoint.DELETE_PRE_PROVISION) + assert pre_delete_examples + assert len(pre_delete_examples) == 2 + for example in pre_delete_examples: + invoke_point, target, target_model = example + assert invoke_point == HookInvocationPoint.DELETE_PRE_PROVISION + assert target == HOOK_TARGET_TYPE_NAME or target == OTHER_HOOK_TARGET_TYPE_NAME + assert target_model == {"resourceProperties": {}} + assert not target_model.get("previousResourceProperties") + + +@pytest.mark.parametrize( + "invoke_point", + [ + HookInvocationPoint.CREATE_PRE_PROVISION, + HookInvocationPoint.UPDATE_PRE_PROVISION, + HookInvocationPoint.DELETE_PRE_PROVISION, + ], +) +def test_call_sync(hook_client, invoke_point): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + + patch_config = patch( + "rpdk.core.contract.hook_client.TypeConfiguration.get_hook_configuration", + return_value={}, + ) + + mock_client = hook_client._client + mock_client.invoke.return_value = {"Payload": StringIO('{"hookStatus": "SUCCESS"}')} + with patch_creds, patch_config: + status, response = hook_client.call( + invoke_point, HOOK_TARGET_TYPE_NAME, {"foo": "bar"} + ) + + assert status == HookStatus.SUCCESS + assert response == {"hookStatus": HookStatus.SUCCESS.value} + + +def test_call_docker(): + patch_sesh = patch( + "rpdk.core.contract.hook_client.create_sdk_session", autospec=True + ) + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + patch_config = patch( + "rpdk.core.contract.hook_client.TypeConfiguration.get_hook_configuration", + return_value={}, + ) + patch_account = patch( + "rpdk.core.contract.hook_client.get_account", + autospec=True, + return_value=ACCOUNT, + ) + patch_docker = patch("rpdk.core.contract.hook_client.docker", autospec=True) + with patch_sesh as mock_create_sesh, patch_docker as mock_docker, patch_creds, patch_config: + with patch_account: + mock_client = mock_docker.from_env.return_value + mock_sesh = mock_create_sesh.return_value + mock_sesh.region_name = DEFAULT_REGION + hook_client = HookClient( + DEFAULT_FUNCTION, + "url", + DEFAULT_REGION, + {}, + EMPTY_OVERRIDE, + docker_image="docker_image", + executable_entrypoint="entrypoint", + ) + hook_client._type_name = HOOK_TYPE_NAME + response_str = ( + "__CFN_HOOK_START_RESPONSE__" + '{"hookStatus": "SUCCESS"}__CFN_HOOK_END_RESPONSE__' + ) + mock_client.containers.run.return_value = str.encode(response_str) + with patch_creds: + status, response = hook_client.call( + "CREATE_PRE_PROVISION", HOOK_TARGET_TYPE_NAME, {"foo": "bar"} + ) + + mock_client.containers.run.assert_called_once() + assert status == HookStatus.SUCCESS + assert response == {"hookStatus": HookStatus.SUCCESS.value} + + +def test_call_docker_executable_entrypoint_null(): + TypeConfiguration.TYPE_CONFIGURATION = {} + patch_sesh = patch( + "rpdk.core.contract.hook_client.create_sdk_session", autospec=True + ) + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + patch_config = patch( + "rpdk.core.contract.hook_client.TypeConfiguration.get_hook_configuration", + return_value={}, + ) + patch_account = patch( + "rpdk.core.contract.hook_client.get_account", + autospec=True, + return_value=ACCOUNT, + ) + patch_docker = patch("rpdk.core.contract.hook_client.docker", autospec=True) + with patch_sesh as mock_create_sesh, patch_docker, patch_creds, patch_config: + with patch_account: + mock_sesh = mock_create_sesh.return_value + mock_sesh.region_name = DEFAULT_REGION + hook_client = HookClient( + DEFAULT_FUNCTION, + "url", + DEFAULT_REGION, + {}, + EMPTY_OVERRIDE, + docker_image="docker_image", + ) + hook_client._type_name = HOOK_TYPE_NAME + + try: + with patch_creds: + hook_client.call( + "CREATE_PRE_PROVISION", HOOK_TARGET_TYPE_NAME, {"foo": "bar"} + ) + except InvalidProjectError: + pass + TypeConfiguration.TYPE_CONFIGURATION = None + + +@pytest.mark.parametrize( + "invoke_point", + [ + HookInvocationPoint.CREATE_PRE_PROVISION, + HookInvocationPoint.UPDATE_PRE_PROVISION, + HookInvocationPoint.DELETE_PRE_PROVISION, + ], +) +def test_call_async(hook_client, invoke_point): + mock_client = hook_client._client + + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + + patch_config = patch( + "rpdk.core.contract.hook_client.TypeConfiguration.get_hook_configuration", + return_value={}, + ) + + mock_client.invoke.side_effect = [ + {"Payload": StringIO('{"hookStatus": "IN_PROGRESS"}')}, + {"Payload": StringIO('{"hookStatus": "SUCCESS"}')}, + ] + + with patch_creds, patch_config: + status, response = hook_client.call(invoke_point, HOOK_TARGET_TYPE_NAME, {}) + + assert status == HookStatus.SUCCESS + assert response == {"hookStatus": HookStatus.SUCCESS.value} + + +def test_call_and_assert_success(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + patch_config = patch( + "rpdk.core.contract.hook_client.TypeConfiguration.get_hook_configuration", + return_value={}, + ) + mock_client = hook_client._client + mock_client.invoke.return_value = {"Payload": StringIO('{"hookStatus": "SUCCESS"}')} + with patch_creds, patch_config: + status, response, error_code = hook_client.call_and_assert( + HookInvocationPoint.CREATE_PRE_PROVISION, + HookStatus.SUCCESS, + HOOK_TARGET_TYPE_NAME, + {}, + ) + assert status == HookStatus.SUCCESS + assert response == {"hookStatus": HookStatus.SUCCESS.value} + assert error_code is None + + +def test_call_and_assert_failed_invalid_payload(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + patch_config = patch( + "rpdk.core.contract.hook_client.TypeConfiguration.get_hook_configuration", + return_value={}, + ) + mock_client = hook_client._client + mock_client.invoke.return_value = {"Payload": StringIO("invalid json document")} + with pytest.raises(ValueError), patch_creds, patch_config: + _status, _response, _error_code = hook_client.call_and_assert( + HookInvocationPoint.CREATE_PRE_PROVISION, + HookStatus.SUCCESS, + HOOK_TARGET_TYPE_NAME, + {}, + ) + + +def test_call_and_assert_failed(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + patch_config = patch( + "rpdk.core.contract.hook_client.TypeConfiguration.get_hook_configuration", + return_value={}, + ) + mock_client = hook_client._client + mock_client.invoke.return_value = { + "Payload": StringIO( + '{"hookStatus": "FAILED","errorCode": "NotFound", "message": "I have failed you"}' + ) + } + with patch_creds, patch_config: + status, response, error_code = hook_client.call_and_assert( + HookInvocationPoint.DELETE_PRE_PROVISION, + HookStatus.FAILED, + HOOK_TARGET_TYPE_NAME, + {}, + ) + assert status == HookStatus.FAILED + assert response == { + "hookStatus": HookStatus.FAILED.value, + "errorCode": "NotFound", + "message": "I have failed you", + } + assert error_code == HandlerErrorCode.NotFound + + +def test_call_and_assert_exception_unsupported_status(hook_client): + mock_client = hook_client._client + mock_client.invoke.return_value = { + "Payload": StringIO('{"hookStatus": "FAILED","errorCode": "NotFound"}') + } + with pytest.raises(ValueError): + hook_client.call_and_assert( + HookInvocationPoint.DELETE_PRE_PROVISION, + "OtherStatus", + HOOK_TARGET_TYPE_NAME, + {}, + ) + + +def test_call_and_assert_exception_assertion_mismatch(hook_client): + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + patch_config = patch( + "rpdk.core.contract.hook_client.TypeConfiguration.get_hook_configuration", + return_value={}, + ) + mock_client = hook_client._client + mock_client.invoke.return_value = {"Payload": StringIO('{"hookStatus": "SUCCESS"}')} + with pytest.raises(AssertionError), patch_creds, patch_config: + hook_client.call_and_assert( + HookInvocationPoint.CREATE_PRE_PROVISION, + HookStatus.FAILED, + HOOK_TARGET_TYPE_NAME, + {}, + ) + + +@pytest.mark.parametrize("status", [HookStatus.SUCCESS, HookStatus.FAILED]) +def test_assert_in_progress_wrong_status(status): + with pytest.raises(AssertionError): + HookClient.assert_in_progress(status, {}) + + +def test_assert_in_progress_error_code_set(): + with pytest.raises(AssertionError): + HookClient.assert_in_progress( + HookStatus.IN_PROGRESS, + {"errorCode": HandlerErrorCode.AccessDenied.value}, + ) + + +def test_assert_in_progress_result_set(): + with pytest.raises(AssertionError): + HookClient.assert_in_progress(HookStatus.IN_PROGRESS, {"result": ""}) + + +def test_assert_in_progress_callback_delay_seconds_unset(): + callback_delay_seconds = HookClient.assert_in_progress( + HookStatus.IN_PROGRESS, {"result": None} + ) + assert callback_delay_seconds == 0 + + +def test_assert_in_progress_callback_delay_seconds_set(): + callback_delay_seconds = HookClient.assert_in_progress( + HookStatus.IN_PROGRESS, {"callbackDelaySeconds": 5} + ) + assert callback_delay_seconds == 5 + + +@pytest.mark.parametrize("status", [HookStatus.IN_PROGRESS, HookStatus.FAILED]) +def test_assert_success_wrong_status(status): + with pytest.raises(AssertionError): + HookClient.assert_success(status, {}) + + +def test_assert_success_error_code_set(): + with pytest.raises(AssertionError): + HookClient.assert_success( + HookStatus.SUCCESS, {"errorCode": HandlerErrorCode.AccessDenied.value} + ) + + +def test_assert_success_callback_delay_seconds_set(): + with pytest.raises(AssertionError): + HookClient.assert_success(HookStatus.SUCCESS, {"callbackDelaySeconds": 5}) + + +@pytest.mark.parametrize("status", [HookStatus.IN_PROGRESS, HookStatus.SUCCESS]) +def test_assert_failed_wrong_status(status): + with pytest.raises(AssertionError): + HookClient.assert_failed(status, {}) + + +def test_assert_failed_error_code_unset(): + with pytest.raises(AssertionError): + HookClient.assert_failed(HookStatus.FAILED, {}) + + +def test_assert_failed_error_code_invalid(): + with pytest.raises(KeyError): + HookClient.assert_failed(HookStatus.FAILED, {"errorCode": "XXX"}) + + +def test_assert_failed_callback_delay_seconds_set(): + with pytest.raises(AssertionError): + HookClient.assert_failed( + HookStatus.FAILED, + { + "errorCode": HandlerErrorCode.AccessDenied.value, + "callbackDelaySeconds": 5, + }, + ) + + +def test_assert_failed_returns_error_code(): + error_code = HookClient.assert_failed( + HookStatus.FAILED, + { + "errorCode": HandlerErrorCode.AccessDenied.value, + "message": "I have failed you", + }, + ) + assert error_code == HandlerErrorCode.AccessDenied + + +@pytest.mark.parametrize( + "invoke_point", + [ + HookInvocationPoint.CREATE_PRE_PROVISION, + HookInvocationPoint.UPDATE_PRE_PROVISION, + HookInvocationPoint.DELETE_PRE_PROVISION, + ], +) +def test_assert_time(hook_client, invoke_point): + hook_client.assert_time(time.time() - 59, time.time(), invoke_point) + + +@pytest.mark.parametrize( + "invoke_point", + [ + HookInvocationPoint.CREATE_PRE_PROVISION, + HookInvocationPoint.UPDATE_PRE_PROVISION, + HookInvocationPoint.DELETE_PRE_PROVISION, + ], +) +def test_assert_time_fail(hook_client, invoke_point): + with pytest.raises(AssertionError): + hook_client.assert_time(time.time() - 61, time.time(), invoke_point) + + +@pytest.mark.parametrize( + "invoke_point", + [HookInvocationPoint.UPDATE_PRE_PROVISION], +) +def test_is_update_invocation_point_true(invoke_point): + assert HookClient.is_update_invocation_point(invoke_point) + + +@pytest.mark.parametrize( + "invoke_point", + [ + HookInvocationPoint.CREATE_PRE_PROVISION, + HookInvocationPoint.DELETE_PRE_PROVISION, + ], +) +def test_is_update_invocation_point_false(invoke_point): + assert not HookClient.is_update_invocation_point(invoke_point) + + +def test_generate_pre_create_target_model_inputs(hook_client_inputs): + assert hook_client_inputs._generate_target_model( + "My::Example::Resource", "CREATE_PRE_PROVISION" + ) == {"resourceProperties": {"a": 1}} + + +def test_generate_pre_update_target_model_inputs(hook_client_inputs): + assert hook_client_inputs._generate_target_model( + "My::Example::Resource", "UPDATE_PRE_PROVISION" + ) == {"resourceProperties": {"a": 2}, "previousResourceProperties": {"c": 4}} + + +def test_generate_invalid_pre_create_target_model_inputs(hook_client_inputs): + assert hook_client_inputs._generate_target_model( + "My::Example::Resource", "INVALID_CREATE_PRE_PROVISION" + ) == {"resourceProperties": {"b": 1}} + + +def test_generate_invalid_pre_delete_target_model_inputs(hook_client_inputs): + assert hook_client_inputs._generate_target_model( + "My::Example::Resource", "INVALID_DELETE_PRE_PROVISION" + ) == {"resourceProperties": {"b": 2}} + + +def test_generate_invalid_target_model_inputs(hook_client_inputs): + assert hook_client_inputs._generate_target_model( + "My::Example::Resource", "INVALID" + ) == {"resourceProperties": {"b": 1}} diff --git a/tests/contract/test_type_configuration.py b/tests/contract/test_type_configuration.py index 4b38a495..d8795e48 100644 --- a/tests/contract/test_type_configuration.py +++ b/tests/contract/test_type_configuration.py @@ -1,5 +1,7 @@ from unittest.mock import mock_open, patch +import pytest + from rpdk.core.contract.type_configuration import TypeConfiguration from rpdk.core.exceptions import InvalidProjectError @@ -9,6 +11,15 @@ TYPE_CONFIGURATION_INVALID = '{"Credentials" :{"ApiKey": "123", xxxx}}' +HOOK_CONFIGURATION_TEST_SETTING = '{"CloudFormationConfiguration": {"HookConfiguration": {"Properties": {"Credentials" :{"ApiKey": "123", "ApplicationKey": "123"}}}}}' + +HOOK_CONFIGURATION_INVALID = '{"CloudFormationConfiguration": {"TypeConfiguration": {"Properties": {"Credentials" :{"ApiKey": "123", "ApplicationKey": "123"}}}}}' + + +def setup_function(): + # Resetting before each test + TypeConfiguration.TYPE_CONFIGURATION = None + def test_get_type_configuration_with_not_exist_file(): with patch("builtins.open", mock_open()) as f: @@ -37,3 +48,32 @@ def test_get_type_configuration_with_invalid_json(): TypeConfiguration.get_type_configuration() except InvalidProjectError: pass + + +@patch("builtins.open", mock_open(read_data=HOOK_CONFIGURATION_TEST_SETTING)) +def test_get_hook_configuration(): + hook_configuration = TypeConfiguration.get_hook_configuration() + assert hook_configuration["Credentials"]["ApiKey"] == "123" + assert hook_configuration["Credentials"]["ApplicationKey"] == "123" + + # get type config again, should be the same config + hook_configuration = TypeConfiguration.get_hook_configuration() + assert hook_configuration["Credentials"]["ApiKey"] == "123" + assert hook_configuration["Credentials"]["ApplicationKey"] == "123" + + +@patch("builtins.open", mock_open(read_data=HOOK_CONFIGURATION_INVALID)) +def test_get_hook_configuration_with_invalid_json(): + with pytest.raises(InvalidProjectError) as execinfo: + TypeConfiguration.get_hook_configuration() + + assert "Hook configuration is invalid" in str(execinfo.value) + + +def test_get_hook_configuration_with_not_exist_file(): + with patch("builtins.open", mock_open()) as f: + f.side_effect = FileNotFoundError() + try: + TypeConfiguration.get_hook_configuration() + except FileNotFoundError: + pass diff --git a/tests/hook/test_init_hook.py b/tests/hook/test_init_hook.py new file mode 100644 index 00000000..85bdb289 --- /dev/null +++ b/tests/hook/test_init_hook.py @@ -0,0 +1,84 @@ +from unittest.mock import patch + +import pytest + +from rpdk.core.exceptions import WizardAbortError, WizardValidationError +from rpdk.core.hook.init_hook import ( + ValidatePluginChoice, + input_language, + input_typename, + validate_type_name, +) +from tests.test_init import PROMPT + + +def test_input_typename(): + type_name = "AWS::CFN::HOOK" + patch_input = patch( + "rpdk.core.hook.init_hook.input_with_validation", return_value=type_name + ) + with patch_input as mock_input: + assert input_typename() == type_name + mock_input.assert_called_once() + + +def test_input_language_no_plugins(): + validator = ValidatePluginChoice([]) + with patch("rpdk.core.hook.init_hook.validate_plugin_choice", validator): + with pytest.raises(WizardAbortError): + input_language() + + +def test_input_language_one_plugin(): + validator = ValidatePluginChoice([PROMPT]) + with patch("rpdk.core.hook.init_hook.validate_plugin_choice", validator): + assert input_language() == PROMPT + + +def test_input_language_several_plugins(): + validator = ValidatePluginChoice(["1", PROMPT, "2"]) + patch_validator = patch( + "rpdk.core.hook.init_hook.validate_plugin_choice", validator + ) + patch_input = patch("rpdk.core.utils.init_utils.input", return_value="2") + with patch_validator, patch_input as mock_input: + assert input_language() == PROMPT + + mock_input.assert_called_once() + + +def test_validate_plugin_choice_not_an_int(): + validator = ValidatePluginChoice(["test"]) + with pytest.raises(WizardValidationError) as excinfo: + validator("a") + assert "integer" in str(excinfo.value) + + +def test_validate_plugin_choice_less_than_zero(): + validator = ValidatePluginChoice(["test"]) + with pytest.raises(WizardValidationError) as excinfo: + validator("-1") + assert "select" in str(excinfo.value) + + +def test_validate_plugin_choice_greater_than_choice(): + choices = range(3) + validator = ValidatePluginChoice(choices) + with pytest.raises(WizardValidationError) as excinfo: + validator(str(len(choices) + 1)) # index is 1 based for input + assert "select" in str(excinfo.value) + + +def test_validate_plugin_choice_valid(): + choices = ["1", PROMPT, "2"] + validator = ValidatePluginChoice(choices) + assert validator("2") == PROMPT + + +def test_validate_type_name_invalid(): + with pytest.raises(WizardValidationError): + validate_type_name("AWS-CFN-HOOK") + + +def test_validate_type_name_valid(): + assert validate_type_name("AWS::CFN::HOOK") == "AWS::CFN::HOOK" diff --git a/tests/test_data_loaders.py b/tests/test_data_loaders.py index 44e9188d..23a2e83c 100644 --- a/tests/test_data_loaders.py +++ b/tests/test_data_loaders.py @@ -17,6 +17,7 @@ STDIN_NAME, get_file_base_uri, get_schema_store, + load_hook_spec, load_resource_spec, resource_json, resource_stream, @@ -40,6 +41,16 @@ "additionalProperties": False, } +HOOK_BASIC_SCHEMA = { + "typeName": "AWS::FOO::BAR", + "description": "test schema", + "typeConfiguration": { + "properties": {"foo": {"type": "string"}}, + "additionalProperties": False, + }, + "additionalProperties": False, +} + def json_s(obj): return StringIO(json.dumps(obj)) @@ -78,6 +89,14 @@ def test_load_resource_spec_empty_object_is_invalid(): load_resource_spec(json_s({})) +def test_load_hook_spec_invalid_json(): + with pytest.raises(SpecValidationError) as excinfo: + load_hook_spec(StringIO('{"foo": "aaaaa}')) + + assert "line 1" in str(excinfo.value) + assert "column 9" in str(excinfo.value) + + def json_files_params(path, glob="*.json"): return tuple(pytest.param(p, id=p.name) for p in path.glob(glob)) @@ -207,6 +226,47 @@ def test_load_resource_spec_remote_key_is_invalid(): assert "remote" in str(excinfo.value) +@pytest.mark.parametrize( + "permission", ("cloudformation:RegisterType", "cloudformation:*") +) +def test_load_hook_spec_hook_permissions_invalid(permission): + schema = { + "typeName": "AWS::FOO::BAR", + "description": "test schema", + "typeConfiguration": { + "properties": {"foo": {"type": "string"}}, + "additionalProperties": False, + }, + "handlers": { + "preCreate": {"targetNames": ["AWS::BAZ::ZAZ"], "permissions": [permission]} + }, + "additionalProperties": False, + } + with pytest.raises(SpecValidationError) as excinfo: + load_hook_spec(json_s(schema)) + assert "not allowed for hook handler permissions" in str(excinfo.value) + + +def test_load_hook_spec_hook_permissions_valid(): + schema = { + "typeName": "AWS::FOO::BAR", + "description": "test schema", + "typeConfiguration": { + "properties": {"foo": {"type": "string"}}, + "additionalProperties": False, + }, + "handlers": { + "preDelete": { + "targetNames": ["AWS::S3::Bucket"], + "permissions": ["s3:GetObject"], + } + }, + "additionalProperties": False, + } + result = load_hook_spec(json_s(schema)) + assert result == schema + + def test_argparse_stdin_name(): """By default, pytest messes with stdin and stdout, which prevents me from writing a test to check we have the right magic name that argparse uses @@ -297,6 +357,18 @@ def test_load_resource_spec_invalid_ref(): assert "bar" in str(cause) +def test_load_hook_spec_invalid_ref(): + copy = json.loads(json.dumps(HOOK_BASIC_SCHEMA)) + copy["typeConfiguration"]["properties"]["foo"] = {"$ref": "#/bar"} + with pytest.raises(SpecValidationError) as excinfo: + load_hook_spec(json_s(copy)) + + cause = excinfo.value.__cause__ + assert cause + assert isinstance(cause, RefResolutionError) + assert "bar" in str(cause) + + @pytest.fixture def plugin(): mock_plugin = create_autospec(LanguagePlugin) @@ -367,7 +439,7 @@ def test_get_schema_store_schemas_with_id(): schema_store = get_schema_store( BASEDIR.parent / "src" / "rpdk" / "core" / "data" / "schema" ) - assert len(schema_store) == 5 + assert len(schema_store) == 7 assert "http://json-schema.org/draft-07/schema#" in schema_store assert ( "https://schema.cloudformation.us-east-1.amazonaws.com/base.definition.schema.v1.json" @@ -381,6 +453,14 @@ def test_get_schema_store_schemas_with_id(): "https://schema.cloudformation.us-east-1.amazonaws.com/provider.definition.schema.v1.json" in schema_store ) + assert ( + "https://schema.cloudformation.us-east-1.amazonaws.com/provider.definition.schema.hooks.v1.json" + in schema_store + ) + assert ( + "https://schema.cloudformation.us-east-1.amazonaws.com/provider.configuration.definition.schema.hooks.v1.json" + in schema_store + ) def test_get_schema_store_schemas_with_out_id(): diff --git a/tests/test_generate.py b/tests/test_generate.py index f10ad1f7..963f089a 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -12,7 +12,38 @@ def test_generate_command_generate(capsys): main(args_in=["generate"]) mock_project.load.assert_called_once_with() - mock_project.generate.assert_called_once_with() + mock_project.generate.assert_called_once_with(None, None, []) + mock_project.generate_docs.assert_called_once_with() + + out, err = capsys.readouterr() + assert not err + assert "foo" in out + + +def test_generate_command_generate_with_args(capsys): + mock_project = Mock(spec=Project) + mock_project.type_name = "foo" + + with patch("rpdk.core.generate.Project", autospec=True, return_value=mock_project): + main( + args_in=[ + "generate", + "--endpoint-url", + "http://localhost/3001", + "--region", + "us-east-1", + "--target-schemas", + "/files/target-schema.json", + "/files/other-target-schema", + ] + ) + + mock_project.load.assert_called_once_with() + mock_project.generate.assert_called_once_with( + "http://localhost/3001", + "us-east-1", + ["/files/target-schema.json", "/files/other-target-schema"], + ) mock_project.generate_docs.assert_called_once_with() out, err = capsys.readouterr() diff --git a/tests/test_init.py b/tests/test_init.py index 7658f553..133d10c5 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -6,6 +6,7 @@ from rpdk.core.cli import main from rpdk.core.exceptions import WizardAbortError, WizardValidationError +from rpdk.core.hook.init_hook import input_typename as input_typename_hook from rpdk.core.init import ( ValidatePluginChoice, check_for_existing_project, @@ -56,6 +57,9 @@ def test_init_resource_method_interactive(): "force": False, "type_name": None, "artifact_type": None, + "endpoint_url": None, + "region": None, + "target_schemas": [], }, ) mock_project.generate.assert_called_once_with() @@ -89,6 +93,41 @@ def test_init_module_method_interactive(): mock_project.generate.assert_not_called() +def test_init_hook_method_interactive(): + type_name = object() + language = object() + + mock_project, patch_project = get_mock_project() + patch_tn = patch("rpdk.core.hook.init_hook.input_typename", return_value=type_name) + patch_l = patch("rpdk.core.hook.init_hook.input_language", return_value=language) + patch_at = patch("rpdk.core.init.init_artifact_type", return_value="HOOK") + + with patch_project, patch_at as mock_t, patch_tn as mock_tn, patch_l as mock_l: + main(args_in=["init"]) + + mock_tn.assert_called_once_with() + mock_l.assert_called_once_with() + mock_t.assert_called_once() + + mock_project.load_settings.assert_called_once_with() + mock_project.init_hook.assert_called_once_with( + type_name, + language, + { + "version": False, + "subparser_name": None, + "verbose": 0, + "force": False, + "type_name": None, + "artifact_type": None, + "endpoint_url": None, + "region": None, + "target_schemas": [], + }, + ) + mock_project.generate.assert_called_once_with(None, None, []) + + def test_init_resource_method_noninteractive(): add_dummy_language_plugin() artifact_type = "RESOURCE" @@ -127,11 +166,60 @@ def test_init_resource_method_noninteractive(): "language": args.language, "dummy": True, "artifact_type": artifact_type, + "endpoint_url": None, + "region": None, + "target_schemas": [], }, ) mock_project.generate.assert_called_once_with() +def test_init_hook_method_noninteractive(): + add_dummy_language_plugin() + artifact_type = "HOOK" + args = get_args("dummy", "Test::Test::Test", artifact_type) + mock_project, patch_project = get_mock_project() + + patch_get_parser = patch( + "rpdk.core.init.get_parsers", return_value={"dummy": dummy_parser} + ) + + with patch_project, patch_get_parser as mock_parser: + main( + args_in=[ + "init", + "--type-name", + args.type_name, + "--artifact-type", + args.artifact_type, + args.language, + "--dummy", + ] + ) + + mock_parser.assert_called_once() + + mock_project.load_settings.assert_called_once_with() + mock_project.init_hook.assert_called_once_with( + args.type_name, + args.language, + { + "version": False, + "subparser_name": args.language, + "verbose": 0, + "force": False, + "type_name": args.type_name, + "language": args.language, + "dummy": True, + "artifact_type": artifact_type, + "endpoint_url": None, + "region": None, + "target_schemas": [], + }, + ) + mock_project.generate.assert_called_once_with(None, None, []) + + def test_init_resource_method_noninteractive_invalid_type_name(): add_dummy_language_plugin() type_name = object() @@ -177,11 +265,118 @@ def test_init_resource_method_noninteractive_invalid_type_name(): "artifact_type": artifact_type, "dummy": True, "language": args.language, + "endpoint_url": None, + "region": None, + "target_schemas": [], }, ) mock_project.generate.assert_called_once_with() +def test_init_hook_method_noninteractive_invalid_type_name(): + add_dummy_language_plugin() + type_name = object() + artifact_type = "HOOK" + + args = get_args("dummy", "invalid_type_name", "HOOK") + mock_project, patch_project = get_mock_project() + + patch_tn = patch("rpdk.core.hook.init_hook.input_typename", return_value=type_name) + patch_t = patch("rpdk.core.init.init_artifact_type", return_value=artifact_type) + patch_get_parser = patch( + "rpdk.core.init.get_parsers", return_value={"dummy": dummy_parser} + ) + + with patch_project, patch_t, patch_tn as mock_tn, patch_get_parser as mock_parser: + main( + args_in=[ + "init", + "-t", + args.type_name, + "-a", + args.artifact_type, + args.language, + "--dummy", + ] + ) + + mock_tn.assert_called_once_with() + mock_parser.assert_called_once() + + mock_project.load_settings.assert_called_once_with() + mock_project.init_hook.assert_called_once_with( + type_name, + args.language, + { + "version": False, + "subparser_name": args.language, + "verbose": 0, + "force": False, + "type_name": args.type_name, + "artifact_type": artifact_type, + "dummy": True, + "language": args.language, + "endpoint_url": None, + "region": None, + "target_schemas": [], + }, + ) + mock_project.generate.assert_called_once_with(None, None, []) + + +def test_init_hook_method_noninteractive_target_schemas(): + add_dummy_language_plugin() + artifact_type = "HOOK" + args = get_args("dummy", "Test::Test::Test", artifact_type) + mock_project, patch_project = get_mock_project() + + patch_get_parser = patch( + "rpdk.core.init.get_parsers", return_value={"dummy": dummy_parser} + ) + + with patch_project, patch_get_parser as mock_parser: + main( + args_in=[ + "init", + "--type-name", + args.type_name, + "--artifact-type", + args.artifact_type, + "--target-schemas", + "/files/target-schema.json,/files/other-target-schema.json", + args.language, + "--dummy", + ] + ) + + mock_parser.assert_called_once() + + mock_project.load_settings.assert_called_once_with() + mock_project.init_hook.assert_called_once_with( + args.type_name, + args.language, + { + "version": False, + "subparser_name": args.language, + "verbose": 0, + "force": False, + "type_name": args.type_name, + "language": args.language, + "dummy": True, + "artifact_type": artifact_type, + "endpoint_url": None, + "region": None, + "target_schemas": [ + "/files/target-schema.json", + "/files/other-target-schema.json", + ], + }, + ) + mock_project.generate.assert_called_once_with( + None, None, ["/files/target-schema.json", "/files/other-target-schema.json"] + ) + + def test_input_with_validation_valid_first_try(capsys): sentinel1 = object() sentinel2 = object() @@ -398,6 +593,14 @@ def test_input_typename_module(): mock_input.assert_called_once() +def test_input_typename_hook(): + type_name = "AWS::CFN::HOOK" + patch_input = patch("rpdk.core.utils.init_utils.input", return_value=type_name) + with patch_input as mock_input: + assert input_typename_hook() == type_name + mock_input.assert_called_once() + + def test_input_language_no_plugins(): validator = ValidatePluginChoice([]) with patch("rpdk.core.init.validate_plugin_choice", validator): diff --git a/tests/test_invoke.py b/tests/test_invoke.py index 16113607..2b2133ca 100644 --- a/tests/test_invoke.py +++ b/tests/test_invoke.py @@ -6,16 +6,83 @@ import pytest -from rpdk.core.cli import main -from rpdk.core.contract.interface import Action +from rpdk.core.cli import EXIT_UNHANDLED_EXCEPTION, main +from rpdk.core.contract.interface import Action, HookInvocationPoint from rpdk.core.invoke import _needs_reinvocation -from rpdk.core.project import Project +from rpdk.core.project import ARTIFACT_TYPE_HOOK, ARTIFACT_TYPE_RESOURCE, Project ACTIONS = list(Action.__members__) +HOOK_INVOCATION_POINTS = list(HookInvocationPoint.__members__) + + +def _setup_resource_test(): + mock_project = Mock(spec=Project) + mock_project.schema = {} + mock_project.root = None + mock_project.executable_entrypoint = None + mock_project.artifact_type = ARTIFACT_TYPE_RESOURCE + + patch_project = patch( + "rpdk.core.invoke.Project", autospec=True, return_value=mock_project + ) + patch_session = patch( + "rpdk.core.contract.resource_client.create_sdk_session", autospec=True + ) + patch_creds = patch( + "rpdk.core.contract.resource_client.get_temporary_credentials", + autospec=True, + return_value={}, + ) + patch_account = patch( + "rpdk.core.contract.resource_client.get_account", + autospec=True, + return_value="", + ) + + return mock_project, patch_project, patch_session, patch_creds, patch_account + + +def _setup_hook_test(): + mock_project = Mock(spec=Project) + mock_project.schema = {} + mock_project.root = None + mock_project.executable_entrypoint = None + mock_project.artifact_type = ARTIFACT_TYPE_HOOK + + patch_project = patch( + "rpdk.core.invoke.Project", autospec=True, return_value=mock_project + ) + patch_session = patch( + "rpdk.core.contract.hook_client.create_sdk_session", autospec=True + ) + patch_creds = patch( + "rpdk.core.contract.hook_client.get_temporary_credentials", + autospec=True, + return_value="{}", + ) + patch_account = patch( + "rpdk.core.contract.hook_client.get_account", + autospec=True, + return_value="", + ) + patch_type_name = patch( + "rpdk.core.contract.hook_client.HookClient.get_hook_type_name", + autospec=True, + return_value="AWS::Testing::Hook", + ) + + return ( + mock_project, + patch_project, + patch_session, + patch_creds, + patch_account, + patch_type_name, + ) @pytest.fixture -def payload_path(tmp_path): +def resource_payload_path(tmp_path): path = tmp_path / "payload.json" with path.open("w", encoding="utf-8") as f: json.dump( @@ -29,6 +96,21 @@ def payload_path(tmp_path): return path +@pytest.fixture +def hook_payload_path(tmp_path): + path = tmp_path / "payload.json" + with path.open("w", encoding="utf-8") as f: + json.dump( + { + "targetName": "AWS::Testing::Resource", + "targetModel": {"foo": "bar"}, + "logicalResourceIdentifier": None, + }, + f, + ) + return path + + @pytest.fixture def invalid_payload(tmp_path): path = tmp_path / "payload.json" @@ -37,9 +119,33 @@ def invalid_payload(tmp_path): return path +@pytest.mark.parametrize("command", ["invalid"]) +def test_command_with_invalid_subcommand(capsys, command): + with patch("rpdk.core.invoke.invoke", autospec=True) as mock_func: + with pytest.raises(SystemExit) as excinfo: + main(args_in=["invoke", command]) + assert excinfo.value.code != EXIT_UNHANDLED_EXCEPTION + _, err = capsys.readouterr() + assert "invalid choice:" in err + mock_func.assert_not_called() + + +@pytest.mark.parametrize("command", ["resource", "hook"]) +def test_subcommand_with_required_params(capsys, command): + with patch("rpdk.core.invoke.invoke", autospec=True) as mock_func: + with pytest.raises(SystemExit) as excinfo: + main(args_in=["invoke", command]) + assert excinfo.value.code != EXIT_UNHANDLED_EXCEPTION + _, err = capsys.readouterr() + assert "the following arguments are required" in err + mock_func.assert_not_called() + + @pytest.mark.parametrize("command", ACTIONS) -def test_invoke_command_happy_path(capsys, payload_path, command): - mock_project, mock_invoke = _invoke_and_expect("SUCCESS", payload_path, command) +def test_invoke_command_happy_path_resource(capsys, resource_payload_path, command): + mock_project, mock_invoke = _invoke_and_expect_resource( + "SUCCESS", resource_payload_path, command + ) mock_project.load.assert_called_once_with() mock_invoke.assert_called_once() @@ -47,9 +153,11 @@ def test_invoke_command_happy_path(capsys, payload_path, command): assert not err -@pytest.mark.parametrize("command", ACTIONS) -def test_invoke_command_sad_path(capsys, payload_path, command): - mock_project, mock_invoke = _invoke_and_expect("FAILED", payload_path, command) +@pytest.mark.parametrize("command", HOOK_INVOCATION_POINTS) +def test_invoke_command_happy_path_hook(capsys, hook_payload_path, command): + mock_project, mock_invoke = _invoke_and_expect_hook( + "SUCCESS", hook_payload_path, command + ) mock_project.load.assert_called_once_with() mock_invoke.assert_called_once() @@ -58,22 +166,21 @@ def test_invoke_command_sad_path(capsys, payload_path, command): @pytest.mark.parametrize("command", ACTIONS) -def test_invoke_command_in_progress_with_reinvoke(capsys, payload_path, command): - mock_project, mock_invoke = _invoke_and_expect( - "IN_PROGRESS", payload_path, command, "--max-reinvoke", "2" +def test_invoke_command_sad_path_resource(capsys, resource_payload_path, command): + mock_project, mock_invoke = _invoke_and_expect_resource( + "FAILED", resource_payload_path, command ) - assert mock_invoke.call_count == 3 - mock_project.load.assert_called_once_with() + mock_invoke.assert_called_once() _out, err = capsys.readouterr() assert not err -@pytest.mark.parametrize("command", ACTIONS) -def test_invoke_command_in_progress_with_no_reinvocation(capsys, payload_path, command): - mock_project, mock_invoke = _invoke_and_expect( - "IN_PROGRESS", payload_path, command, "--max-reinvoke", "0" +@pytest.mark.parametrize("command", HOOK_INVOCATION_POINTS) +def test_invoke_command_sad_path_hook(capsys, hook_payload_path, command): + mock_project, mock_invoke = _invoke_and_expect_hook( + "FAILED", hook_payload_path, command ) mock_project.load.assert_called_once_with() @@ -83,60 +190,110 @@ def test_invoke_command_in_progress_with_no_reinvocation(capsys, payload_path, c @pytest.mark.parametrize("command", ACTIONS) -def test_value_error_on_json_load(capsys, invalid_payload, command): - mock_project = Mock(spec=Project) - mock_project.schema = {} - mock_project.root = None - mock_project.executable_entrypoint = None - - patch_project = patch( - "rpdk.core.invoke.Project", autospec=True, return_value=mock_project +def test_invoke_command_in_progress_with_reinvoke_resource( + capsys, resource_payload_path, command +): + mock_project, mock_invoke = _invoke_and_expect_resource( + "IN_PROGRESS", resource_payload_path, command, "--max-reinvoke", "2" ) - patch_session = patch( - "rpdk.core.contract.resource_client.create_sdk_session", autospec=True + + assert mock_invoke.call_count == 3 + + mock_project.load.assert_called_once_with() + _out, err = capsys.readouterr() + assert not err + + +@pytest.mark.parametrize("command", HOOK_INVOCATION_POINTS) +def test_invoke_command_in_progress_with_reinvoke_hook( + capsys, hook_payload_path, command +): + mock_project, mock_invoke = _invoke_and_expect_hook( + "IN_PROGRESS", hook_payload_path, command, "--max-reinvoke", "2" ) - patch_creds = patch( - "rpdk.core.contract.resource_client.get_temporary_credentials", - autospec=True, - return_value={}, + + assert mock_invoke.call_count == 3 + + mock_project.load.assert_called_once_with() + _out, err = capsys.readouterr() + assert not err + + +@pytest.mark.parametrize("command", ACTIONS) +def test_invoke_command_in_progress_with_no_reinvocation_resource( + capsys, resource_payload_path, command +): + mock_project, mock_invoke = _invoke_and_expect_resource( + "IN_PROGRESS", resource_payload_path, command, "--max-reinvoke", "0" ) - patch_account = patch( - "rpdk.core.contract.resource_client.get_account", - autospec=True, - return_value="", + + mock_project.load.assert_called_once_with() + mock_invoke.assert_called_once() + _out, err = capsys.readouterr() + assert not err + + +@pytest.mark.parametrize("command", HOOK_INVOCATION_POINTS) +def test_invoke_command_in_progress_with_no_reinvocation_hook( + capsys, hook_payload_path, command +): + mock_project, mock_invoke = _invoke_and_expect_hook( + "IN_PROGRESS", hook_payload_path, command, "--max-reinvoke", "0" ) + mock_project.load.assert_called_once_with() + mock_invoke.assert_called_once() + _out, err = capsys.readouterr() + assert not err + + +@pytest.mark.parametrize("command", ACTIONS) +def test_value_error_on_json_load_resource(capsys, invalid_payload, command): + ( + _mock_project, + patch_project, + patch_session, + patch_creds, + patch_account, + ) = _setup_resource_test() + with patch_project, patch_session, patch_creds, patch_account: with pytest.raises(SystemExit): - main(args_in=["invoke", command, str(invalid_payload)]) + main(args_in=["invoke", "resource", command, str(invalid_payload)]) + + out, _err = capsys.readouterr() + assert "Invalid JSON" in out + + +@pytest.mark.parametrize("command", HOOK_INVOCATION_POINTS) +def test_value_error_on_json_load_hook(capsys, invalid_payload, command): + ( + _mock_project, + patch_project, + patch_session, + patch_creds, + patch_account, + patch_type_name, + ) = _setup_hook_test() + + with patch_project, patch_session, patch_creds, patch_account, patch_type_name: + with pytest.raises(SystemExit): + main(args_in=["invoke", "hook", command, str(invalid_payload)]) out, _err = capsys.readouterr() assert "Invalid JSON" in out @pytest.mark.parametrize("command", ACTIONS) -def test_keyboard_interrupt(capsys, payload_path, command): - mock_project = Mock(spec=Project) - mock_project.schema = {} - mock_project.root = None - mock_project.executable_entrypoint = None +def test_keyboard_interrupt_resource(capsys, resource_payload_path, command): + ( + mock_project, + patch_project, + patch_session, + patch_creds, + patch_account, + ) = _setup_resource_test() - patch_project = patch( - "rpdk.core.invoke.Project", autospec=True, return_value=mock_project - ) - patch_session = patch( - "rpdk.core.contract.resource_client.create_sdk_session", autospec=True - ) - patch_creds = patch( - "rpdk.core.contract.resource_client.get_temporary_credentials", - autospec=True, - return_value={}, - ) - patch_account = patch( - "rpdk.core.contract.resource_client.get_account", - autospec=True, - return_value="", - ) patch_dumps = patch.object(json, "dumps", side_effect=KeyboardInterrupt) # fmt: off @@ -146,7 +303,36 @@ def test_keyboard_interrupt(capsys, payload_path, command): patch_account, \ patch_session as mock_session: mock_client = mock_session.return_value.client.return_value - main(args_in=["invoke", command, str(payload_path)]) + main(args_in=["invoke", "resource", command, str(resource_payload_path)]) + # fmt: on + + mock_project.load.assert_called_once_with() + mock_client.invoke.assert_not_called() + _out, err = capsys.readouterr() + assert not err + + +@pytest.mark.parametrize("command", HOOK_INVOCATION_POINTS) +def test_keyboard_interrupt_hook(capsys, hook_payload_path, command): + ( + mock_project, + patch_project, + patch_session, + patch_creds, + patch_account, + patch_type_name, + ) = _setup_hook_test() + patch_dumps = patch.object(json, "dumps", side_effect=KeyboardInterrupt) + + # fmt: off + with patch_project, \ + patch_creds, \ + patch_dumps, \ + patch_account, \ + patch_type_name, \ + patch_session as mock_session: + mock_client = mock_session.return_value.client.return_value + main(args_in=["invoke", "hook", command, str(hook_payload_path)]) # fmt: on mock_project.load.assert_called_once_with() @@ -165,39 +351,52 @@ def test_needs_reinvocation(max_reinvoke, current_invocation, result): assert _needs_reinvocation(max_reinvoke, current_invocation) is result -def _invoke_and_expect(status, payload_path, command, *args): - mock_project = Mock(spec=Project) - mock_project.schema = {} - mock_project.root = None - mock_project.executable_entrypoint = None +def _invoke_and_expect_resource(status, resource_payload_path, command, *args): + ( + mock_project, + patch_project, + patch_session, + patch_creds, + patch_account, + ) = _setup_resource_test() - patch_project = patch( - "rpdk.core.invoke.Project", autospec=True, return_value=mock_project - ) - patch_session = patch( - "rpdk.core.contract.resource_client.create_sdk_session", autospec=True - ) - patch_creds = patch( - "rpdk.core.contract.resource_client.get_temporary_credentials", - autospec=True, - return_value={}, - ) - patch_account = patch( - "rpdk.core.contract.resource_client.get_account", - autospec=True, - return_value="", - ) + # fmt: off + with patch_project, \ + patch_account, \ + patch_session as mock_session, \ + patch_creds as mock_creds: + mock_client = mock_session.return_value.client.return_value + mock_client.invoke.side_effect = lambda **_kwargs: { + "Payload": StringIO(json.dumps({"status": status})) + } + main(args_in=["invoke", "resource", command, str(resource_payload_path), *args]) + # fmt: on + mock_creds.assert_called() + + return mock_project, mock_client.invoke + + +def _invoke_and_expect_hook(status, hook_payload_path, command, *args): + ( + mock_project, + patch_project, + patch_session, + patch_creds, + patch_account, + patch_type_name, + ) = _setup_hook_test() # fmt: off with patch_project, \ patch_account, \ + patch_type_name, \ patch_session as mock_session, \ patch_creds as mock_creds: mock_client = mock_session.return_value.client.return_value mock_client.invoke.side_effect = lambda **_kwargs: { "Payload": StringIO(json.dumps({"status": status})) } - main(args_in=["invoke", command, str(payload_path), *args]) + main(args_in=["invoke", "hook", command, str(hook_payload_path), *args]) # fmt: on mock_creds.assert_called() diff --git a/tests/test_project.py b/tests/test_project.py index 315dd823..f262a3fa 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -11,6 +11,7 @@ from io import StringIO from pathlib import Path from shutil import copyfile +from unittest import TestCase from unittest.mock import ANY, MagicMock, call, patch import pytest @@ -33,19 +34,23 @@ OVERRIDES_FILENAME, SCHEMA_UPLOAD_FILENAME, SETTINGS_FILENAME, + TARGET_INFO_FILENAME, Project, escape_markdown, ) -from rpdk.core.test import empty_override +from rpdk.core.test import empty_hook_override, empty_override +from rpdk.core.type_schema_loader import TypeSchemaLoader from rpdk.core.upload import Uploader from .utils import CONTENTS_UTF8, UnclosingBytesIO ARTIFACT_TYPE_RESOURCE = "RESOURCE" ARTIFACT_TYPE_MODULE = "MODULE" +ARTIFACT_TYPE_HOOK = "HOOK" LANGUAGE = "BQHDBC" TYPE_NAME = "AWS::Color::Red" MODULE_TYPE_NAME = "AWS::Color::Red::MODULE" +HOOK_TYPE_NAME = "AWS::CFN::HOOK" REGION = "us-east-1" ENDPOINT = "cloudformation.beta.com" RUNTIME = random.choice(list(LAMBDA_RUNTIMES)) @@ -69,6 +74,9 @@ CREATE_INPUTS_FILE = "inputs/inputs_1_create.json" UPDATE_INPUTS_FILE = "inputs/inputs_1_update.json" INVALID_INPUTS_FILE = "inputs/inputs_1_invalid.json" +PRE_CREATE_INPUTS_FILE = "inputs/inputs_1_pre_create.json" +PRE_UPDATE_INPUTS_FILE = "inputs/inputs_1_pre_update.json" +INVALID_PRE_DELETE_INPUTS_FILE = "inputs/inputs_1_invalid_pre_delete.json" PLUGIN_INFORMATION = { "plugin-version": "2.1.3", @@ -126,6 +134,13 @@ def test_load_settings_invalid_modules_settings(project): mock_open.assert_called_once_with("r", encoding="utf-8") +def test_load_settings_invalid_hooks_settings(project): + with patch_settings(project, '{"artifact_type": "HOOK"}') as mock_open: + with pytest.raises(InvalidProjectError): + project.load_settings() + mock_open.assert_called_once_with("r", encoding="utf-8") + + def test_load_settings_valid_json_for_resource(project): plugin = object() data = json.dumps( @@ -216,11 +231,45 @@ def test_generate_for_modules_succeeds(project): project.generate_docs() +def test_load_settings_valid_json_for_hook(project): + plugin = object() + data = json.dumps( + { + "artifact_type": "HOOK", + "typeName": HOOK_TYPE_NAME, + "language": LANGUAGE, + "runtime": RUNTIME, + "entrypoint": None, + "testEntrypoint": None, + } + ) + patch_load = patch( + "rpdk.core.project.load_plugin", autospec=True, return_value=plugin + ) + + with patch_settings(project, data) as mock_open, patch_load as mock_load: + project.load_settings() + + mock_open.assert_called_once_with("r", encoding="utf-8") + mock_load.assert_called_once_with(LANGUAGE) + assert project.type_info == ("AWS", "CFN", "HOOK") + assert project.type_name == HOOK_TYPE_NAME + assert project.language == LANGUAGE + assert project.artifact_type == ARTIFACT_TYPE_HOOK + assert project._plugin is plugin + assert project.settings == {} + + def test_load_schema_settings_not_loaded(project): with pytest.raises(InternalError): project.load_schema() +def test_load_hook_schema_settings_not_loaded(project): + with pytest.raises(InternalError): + project.load_hook_schema() + + def test_load_schema_example(project): project.type_name = "AWS::Color::Blue" project._write_example_schema() @@ -661,6 +710,141 @@ def test_init_resource(project): assert f.read() == b"\n" +def test_generate_hook_handlers(project, tmpdir): + project.type_name = "Test::Handler::Test" + project.artifact_type = ARTIFACT_TYPE_HOOK + expected_actions = {"preCreateAction", "preDeleteAction"} + project.schema = { + "handlers": { + "preCreate": {"permissions": ["preCreateAction", "preDeleteAction"]}, + "preDelete": {"permissions": ["preDeleteAction", ""]}, + } + } + project.root = tmpdir + mock_plugin = MagicMock(spec=["generate"]) + with patch.object(project, "_plugin", mock_plugin): + project.generate() + + role_path = project.root / "hook-role.yaml" + with role_path.open("r", encoding="utf-8") as f: + template = yaml.safe_load(f.read()) + + action_list = template["Resources"]["ExecutionRole"]["Properties"]["Policies"][0][ + "PolicyDocument" + ]["Statement"][0]["Action"] + + assert all(action in expected_actions for action in action_list) + assert len(action_list) == len(expected_actions) + assert template["Outputs"]["ExecutionRoleArn"] + mock_plugin.generate.assert_called_once_with(project) + + +@pytest.mark.parametrize( + "schema", + ( + {"handlers": {"preCreate": {"permissions": [""]}}}, + {"handlers": {"preCreate": {}}}, + ), +) +def test_generate_hook_handlers_deny_all(project, tmpdir, schema): + project.type_name = "Test::Handler::Test" + project.artifact_type = ARTIFACT_TYPE_HOOK + project.schema = schema + project.root = tmpdir + mock_plugin = MagicMock(spec=["generate"]) + with patch.object(project, "_plugin", mock_plugin): + project.generate() + + role_path = project.root / "hook-role.yaml" + with role_path.open("r", encoding="utf-8") as f: + template = yaml.safe_load(f.read()) + + statement = template["Resources"]["ExecutionRole"]["Properties"]["Policies"][0][ + "PolicyDocument" + ]["Statement"][0] + assert statement["Effect"] == "Deny" + assert statement["Action"][0] == "*" + mock_plugin.generate.assert_called_once_with(project) + + +@pytest.mark.parametrize( + "schema,result", + ( + ({"handlers": {"preCreate": {"timeoutInMinutes": 720}}}, 43200), + ({"handlers": {"preCreate": {"timeoutInMinutes": 2}}}, 3600), + ({"handlers": {"preCreate": {"timeoutInMinutes": 90}}}, 6300), + ( + { + "handlers": { + "preCreate": {"timeoutInMinutes": 70}, + "preUpdate": {"timeoutInMinutes": 90}, + } + }, + 6300, + ), + ({"handlers": {"preCreate": {}}}, 8400), + ({"handlers": {"preCreate": {"timeoutInMinutes": 90}, "preDelete": {}}}, 8400), + ), +) +def test_generate__hook_handlers_role_session_timeout(project, tmpdir, schema, result): + project.type_name = "Test::Handler::Test" + project.artifact_type = ARTIFACT_TYPE_HOOK + project.schema = schema + project.root = tmpdir + mock_plugin = MagicMock(spec=["generate"]) + with patch.object(project, "_plugin", mock_plugin): + project.generate() + + role_path = project.root / "hook-role.yaml" + with role_path.open("r", encoding="utf-8") as f: + template = yaml.safe_load(f.read()) + + max_session_timeout = template["Resources"]["ExecutionRole"]["Properties"][ + "MaxSessionDuration" + ] + assert max_session_timeout == result + + mock_plugin.generate.assert_called_once_with(project) + + +def test_init_hook(project): + type_name = "AWS::CFN::HOOK" + + mock_plugin = MagicMock(spec=["init"]) + patch_load_plugin = patch( + "rpdk.core.project.load_plugin", autospec=True, return_value=mock_plugin + ) + + with patch_load_plugin as mock_load_plugin: + project.init_hook(type_name, LANGUAGE) + + mock_load_plugin.assert_called_once_with(LANGUAGE) + mock_plugin.init.assert_called_once_with(project) + + assert project.type_info == ("AWS", "CFN", "HOOK") + assert project.type_name == type_name + assert project.language == LANGUAGE + assert project.artifact_type == ARTIFACT_TYPE_HOOK + assert project._plugin is mock_plugin + assert project.settings == {} + + with project.settings_path.open("r", encoding="utf-8") as f: + assert json.load(f) + + # ends with newline + with project.settings_path.open("rb") as f: + f.seek(-1, os.SEEK_END) + assert f.read() == b"\n" + + with project.schema_path.open("r", encoding="utf-8") as f: + assert json.load(f) + + # ends with newline + with project.schema_path.open("rb") as f: + f.seek(-1, os.SEEK_END) + assert f.read() == b"\n" + + def test_init_module(project): type_name = "AWS::Color::Red" @@ -707,6 +891,26 @@ def test_load_invalid_schema(project): assert "invalid" in str(excinfo.value) +def test_load_invalid_hook_schema(project): + project.artifact_type = "HOOK" + project.type_name = "AWS::CFN::HOOK" + patch_settings = patch.object( + project, "load_settings", return_value={"artifact_type": "HOOK"} + ) + patch_schema = patch.object( + project, "load_hook_schema", side_effect=SpecValidationError("") + ) + with patch_settings as mock_settings, patch_schema as mock_schema, pytest.raises( + InvalidProjectError + ) as excinfo: + project.load() + + mock_settings.assert_called_once_with() + mock_schema.assert_called_once_with() + + assert "invalid" in str(excinfo.value) + + def test_load_module_project_succeeds(project, tmp_path_factory): project.artifact_type = "MODULE" project.type_name = "Unit::Test::Malik::MODULE" @@ -740,6 +944,17 @@ def test_load_resource_succeeds(project): project.load() +def test_load_hook_succeeds(project): + project.artifact_type = "HOOK" + project.type_name = "AWS::CFN::HOOK" + patch_load_settings = patch.object( + project, "load_settings", return_values={"artifact_type": "HOOK"} + ) + project._write_example_hook_schema() + with patch_load_settings: + project.load() + + def test_load_module_project_with_invalid_fragments(project): project.artifact_type = "MODULE" project.type_name = "Unit::Test::Malik::MODULE" @@ -767,6 +982,26 @@ def test_schema_not_found(project): assert "not found" in str(excinfo.value) +def test_hook_schema_not_found(project): + project.artifact_type = "HOOK" + project.type_name = "AWS::CFN::HOOK" + patch_settings = patch.object( + project, "load_settings", return_value={"artifact_type": "HOOK"} + ) + patch_schema = patch.object( + project, "load_hook_schema", side_effect=FileNotFoundError + ) + with patch_settings as mock_settings, patch_schema as mock_schema, pytest.raises( + InvalidProjectError + ) as excinfo: + project.load() + + mock_settings.assert_called_once_with() + mock_schema.assert_called_once_with() + + assert "not found" in str(excinfo.value) + + def test_settings_not_found(project): patch_settings = patch.object( project, "load_settings", side_effect=FileNotFoundError @@ -802,6 +1037,51 @@ def create_input_file(base): f.write("{}") +def create_hook_input_file(base): + path = base / "inputs" + os.mkdir(path, mode=0o777) + + path_pre_create = base / PRE_CREATE_INPUTS_FILE + with path_pre_create.open("w", encoding="utf-8") as f: + f.write(json.dumps({TYPE_NAME: {"resourceProperties": {}}})) + + path_pre_update = base / PRE_UPDATE_INPUTS_FILE + with path_pre_update.open("w", encoding="utf-8") as f: + f.write( + json.dumps( + { + TYPE_NAME: { + "resourceProperties": {}, + "previousResourceProperties": {}, + } + } + ) + ) + + path_invalid_pre_delete = base / INVALID_PRE_DELETE_INPUTS_FILE + with path_invalid_pre_delete.open("w", encoding="utf-8") as f: + f.write(json.dumps({TYPE_NAME: {"resourceProperties": {}}})) + + path_invalid = base / INVALID_INPUTS_FILE + with path_invalid.open("w", encoding="utf-8") as f: + f.write(json.dumps({TYPE_NAME: {"resourceProperties": {}}})) + + +def _get_target_schema_filename(target_name): + return "{}.json".format("-".join(s.lower() for s in target_name.split("::"))) + + +def create_target_schema_file(base, target_schema): + path = base / "target-schemas" + os.mkdir(path, mode=0o777) + + schema_filename = _get_target_schema_filename(target_schema["typeName"]) + + path_target_schema = base / "target-schemas" / schema_filename + with path_target_schema.open("w", encoding="utf-8") as f: + f.write(json.dumps(target_schema, indent=4)) + + # pylint: disable=too-many-arguments, too-many-locals, too-many-statements @pytest.mark.parametrize("is_type_configuration_available", (False, True)) def test_submit_dry_run(project, is_type_configuration_available): @@ -967,56 +1247,270 @@ def test_submit_dry_run_modules(project): assert zip_file.testzip() is None -def test_submit_live_run(project): +# pylint: disable=too-many-arguments, too-many-locals, too-many-statements +def test_submit_dry_run_hooks(project): project.type_name = TYPE_NAME project.runtime = RUNTIME project.language = LANGUAGE - project.artifact_type = ARTIFACT_TYPE_RESOURCE + project.artifact_type = ARTIFACT_TYPE_HOOK + zip_path = project.root / "test.zip" with project.schema_path.open("w", encoding="utf-8") as f: f.write(CONTENTS_UTF8) - project.write_settings() + project.configuration_schema = { + "CloudFormationConfiguration": {"HookConfiguration": {"Properties": {}}} + } - temp_file = UnclosingBytesIO() + with project.overrides_path.open("w", encoding="utf-8") as f: + f.write(json.dumps(empty_hook_override())) + + create_input_file(project.root) + + project.write_settings() patch_plugin = patch.object(project, "_plugin", spec=LanguagePlugin) patch_upload = patch.object(project, "_upload", autospec=True) - patch_path = patch("rpdk.core.project.Path", autospec=True) - patch_temp = patch("rpdk.core.project.TemporaryFile", return_value=temp_file) + patch_path = patch("rpdk.core.project.Path", return_value=zip_path) + patch_temp = patch("rpdk.core.project.TemporaryFile", autospec=True) # fmt: off # these context managers can't be wrapped by black, but it removes the \ with patch_plugin as mock_plugin, patch_path as mock_path, \ patch_temp as mock_temp, patch_upload as mock_upload: + mock_plugin.get_plugin_information = MagicMock(return_value=PLUGIN_INFORMATION) + project.submit( - False, + True, endpoint_url=ENDPOINT, region_name=REGION, role_arn=None, use_role=True, - set_default=True + set_default=False ) # fmt: on - mock_path.assert_not_called() - mock_temp.assert_called_once_with("w+b") + mock_temp.assert_not_called() + mock_path.assert_called_with("{}.zip".format(project.hypenated_name)) mock_plugin.package.assert_called_once_with(project, ANY) + mock_upload.assert_not_called() - # zip file construction is tested by the dry-run test + file_set = { + SCHEMA_UPLOAD_FILENAME, + SETTINGS_FILENAME, + OVERRIDES_FILENAME, + CREATE_INPUTS_FILE, + INVALID_INPUTS_FILE, + UPDATE_INPUTS_FILE, + CFN_METADATA_FILENAME, + CONFIGURATION_SCHEMA_UPLOAD_FILENAME, + TARGET_INFO_FILENAME, + } + with zipfile.ZipFile(zip_path, mode="r") as zip_file: + assert set(zip_file.namelist()) == file_set - assert temp_file.tell() == 0 # file was rewound before upload - mock_upload.assert_called_once_with( - temp_file, - region_name=REGION, - endpoint_url=ENDPOINT, - role_arn=None, - use_role=True, - set_default=True, - ) + schema_contents = zip_file.read(SCHEMA_UPLOAD_FILENAME).decode("utf-8") + assert schema_contents == CONTENTS_UTF8 - assert temp_file._was_closed - temp_file._close() + configuration_schema_contents = zip_file.read( + CONFIGURATION_SCHEMA_UPLOAD_FILENAME + ).decode("utf-8") + assert configuration_schema_contents == json.dumps( + project.configuration_schema, indent=4 + ) + + settings = json.loads(zip_file.read(SETTINGS_FILENAME).decode("utf-8")) + assert settings["runtime"] == RUNTIME + overrides = json.loads(zip_file.read(OVERRIDES_FILENAME).decode("utf-8")) + assert "CREATE_PRE_PROVISION" in overrides + # https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile.testzip + assert zip_file.testzip() is None + metadata_info = json.loads(zip_file.read(CFN_METADATA_FILENAME).decode("utf-8")) + assert "cli-version" in metadata_info + assert "plugin-version" in metadata_info + assert "plugin-tool-version" in metadata_info + + +# pylint: disable=too-many-arguments, too-many-locals, too-many-statements +def test_submit_dry_run_hooks_with_target_info(project): + schema = { + "typeName": "AWS::FOO::BAR", + "description": "test schema", + "typeConfiguration": { + "properties": {"foo": {"type": "string"}}, + "additionalProperties": False, + }, + "handlers": { + "preCreate": { + "targetNames": [TYPE_NAME], + } + }, + "additionalProperties": False, + } + + target_info = { + TYPE_NAME: { + "TargetName": TYPE_NAME, + "TargetType": "RESOURCE", + "Schema": { + "typeName": TYPE_NAME, + "description": "test description", + "additionalProperties": False, + "properties": { + "Id": {"type": "string"}, + }, + "required": [], + "primaryIdentifier": ["/properties/Id"], + }, + "ProvisioningType": "FULLY_MUTTABLE", + "IsCfnRegistrySupportedType": True, + "SchemaFileAvailable": True, + }, + } + + project.type_name = TYPE_NAME + project.runtime = RUNTIME + project.language = LANGUAGE + project.artifact_type = ARTIFACT_TYPE_HOOK + project.schema = schema + zip_path = project.root / "test.zip" + + with project.schema_path.open("w", encoding="utf-8") as f: + f.write(json.dumps(schema, indent=4)) + + project.configuration_schema = { + "CloudFormationConfiguration": {"HookConfiguration": {"Properties": {}}} + } + + with project.overrides_path.open("w", encoding="utf-8") as f: + f.write(json.dumps(empty_hook_override())) + + with project.target_info_path.open("w", encoding="utf-8") as f: + f.write(json.dumps({TYPE_NAME: {"ProvisioningType": "FULLY_MUTTABLE"}})) + + create_hook_input_file(project.root) + + create_target_schema_file(project.root, target_info[TYPE_NAME]["Schema"]) + + project.write_settings() + + patch_plugin = patch.object(project, "_plugin", spec=LanguagePlugin) + patch_upload = patch.object(project, "_upload", autospec=True) + patch_path = patch("rpdk.core.project.Path", return_value=zip_path) + patch_temp = patch("rpdk.core.project.TemporaryFile", autospec=True) + + # fmt: off + # these context managers can't be wrapped by black, but it removes the \ + with patch_plugin as mock_plugin, patch_path as mock_path, \ + patch_temp as mock_temp, patch_upload as mock_upload: + mock_plugin.get_plugin_information = MagicMock(return_value=PLUGIN_INFORMATION) + + project.submit( + True, + endpoint_url=None, + region_name=REGION, + role_arn=None, + use_role=True, + set_default=False + ) + # fmt: on + + mock_temp.assert_not_called() + mock_path.assert_called_with("{}.zip".format(project.hypenated_name)) + mock_plugin.package.assert_called_once_with(project, ANY) + mock_upload.assert_not_called() + + file_set = { + SCHEMA_UPLOAD_FILENAME, + SETTINGS_FILENAME, + OVERRIDES_FILENAME, + PRE_CREATE_INPUTS_FILE, + PRE_UPDATE_INPUTS_FILE, + INVALID_PRE_DELETE_INPUTS_FILE, + INVALID_INPUTS_FILE, + CFN_METADATA_FILENAME, + CONFIGURATION_SCHEMA_UPLOAD_FILENAME, + TARGET_INFO_FILENAME, + "target-schemas/aws-color-red.json", + } + with zipfile.ZipFile(zip_path, mode="r") as zip_file: + assert set(zip_file.namelist()) == file_set + + schema_contents = zip_file.read(SCHEMA_UPLOAD_FILENAME).decode("utf-8") + assert json.loads(schema_contents) == schema + + configuration_schema_contents = zip_file.read( + CONFIGURATION_SCHEMA_UPLOAD_FILENAME + ).decode("utf-8") + assert configuration_schema_contents == json.dumps( + project.configuration_schema, indent=4 + ) + zip_file.printdir() + settings = json.loads(zip_file.read(SETTINGS_FILENAME).decode("utf-8")) + assert settings["runtime"] == RUNTIME + overrides = json.loads(zip_file.read(OVERRIDES_FILENAME).decode("utf-8")) + assert "CREATE_PRE_PROVISION" in overrides + assert target_info == json.loads( + zip_file.read(TARGET_INFO_FILENAME).decode("utf-8") + ) + # https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile.testzip + assert zip_file.testzip() is None + metadata_info = json.loads(zip_file.read(CFN_METADATA_FILENAME).decode("utf-8")) + assert "cli-version" in metadata_info + assert "plugin-version" in metadata_info + assert "plugin-tool-version" in metadata_info + + +def test_submit_live_run(project): + project.type_name = TYPE_NAME + project.runtime = RUNTIME + project.language = LANGUAGE + project.artifact_type = ARTIFACT_TYPE_RESOURCE + + with project.schema_path.open("w", encoding="utf-8") as f: + f.write(CONTENTS_UTF8) + + project.write_settings() + + temp_file = UnclosingBytesIO() + + patch_plugin = patch.object(project, "_plugin", spec=LanguagePlugin) + patch_upload = patch.object(project, "_upload", autospec=True) + patch_path = patch("rpdk.core.project.Path", autospec=True) + patch_temp = patch("rpdk.core.project.TemporaryFile", return_value=temp_file) + + # fmt: off + # these context managers can't be wrapped by black, but it removes the \ + with patch_plugin as mock_plugin, patch_path as mock_path, \ + patch_temp as mock_temp, patch_upload as mock_upload: + project.submit( + False, + endpoint_url=ENDPOINT, + region_name=REGION, + role_arn=None, + use_role=True, + set_default=True + ) + # fmt: on + + mock_path.assert_not_called() + mock_temp.assert_called_once_with("w+b") + mock_plugin.package.assert_called_once_with(project, ANY) + + # zip file construction is tested by the dry-run test + + assert temp_file.tell() == 0 # file was rewound before upload + mock_upload.assert_called_once_with( + temp_file, + region_name=REGION, + endpoint_url=ENDPOINT, + role_arn=None, + use_role=True, + set_default=True, + ) + + assert temp_file._was_closed + temp_file._close() def test_submit_live_run_for_module(project): @@ -1057,6 +1551,62 @@ def test_submit_live_run_for_module(project): temp_file._close() +def test_submit_live_run_for_hooks(project): + project.type_name = TYPE_NAME + project.runtime = RUNTIME + project.language = LANGUAGE + project.artifact_type = ARTIFACT_TYPE_HOOK + + with project.schema_path.open("w", encoding="utf-8") as f: + f.write(CONTENTS_UTF8) + + project.configuration_schema = { + "CloudFormationConfiguration": {"HookConfiguration": {"Properties": {}}} + } + + project.write_settings() + + temp_file = UnclosingBytesIO() + + patch_plugin = patch.object(project, "_plugin", spec=LanguagePlugin) + patch_upload = patch.object(project, "_upload", autospec=True) + patch_path = patch("rpdk.core.project.Path", autospec=True) + patch_temp = patch("rpdk.core.project.TemporaryFile", return_value=temp_file) + + # fmt: off + # these context managers can't be wrapped by black, but it removes the \ + with patch_plugin as mock_plugin, patch_path as mock_path, \ + patch_temp as mock_temp, patch_upload as mock_upload: + project.submit( + False, + endpoint_url=ENDPOINT, + region_name=REGION, + role_arn=None, + use_role=True, + set_default=True + ) + # fmt: on + + mock_path.assert_not_called() + mock_temp.assert_called_once_with("w+b") + mock_plugin.package.assert_called_once_with(project, ANY) + + # zip file construction is tested by the dry-run test + + assert temp_file.tell() == 0 # file was rewound before upload + mock_upload.assert_called_once_with( + temp_file, + region_name=REGION, + endpoint_url=ENDPOINT, + role_arn=None, + use_role=True, + set_default=True, + ) + + assert temp_file._was_closed + temp_file._close() + + def test__upload_good_path_create_role_and_set_default(project): project.type_name = TYPE_NAME project.artifact_type = ARTIFACT_TYPE_RESOURCE @@ -1110,6 +1660,59 @@ def test__upload_good_path_create_role_and_set_default(project): mock_wait.assert_called_once_with(mock_cfn_client, "foo", True) +def test__upload_good_path_create_role_and_set_default_hook(project): + project.type_name = TYPE_NAME + project.artifact_type = ARTIFACT_TYPE_HOOK + project.schema = {"handlers": {}} + + mock_cfn_client = MagicMock(spec=["register_type"]) + mock_cfn_client.register_type.return_value = {"RegistrationToken": "foo"} + fileobj = object() + + patch_sdk = patch("rpdk.core.project.create_sdk_session", autospec=True) + patch_uploader = patch.object(Uploader, "upload", return_value="url") + patch_exec_role_arn = patch.object( + Uploader, "create_or_update_role", return_value="some-execution-role-arn" + ) + patch_logging_role_arn = patch.object( + Uploader, "get_log_delivery_role_arn", return_value="some-log-role-arn" + ) + patch_uuid = patch("rpdk.core.project.uuid4", autospec=True, return_value="foo") + patch_wait = patch.object(project, "_wait_for_registration", autospec=True) + + with patch_sdk as mock_sdk, patch_uploader as mock_upload_method, patch_logging_role_arn as mock_role_arn_method, patch_exec_role_arn as mock_exec_role_method: # noqa: B950 as it conflicts with formatting rules # pylint: disable=C0301 + mock_sdk.return_value.client.side_effect = [mock_cfn_client, MagicMock()] + with patch_uuid as mock_uuid, patch_wait as mock_wait: + project._upload( + fileobj, + endpoint_url=None, + region_name=None, + role_arn=None, + use_role=True, + set_default=True, + ) + + mock_sdk.assert_called_once_with(None) + mock_exec_role_method.assert_called_once_with( + project.root / "hook-role.yaml", project.hypenated_name + ) + mock_upload_method.assert_called_once_with(project.hypenated_name, fileobj) + mock_role_arn_method.assert_called_once_with() + mock_uuid.assert_called_once_with() + mock_cfn_client.register_type.assert_called_once_with( + Type="HOOK", + TypeName=project.type_name, + SchemaHandlerPackage="url", + ClientRequestToken=mock_uuid.return_value, + LoggingConfig={ + "LogRoleArn": "some-log-role-arn", + "LogGroupName": "aws-color-red-logs", + }, + ExecutionRoleArn="some-execution-role-arn", + ) + mock_wait.assert_called_once_with(mock_cfn_client, "foo", True) + + @pytest.mark.parametrize( ("use_role,expected_additional_args"), [(True, {"ExecutionRoleArn": "someArn"}), (False, {})], @@ -1164,6 +1767,60 @@ def test__upload_good_path_skip_role_creation( ) +@pytest.mark.parametrize( + ("use_role,expected_additional_args"), + [(True, {"ExecutionRoleArn": "someArn"}), (False, {})], +) +def test__upload_good_path_skip_role_creation_hook( + project, use_role, expected_additional_args +): + project.type_name = TYPE_NAME + project.artifact_type = ARTIFACT_TYPE_HOOK + project.schema = {"handlers": {}} + + mock_cfn_client = MagicMock(spec=["register_type"]) + fileobj = object() + mock_cfn_client.register_type.return_value = {"RegistrationToken": "foo"} + + patch_sdk = patch("rpdk.core.project.create_sdk_session", autospec=True) + patch_uploader = patch.object(Uploader, "upload", return_value="url") + patch_logging_role_arn = patch.object( + Uploader, "get_log_delivery_role_arn", return_value="some-log-role-arn" + ) + patch_uuid = patch("rpdk.core.project.uuid4", autospec=True, return_value="foo") + patch_wait = patch.object(project, "_wait_for_registration", autospec=True) + + with patch_sdk as mock_sdk, patch_uploader as mock_upload_method, patch_logging_role_arn as mock_role_arn_method: # noqa: B950 as it conflicts with formatting rules # pylint: disable=C0301 + mock_sdk.return_value.client.side_effect = [mock_cfn_client, MagicMock()] + with patch_uuid as mock_uuid, patch_wait as mock_wait: + project._upload( + fileobj, + endpoint_url=None, + region_name=None, + role_arn="someArn", + use_role=use_role, + set_default=True, + ) + + mock_sdk.assert_called_once_with(None) + mock_upload_method.assert_called_once_with(project.hypenated_name, fileobj) + mock_role_arn_method.assert_called_once_with() + mock_uuid.assert_called_once_with() + mock_wait.assert_called_once_with(mock_cfn_client, "foo", True) + + mock_cfn_client.register_type.assert_called_once_with( + Type="HOOK", + TypeName=project.type_name, + SchemaHandlerPackage="url", + ClientRequestToken=mock_uuid.return_value, + LoggingConfig={ + "LogRoleArn": "some-log-role-arn", + "LogGroupName": "aws-color-red-logs", + }, + **expected_additional_args, + ) + + def test__upload_clienterror(project): project.type_name = TYPE_NAME project.artifact_type = ARTIFACT_TYPE_RESOURCE @@ -1258,6 +1915,53 @@ def test__upload_clienterror_module(project): ) +def test__upload_clienterror_hook(project): + project.type_name = TYPE_NAME + project.artifact_type = ARTIFACT_TYPE_HOOK + project.schema = {} + + mock_cfn_client = MagicMock(spec=["register_type"]) + mock_cfn_client.register_type.side_effect = ClientError( + BLANK_CLIENT_ERROR, "RegisterType" + ) + fileobj = object() + + patch_sdk = patch("rpdk.core.project.create_sdk_session", autospec=True) + patch_uploader = patch.object(Uploader, "upload", return_value="url") + patch_role_arn = patch.object( + Uploader, "get_log_delivery_role_arn", return_value="some-log-role-arn" + ) + patch_uuid = patch("rpdk.core.project.uuid4", autospec=True, return_value="foo") + + with patch_sdk as mock_sdk, patch_uploader as mock_upload_method, patch_role_arn as mock_role_arn_method: # noqa: B950 as it conflicts with formatting rules # pylint: disable=C0301 + mock_session = mock_sdk.return_value + mock_session.client.side_effect = [mock_cfn_client, MagicMock()] + with patch_uuid as mock_uuid, pytest.raises(DownstreamError): + project._upload( + fileobj, + endpoint_url=None, + region_name=None, + role_arn=None, + use_role=False, + set_default=True, + ) + + mock_sdk.assert_called_once_with(None) + mock_upload_method.assert_called_once_with(project.hypenated_name, fileobj) + mock_role_arn_method.assert_called_once_with() + mock_uuid.assert_called_once_with() + mock_cfn_client.register_type.assert_called_once_with( + Type="HOOK", + TypeName=project.type_name, + SchemaHandlerPackage="url", + ClientRequestToken=mock_uuid.return_value, + LoggingConfig={ + "LogRoleArn": "some-log-role-arn", + "LogGroupName": "aws-color-red-logs", + }, + ) + + def test__wait_for_registration_set_default(project): mock_cfn_client = MagicMock( spec=["describe_type_registration", "set_type_default_version", "get_waiter"] @@ -1475,3 +2179,315 @@ def test__write_settings_nonnull_executable_entrypoint(project): settings = json.load(f) assert "executableEntrypoint" in settings assert settings["executableEntrypoint"] == "executable_entrypoint" + + +def test__load_target_info_for_resource(project): + project.type_name = TYPE_NAME + project.artifact_type = ARTIFACT_TYPE_RESOURCE + project.schema = {"handlers": {}} + + target_info = project._load_target_info(endpoint_url=None, region_name=None) + + assert not target_info + + +def test__load_target_info_for_hooks(project): + project.type_name = HOOK_TYPE_NAME + project.artifact_type = ARTIFACT_TYPE_HOOK + project.schema = { + "handlers": { + "preCreate": { + "targetNames": ["AWS::TestHook::Target", "AWS::TestHook::OtherTarget"] + }, + "preUpdate": { + "targetNames": [ + "AWS::TestHookOne::Target", + "AWS::TestHookTwo::Target", + "AWS::ArrayHook::Target", + ] + }, + "preDelete": {"targetNames": ["AWS::TestHook::Target"]}, + } + } + + patch_sdk = patch("rpdk.core.type_schema_loader.create_sdk_session", autospec=True) + patch_loader_method = patch.object( + TypeSchemaLoader, + "load_type_schema", + side_effect=[ + { + "typeName": "AWS::TestHook::Target", + "description": "descript", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + }, + { + "typeName": "AWS::AnotherDiffHook::Target", + "description": "descript", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + }, + [ + { + "typeName": "AWS::TestHookOne::Target", + "description": "descript", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + }, + { + "typeName": "AWS::TestHookTwo::Target", + "description": "descript", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + }, + ], + [ + { + "typeName": "AWS::ArrayHook::Target", + "description": "descript", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + } + ], + ], + ) + patch_loader_cfn = patch.object( + TypeSchemaLoader, + "load_schema_from_cfn_registry", + return_value=( + { + "typeName": "AWS::TestHook::OtherTarget", + "description": "descript", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + }, + "RESOURCE", + "FULLY_MUTTABLE", + ), + ) + + patch_is_registry_checker = patch.object( + TypeSchemaLoader, + "get_provision_type", + side_effect={ + "AWS::TestHook::Target": "FULLY_MUTABLE", + "AWS::TestHookOne::Target": "IMMUTABLE", + "AWS::TestHookTwo::Target": "IMMUTABLE", + "AWS::ArrayHook::Target": "FULLY_MUTTABLE", + }.get, + ) + + # pylint: disable=line-too-long + with patch_sdk as mock_sdk, patch_loader_method as mock_loader_method, patch_loader_cfn as mock_loader_cfn, patch_is_registry_checker as mock_is_registry_checker: + mock_sdk.return_value.client.side_effect = [MagicMock(), MagicMock()] + target_info = project._load_target_info( + endpoint_url=None, + region_name=None, + provided_schemas=[ + "/files/target-schema.json", + "/files/target-schema-not-for-this-project.json", + "/files/list-of-target-schemas.json", + "/files/file-of-valid-json-array-with-a-target-schema.json", + ], + ) + + assert target_info == { + "AWS::TestHook::Target": { + "TargetName": "AWS::TestHook::Target", + "TargetType": "RESOURCE", + "Schema": { + "typeName": "AWS::TestHook::Target", + "description": "descript", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + }, + "ProvisioningType": "FULLY_MUTABLE", + "IsCfnRegistrySupportedType": True, + "SchemaFileAvailable": True, + }, + "AWS::TestHook::OtherTarget": { + "TargetName": "AWS::TestHook::OtherTarget", + "TargetType": "RESOURCE", + "Schema": { + "typeName": "AWS::TestHook::OtherTarget", + "description": "descript", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + }, + "ProvisioningType": "FULLY_MUTTABLE", + "IsCfnRegistrySupportedType": True, + "SchemaFileAvailable": True, + }, + "AWS::TestHookOne::Target": { + "TargetName": "AWS::TestHookOne::Target", + "TargetType": "RESOURCE", + "Schema": { + "typeName": "AWS::TestHookOne::Target", + "description": "descript", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + }, + "ProvisioningType": "IMMUTABLE", + "IsCfnRegistrySupportedType": True, + "SchemaFileAvailable": True, + }, + "AWS::TestHookTwo::Target": { + "TargetName": "AWS::TestHookTwo::Target", + "TargetType": "RESOURCE", + "Schema": { + "typeName": "AWS::TestHookTwo::Target", + "description": "descript", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + }, + "ProvisioningType": "IMMUTABLE", + "IsCfnRegistrySupportedType": True, + "SchemaFileAvailable": True, + }, + "AWS::ArrayHook::Target": { + "TargetName": "AWS::ArrayHook::Target", + "TargetType": "RESOURCE", + "Schema": { + "typeName": "AWS::ArrayHook::Target", + "description": "descript", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + }, + "ProvisioningType": "FULLY_MUTTABLE", + "IsCfnRegistrySupportedType": True, + "SchemaFileAvailable": True, + }, + } + + mock_sdk.assert_called_once_with(None) + assert mock_loader_method.call_args_list == [ + call("/files/target-schema.json"), + call("/files/target-schema-not-for-this-project.json"), + call("/files/list-of-target-schemas.json"), + call("/files/file-of-valid-json-array-with-a-target-schema.json"), + ] + TestCase().assertCountEqual( + mock_is_registry_checker.call_args_list, + [ + call("AWS::TestHook::Target", "RESOURCE"), + call("AWS::TestHookOne::Target", "RESOURCE"), + call("AWS::TestHookTwo::Target", "RESOURCE"), + call("AWS::ArrayHook::Target", "RESOURCE"), + ], + ) + mock_loader_cfn.assert_called_once_with("AWS::TestHook::OtherTarget", "RESOURCE") + + +def test__load_target_info_for_hooks_invalid_target_schema(project): + project.type_name = HOOK_TYPE_NAME + project.artifact_type = ARTIFACT_TYPE_HOOK + project.schema = { + "handlers": { + "preCreate": { + "targetNames": ["AWS::TestHook::Target", "AWS::TestHook::OtherTarget"] + }, + "preUpdate": { + "targetNames": [ + "AWS::TestHookOne::Target", + "AWS::TestHookTwo::Target", + "AWS::ArrayHook::Target", + ] + }, + "preDelete": {"targetNames": ["AWS::TestHook::Target"]}, + } + } + + patch_sdk = patch("rpdk.core.type_schema_loader.create_sdk_session", autospec=True) + patch_loader_method = patch.object( + TypeSchemaLoader, + "load_type_schema", + side_effect=[ + { + "typeName": "AWS::TestHook::Target", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + }, + ], + ) + patch_loader_cfn = patch.object( + TypeSchemaLoader, + "load_schema_from_cfn_registry", + return_value=( + { + "typeName": "AWS::TestHook::OtherTarget", + "description": "descript", + "properties": {"Name": {"type": "string"}}, + "primaryIdentifier": ["/properties/Name"], + "additionalProperties": False, + }, + "RESOURCE", + "FULLY_MUTTABLE", + ), + ) + + with patch_loader_cfn, pytest.raises( + InvalidProjectError + ), patch_sdk as mock_sdk, patch_loader_method as mock_loader_method: + mock_sdk.return_value.client.side_effect = [MagicMock(), MagicMock()] + project._load_target_info( + endpoint_url=None, + region_name=None, + provided_schemas=["/files/target-schema.json"], + ) + + mock_sdk.assert_called_once_with(None) + assert mock_loader_method.call_args_list == [ + call("/files/target-schema.json"), + ] + + +def test__load_target_info_for_hooks_duplicate_schemas(project): + project.type_name = HOOK_TYPE_NAME + project.artifact_type = ARTIFACT_TYPE_HOOK + project.schema = { + "handlers": { + "preCreate": { + "targetNames": ["AWS::TestHook::Target", "AWS::TestHook::OtherTarget"] + } + } + } + + patch_sdk = patch("rpdk.core.type_schema_loader.create_sdk_session", autospec=True) + patch_loader_method = patch.object( + TypeSchemaLoader, + "load_type_schema", + side_effect=[ + {"typeName": "AWS::TestHook::Target"}, + {"typeName": "AWS::TestHook::Target"}, + ], + ) + + with patch_sdk as mock_sdk, patch_loader_method as mock_loader_method: + mock_sdk.return_value.client.side_effect = [MagicMock(), MagicMock()] + with pytest.raises(InvalidProjectError): + project._load_target_info( + endpoint_url=None, + region_name=None, + provided_schemas=[ + "/files/target-schema.json", + "/files/target-schema-not-for-this-project.json", + ], + ) + + mock_sdk.assert_called_once_with(None) + assert mock_loader_method.call_args_list == [ + call("/files/target-schema.json"), + call("/files/target-schema-not-for-this-project.json"), + ] diff --git a/tests/test_test.py b/tests/test_test.py index a483dc2d..e8ad173a 100644 --- a/tests/test_test.py +++ b/tests/test_test.py @@ -1,5 +1,5 @@ # fixture and parameter have the same name -# pylint: disable=redefined-outer-name +# pylint: disable=protected-access,redefined-outer-name import json import os from contextlib import contextmanager @@ -9,23 +9,32 @@ import pytest from rpdk.core.cli import EXIT_UNHANDLED_EXCEPTION, main -from rpdk.core.contract.interface import Action +from rpdk.core.contract.interface import Action, HookInvocationPoint from rpdk.core.exceptions import SysExitRecommendedError -from rpdk.core.project import ARTIFACT_TYPE_MODULE, ARTIFACT_TYPE_RESOURCE, Project +from rpdk.core.project import ( + ARTIFACT_TYPE_HOOK, + ARTIFACT_TYPE_MODULE, + ARTIFACT_TYPE_RESOURCE, + Project, +) from rpdk.core.test import ( DEFAULT_ENDPOINT, DEFAULT_FUNCTION, DEFAULT_REGION, _validate_sam_args, + empty_hook_override, empty_override, + get_hook_overrides, get_inputs, get_marker_options, get_overrides, temporary_ini_file, ) +from rpdk.core.utils.handler_utils import generate_handler_name RANDOM_INI = "pytest_SOYPKR.ini" -EMPTY_OVERRIDE = empty_override() +EMPTY_RESOURCE_OVERRIDE = empty_override() +EMPTY_HOOK_OVERRIDE = empty_hook_override() ROLE_ARN = "role_arn" CREDENTIALS = { "AccessKeyId": object(), @@ -34,7 +43,46 @@ } -SCHEMA = {"handlers": {action.lower(): [] for action in Action}} +RESOURCE_SCHEMA = {"handlers": {generate_handler_name(action): [] for action in Action}} +HOOK_SCHEMA = { + "handlers": { + generate_handler_name(invoke_point): [] for invoke_point in HookInvocationPoint + } +} + +HOOK_TARGET_INFO = { + "My::Example::Resource": { + "TargetName": "My::Example::Resource", + "TargetType": "RESOURCE", + "Schema": { + "typeName": "My::Example::Resource", + "additionalProperties": False, + "properties": { + "Id": {"type": "string"}, + "Tags": { + "type": "array", + "uniqueItems": False, + "items": {"$ref": "#/definitions/Tag"}, + }, + }, + "required": [], + "definitions": { + "Tag": { + "type": "object", + "additionalProperties": False, + "properties": { + "Value": {"type": "string"}, + "Key": {"type": "string"}, + }, + "required": ["Value", "Key"], + } + }, + }, + "ProvisioningType": "FULLY_MUTTABLE", + "IsCfnRegistrySupportedType": True, + "SchemaFileAvailable": True, + } +} @pytest.fixture @@ -47,6 +95,21 @@ def mock_temporary_ini_file(): yield RANDOM_INI +def _get_expected_marker_options(artifact_type): + resource_actions = [op.lower() for op in Action] + hook_actions = [op.lower() for op in HookInvocationPoint] + all_actions = resource_actions + hook_actions + + if artifact_type == ARTIFACT_TYPE_HOOK: + included_actions = set(hook_actions) + else: + included_actions = set(resource_actions) + + return " and ".join( + ["not " + action for action in all_actions if action not in included_actions] + ) + + def create_input_file(base, create_string, update_string, invalid_string): path = base / "inputs" os.mkdir(path, mode=0o777) @@ -95,21 +158,22 @@ def create_invalid_input_file(base): ), ], ) -def test_test_command_happy_path( +def test_test_command_happy_path_resource( base, capsys, args_in, pytest_args, plugin_args ): # pylint: disable=too-many-locals create_input_file(base, '{"a": 1}', '{"a": 2}', '{"b": 1}') mock_project = Mock(spec=Project) - mock_project.schema = SCHEMA + mock_project.schema = RESOURCE_SCHEMA mock_project.root = base mock_project.executable_entrypoint = None mock_project.artifact_type = ARTIFACT_TYPE_RESOURCE + marker_options = _get_expected_marker_options(mock_project.artifact_type) patch_project = patch( "rpdk.core.test.Project", autospec=True, return_value=mock_project ) patch_plugin = patch("rpdk.core.test.ContractPlugin", autospec=True) - patch_client = patch("rpdk.core.test.ResourceClient", autospec=True) + patch_resource_client = patch("rpdk.core.test.ResourceClient", autospec=True) patch_pytest = patch("rpdk.core.test.pytest.main", autospec=True, return_value=0) patch_ini = patch( "rpdk.core.test.temporary_ini_file", side_effect=mock_temporary_ini_file @@ -117,7 +181,7 @@ def test_test_command_happy_path( # fmt: off with patch_project, \ patch_plugin as mock_plugin, \ - patch_client as mock_client, \ + patch_resource_client as mock_resource_client, \ patch_pytest as mock_pytest, \ patch_ini as mock_ini: main(args_in=["test"] + args_in) @@ -125,12 +189,12 @@ def test_test_command_happy_path( mock_project.load.assert_called_once_with() function_name, endpoint, region, enforce_timeout = plugin_args - mock_client.assert_called_once_with( + mock_resource_client.assert_called_once_with( function_name, endpoint, region, mock_project.schema, - EMPTY_OVERRIDE, + EMPTY_RESOURCE_OVERRIDE, {"CREATE": {"a": 1}, "UPDATE": {"a": 2}, "INVALID": {"b": 1}}, None, enforce_timeout, @@ -140,10 +204,93 @@ def test_test_command_happy_path( None, None, ) - mock_plugin.assert_called_once_with(mock_client.return_value) + mock_plugin.assert_called_once_with( + {"resource_client": mock_resource_client.return_value} + ) mock_ini.assert_called_once_with() mock_pytest.assert_called_once_with( - ["-c", RANDOM_INI, "-m", ""] + pytest_args, plugins=[mock_plugin.return_value] + ["-c", RANDOM_INI, "-m", marker_options] + pytest_args, + plugins=[mock_plugin.return_value], + ) + + _out, err = capsys.readouterr() + assert not err + + +@pytest.mark.parametrize( + "args_in,pytest_args,plugin_args", + [ + ([], [], [DEFAULT_FUNCTION, DEFAULT_ENDPOINT, DEFAULT_REGION, "30"]), + (["--endpoint", "foo"], [], [DEFAULT_FUNCTION, "foo", DEFAULT_REGION, "30"]), + ( + ["--function-name", "bar", "--enforce-timeout", "60"], + [], + ["bar", DEFAULT_ENDPOINT, DEFAULT_REGION, "60"], + ), + ( + ["--", "-k", "create"], + ["-k", "create"], + [DEFAULT_FUNCTION, DEFAULT_ENDPOINT, DEFAULT_REGION, "30"], + ), + ( + ["--region", "us-west-2", "--", "--collect-only"], + ["--collect-only"], + [DEFAULT_FUNCTION, DEFAULT_ENDPOINT, "us-west-2", "30"], + ), + ], +) +def test_test_command_happy_path_hook( + base, capsys, args_in, pytest_args, plugin_args +): # pylint: disable=too-many-locals + mock_project = Mock(spec=Project) + mock_project.schema = HOOK_SCHEMA + mock_project.root = base + mock_project.artifact_type = ARTIFACT_TYPE_HOOK + mock_project.executable_entrypoint = None + mock_project._load_target_info.return_value = HOOK_TARGET_INFO + marker_options = _get_expected_marker_options(mock_project.artifact_type) + + patch_project = patch( + "rpdk.core.test.Project", autospec=True, return_value=mock_project + ) + patch_plugin = patch("rpdk.core.test.ContractPlugin", autospec=True) + patch_hook_client = patch("rpdk.core.test.HookClient", autospec=True) + patch_pytest = patch("rpdk.core.test.pytest.main", autospec=True, return_value=0) + patch_ini = patch( + "rpdk.core.test.temporary_ini_file", side_effect=mock_temporary_ini_file + ) + # fmt: off + with patch_project, \ + patch_plugin as mock_plugin, \ + patch_hook_client as mock_hook_client, \ + patch_pytest as mock_pytest, \ + patch_ini as mock_ini: + main(args_in=["test"] + args_in) + # fmt: on + + mock_project.load.assert_called_once_with() + function_name, endpoint, region, enforce_timeout = plugin_args + mock_hook_client.assert_called_once_with( + function_name, + endpoint, + region, + mock_project.schema, + EMPTY_HOOK_OVERRIDE, + None, + None, + enforce_timeout, + mock_project.type_name, + None, + None, + None, + None, + HOOK_TARGET_INFO, + ) + mock_plugin.assert_called_once_with({"hook_client": mock_hook_client.return_value}) + mock_ini.assert_called_once_with() + mock_pytest.assert_called_once_with( + ["-c", RANDOM_INI, "-m", marker_options] + pytest_args, + plugins=[mock_plugin.return_value], ) _out, err = capsys.readouterr() @@ -154,7 +301,7 @@ def test_test_command_return_code_on_error(): mock_project = Mock(spec=Project) mock_project.root = None - mock_project.schema = SCHEMA + mock_project.schema = RESOURCE_SCHEMA mock_project.executable_entrypoint = None mock_project.artifact_type = ARTIFACT_TYPE_RESOURCE patch_project = patch( @@ -193,7 +340,7 @@ def test_temporary_ini_file(): def test_get_overrides_no_root(): - assert get_overrides(None, DEFAULT_REGION, "", None) == EMPTY_OVERRIDE + assert get_overrides(None, DEFAULT_REGION, "", None) == EMPTY_RESOURCE_OVERRIDE def test_get_overrides_file_not_found(base): @@ -202,20 +349,20 @@ def test_get_overrides_file_not_found(base): path.unlink() except FileNotFoundError: pass - assert get_overrides(path, DEFAULT_REGION, "", None) == EMPTY_OVERRIDE + assert get_overrides(path, DEFAULT_REGION, "", None) == EMPTY_RESOURCE_OVERRIDE def test_get_overrides_invalid_file(base): path = base / "overrides.json" path.write_text("{}") - assert get_overrides(base, DEFAULT_REGION, "", None) == EMPTY_OVERRIDE + assert get_overrides(base, DEFAULT_REGION, "", None) == EMPTY_RESOURCE_OVERRIDE def test_get_overrides_empty_overrides(base): path = base / "overrides.json" with path.open("w", encoding="utf-8") as f: - json.dump(EMPTY_OVERRIDE, f) - assert get_overrides(base, DEFAULT_REGION, "", None) == EMPTY_OVERRIDE + json.dump(EMPTY_RESOURCE_OVERRIDE, f) + assert get_overrides(base, DEFAULT_REGION, "", None) == EMPTY_RESOURCE_OVERRIDE def test_get_overrides_invalid_pointer_skipped(base): @@ -225,7 +372,7 @@ def test_get_overrides_invalid_pointer_skipped(base): path = base / "overrides.json" with path.open("w", encoding="utf-8") as f: json.dump(overrides, f) - assert get_overrides(base, DEFAULT_REGION, "", None) == EMPTY_OVERRIDE + assert get_overrides(base, DEFAULT_REGION, "", None) == EMPTY_RESOURCE_OVERRIDE def test_get_overrides_good_path(base): @@ -240,6 +387,41 @@ def test_get_overrides_good_path(base): } +def test_get_hook_overrides_no_root(): + assert get_hook_overrides(None, DEFAULT_REGION, "", None) == EMPTY_HOOK_OVERRIDE + + +def test_get_hook_overrides_file_not_found(base): + path = base / "overrides.json" + try: + path.unlink() + except FileNotFoundError: + pass + assert get_hook_overrides(path, DEFAULT_REGION, "", None) == EMPTY_HOOK_OVERRIDE + + +def test_get_hook_overrides_invalid_file(base): + path = base / "overrides.json" + path.write_text("{}") + assert get_hook_overrides(base, DEFAULT_REGION, "", None) == EMPTY_HOOK_OVERRIDE + + +def test_get_hook_overrides_good_path(base): + overrides = empty_hook_override() + overrides["CREATE_PRE_PROVISION"]["My::Example::Resource"] = { + "resourceProperties": {"/foo/bar": {}} + } + + path = base / "overrides.json" + with path.open("w", encoding="utf-8") as f: + json.dump(overrides, f) + assert get_hook_overrides(base, DEFAULT_REGION, "", None) == { + "CREATE_PRE_PROVISION": { + "My::Example::Resource": {"resourceProperties": {("foo", "bar"): {}}} + } + } + + @pytest.mark.parametrize( "overrides_string,list_exports_return_value,expected_overrides", [ @@ -298,7 +480,7 @@ def test_get_overrides_with_jinja( @pytest.mark.parametrize( "schema,expected_marker_keywords", [ - (SCHEMA, ""), + (RESOURCE_SCHEMA, ""), ( {"handlers": {"create": [], "read": [], "update": [], "delete": []}}, ("not list",), diff --git a/tests/test_type_schema_loader.py b/tests/test_type_schema_loader.py new file mode 100644 index 00000000..81a7da2d --- /dev/null +++ b/tests/test_type_schema_loader.py @@ -0,0 +1,542 @@ +# fixture and parameter have the same name +# pylint: disable=redefined-outer-name,useless-super-delegation,protected-access +import json +import unittest +from io import BytesIO +from unittest.mock import Mock, mock_open, patch + +import pytest +from botocore.exceptions import ClientError + +from rpdk.core.type_schema_loader import TypeSchemaLoader, is_valid_type_schema_uri + +TEST_TARGET_TYPE_NAME = "AWS::Test::Target" +TEST_TARGET_SCHEMA = { + "typeName": TEST_TARGET_TYPE_NAME, + "additionalProperties": False, + "properties": {}, + "required": [], +} +TEST_TARGET_SCHEMA_JSON = json.dumps(TEST_TARGET_SCHEMA) +TEST_TARGET_SCHEMA_JSON_ARRAY = json.dumps([TEST_TARGET_SCHEMA]) +OTHER_TEST_TARGET_FALLBACK_SCHEMA = { + "typeName": "AWS::Test::Backup", + "additionalProperties": False, + "properties": {}, + "required": [], +} +OTHER_TEST_TARGET_FALLBACK_SCHEMA_JSON = json.dumps(OTHER_TEST_TARGET_FALLBACK_SCHEMA) + +TEST_TARGET_SCHEMA_BUCKET = "TestTargetSchemaBucket" +TEST_TARGET_SCHEMA_KEY = "test-target-schema.json" +TEST_TARGET_SCHEMA_FILE_PATH = "/files/{}".format(TEST_TARGET_SCHEMA_KEY) +TEST_TARGET_SCHEMA_FILE_URI = "file://{}".format(TEST_TARGET_SCHEMA_FILE_PATH) +TEST_S3_TARGET_SCHEMA_URI = "s3://{}/{}".format( + TEST_TARGET_SCHEMA_BUCKET, TEST_TARGET_SCHEMA_KEY +) +TEST_HTTPS_TARGET_SCHEMA_URI = "https://{}.s3.us-west-2.amazonaws.com/{}".format( + TEST_TARGET_SCHEMA_BUCKET, TEST_TARGET_SCHEMA_KEY +) + + +# pylint: disable=C0103 +def assert_dict_equals(d1, d2): + unittest.TestCase().assertDictEqual(d1, d2) + + +@pytest.fixture +def loader(): + loader = TypeSchemaLoader(Mock(), Mock()) + return loader + + +def test_load_type_schema_from_json(loader): + with patch.object( + loader, "load_type_schema_from_json", wraps=loader.load_type_schema_from_json + ) as mock_load_json: + type_schema = loader.load_type_schema(TEST_TARGET_SCHEMA_JSON) + + assert_dict_equals(TEST_TARGET_SCHEMA, type_schema) + mock_load_json.assert_called_with(TEST_TARGET_SCHEMA_JSON, None) + + +def test_load_type_schema_from_invalid_json(loader): + with patch.object( + loader, "load_type_schema_from_json", wraps=loader.load_type_schema_from_json + ) as mock_load_json: + type_schema = loader.load_type_schema( + '{"Credentials" :{"ApiKey": "123", xxxx}}' + ) + + assert not type_schema + mock_load_json.assert_called_with('{"Credentials" :{"ApiKey": "123", xxxx}}', None) + + +def test_load_type_schema_from_invalid_json_fallback_to_default(loader): + with patch.object( + loader, "load_type_schema_from_json", wraps=loader.load_type_schema_from_json + ) as mock_load_json: + type_schema = loader.load_type_schema( + '{"Credentials" :{"ApiKey": "123", xxxx}}', + OTHER_TEST_TARGET_FALLBACK_SCHEMA, + ) + + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + mock_load_json.assert_called_with( + '{"Credentials" :{"ApiKey": "123", xxxx}}', OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + + +def test_load_type_schema_from_json_array(loader): + with patch.object( + loader, "load_type_schema_from_json", wraps=loader.load_type_schema_from_json + ) as mock_load_json: + type_schema = loader.load_type_schema(TEST_TARGET_SCHEMA_JSON_ARRAY) + + assert [TEST_TARGET_SCHEMA] == type_schema + mock_load_json.assert_called_with(TEST_TARGET_SCHEMA_JSON_ARRAY, None) + + +def test_load_type_schema_from_invalid_json_array(loader): + with patch.object( + loader, "load_type_schema_from_json", wraps=loader.load_type_schema_from_json + ) as mock_load_json: + type_schema = loader.load_type_schema('[{"Credentials" :{"ApiKey": "123"}}]]') + + assert not type_schema + mock_load_json.assert_called_with('[{"Credentials" :{"ApiKey": "123"}}]]', None) + + +def test_load_type_schema_from_invalid_json_array_fallback_to_default(loader): + with patch.object( + loader, "load_type_schema_from_json", wraps=loader.load_type_schema_from_json + ) as mock_load_json: + type_schema = loader.load_type_schema( + '[{"Credentials" :{"ApiKey": "123"}}]]', + OTHER_TEST_TARGET_FALLBACK_SCHEMA, + ) + + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + mock_load_json.assert_called_with( + '[{"Credentials" :{"ApiKey": "123"}}]]', OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + + +def test_load_type_schema_from_file(loader): + patch_file = patch("builtins.open", mock_open(read_data=TEST_TARGET_SCHEMA_JSON)) + patch_path_is_file = patch( + "rpdk.core.type_schema_loader.os.path.isfile", return_value=True + ) + patch_load_file = patch.object( + loader, "load_type_schema_from_file", wraps=loader.load_type_schema_from_file + ) + + with patch_file as mock_file, patch_path_is_file as mock_path_is_file, patch_load_file as mock_load_file: + type_schema = loader.load_type_schema(TEST_TARGET_SCHEMA_FILE_PATH) + + assert_dict_equals(TEST_TARGET_SCHEMA, type_schema) + mock_path_is_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH) + mock_load_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH, None) + mock_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH, "r") + + +def test_load_type_schema_from_file_file_not_found(loader): + patch_file = patch("builtins.open", mock_open()) + patch_path_is_file = patch( + "rpdk.core.type_schema_loader.os.path.isfile", return_value=True + ) + patch_load_file = patch.object( + loader, "load_type_schema_from_file", wraps=loader.load_type_schema_from_file + ) + + with patch_file as mock_file, patch_path_is_file as mock_path_is_file, patch_load_file as mock_load_file: + mock_file.side_effect = FileNotFoundError() + type_schema = loader.load_type_schema(TEST_TARGET_SCHEMA_FILE_PATH) + + assert not type_schema + mock_path_is_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH) + mock_load_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH, None) + mock_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH, "r") + + +def test_load_type_schema_from_file_error_fallback_to_default(loader): + patch_file = patch("builtins.open", mock_open()) + patch_path_is_file = patch( + "rpdk.core.type_schema_loader.os.path.isfile", return_value=True + ) + patch_load_file = patch.object( + loader, "load_type_schema_from_file", wraps=loader.load_type_schema_from_file + ) + + with patch_file as mock_file, patch_path_is_file as mock_path_is_file, patch_load_file as mock_load_file: + mock_file.side_effect = FileNotFoundError() + type_schema = loader.load_type_schema( + TEST_TARGET_SCHEMA_FILE_PATH, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + mock_path_is_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH) + mock_load_file.assert_called_with( + TEST_TARGET_SCHEMA_FILE_PATH, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + mock_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH, "r") + + +def test_load_type_schema_from_file_uri(loader): + patch_file = patch("builtins.open", mock_open(read_data=TEST_TARGET_SCHEMA_JSON)) + patch_load_from_uri = patch.object( + loader, "load_type_schema_from_uri", wraps=loader.load_type_schema_from_uri + ) + patch_load_file = patch.object( + loader, "load_type_schema_from_file", wraps=loader.load_type_schema_from_file + ) + + with patch_file as mock_file, patch_load_from_uri as mock_load_from_uri, patch_load_file as mock_load_file: + type_schema = loader.load_type_schema(TEST_TARGET_SCHEMA_FILE_URI) + + assert_dict_equals(TEST_TARGET_SCHEMA, type_schema) + mock_load_from_uri.assert_called_with(TEST_TARGET_SCHEMA_FILE_URI, None) + mock_load_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH, None) + mock_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH, "r") + + +def test_load_type_schema_from_file_uri_file_not_found(loader): + patch_file = patch("builtins.open", mock_open()) + patch_load_from_uri = patch.object( + loader, "load_type_schema_from_uri", wraps=loader.load_type_schema_from_uri + ) + patch_load_file = patch.object( + loader, "load_type_schema_from_file", wraps=loader.load_type_schema_from_file + ) + + with patch_file as mock_file, patch_load_from_uri as mock_load_from_uri, patch_load_file as mock_load_file: + mock_file.side_effect = FileNotFoundError() + type_schema = loader.load_type_schema(TEST_TARGET_SCHEMA_FILE_URI) + + assert not type_schema + mock_load_from_uri.assert_called_with(TEST_TARGET_SCHEMA_FILE_URI, None) + mock_load_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH, None) + mock_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH, "r") + + +def test_load_type_schema_from_file_uri_error_fallback_to_default(loader): + patch_file = patch("builtins.open", mock_open()) + patch_load_from_uri = patch.object( + loader, "load_type_schema_from_uri", wraps=loader.load_type_schema_from_uri + ) + patch_load_file = patch.object( + loader, "load_type_schema_from_file", wraps=loader.load_type_schema_from_file + ) + + with patch_file as mock_file, patch_load_from_uri as mock_load_from_uri, patch_load_file as mock_load_file: + mock_file.side_effect = FileNotFoundError() + type_schema = loader.load_type_schema( + TEST_TARGET_SCHEMA_FILE_URI, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + mock_load_from_uri.assert_called_with( + TEST_TARGET_SCHEMA_FILE_URI, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + mock_load_file.assert_called_with( + TEST_TARGET_SCHEMA_FILE_PATH, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + mock_file.assert_called_with(TEST_TARGET_SCHEMA_FILE_PATH, "r") + + +def test_load_type_schema_from_https_url(loader): + mock_request = Mock() + mock_request.status_code = 200 + mock_request.content = TEST_TARGET_SCHEMA_JSON.encode("utf-8") + + patch_get_request = patch( + "rpdk.core.type_schema_loader.requests.get", return_value=mock_request + ) + patch_load_from_uri = patch.object( + loader, "load_type_schema_from_uri", wraps=loader.load_type_schema_from_uri + ) + patch_get_from_url = patch.object( + loader, "_get_type_schema_from_url", wraps=loader._get_type_schema_from_url + ) + + with patch_get_request as mock_get_request, patch_load_from_uri as mock_load_from_uri, patch_get_from_url as mock_get_from_url: + type_schema = loader.load_type_schema(TEST_HTTPS_TARGET_SCHEMA_URI) + + assert_dict_equals(TEST_TARGET_SCHEMA, type_schema) + mock_load_from_uri.assert_called_with(TEST_HTTPS_TARGET_SCHEMA_URI, None) + mock_get_from_url.assert_called_with(TEST_HTTPS_TARGET_SCHEMA_URI, None) + mock_get_request.assert_called_with(TEST_HTTPS_TARGET_SCHEMA_URI, timeout=60) + + +def test_load_type_schema_from_https_url_unsuccessful(loader): + mock_request = Mock() + mock_request.status_code = 404 + + patch_get_request = patch( + "rpdk.core.type_schema_loader.requests.get", return_value=mock_request + ) + patch_load_from_uri = patch.object( + loader, "load_type_schema_from_uri", wraps=loader.load_type_schema_from_uri + ) + patch_get_from_url = patch.object( + loader, "_get_type_schema_from_url", wraps=loader._get_type_schema_from_url + ) + + with patch_get_request as mock_get_request, patch_load_from_uri as mock_load_from_uri, patch_get_from_url as mock_get_from_url: + type_schema = loader.load_type_schema(TEST_HTTPS_TARGET_SCHEMA_URI) + + assert not type_schema + mock_load_from_uri.assert_called_with(TEST_HTTPS_TARGET_SCHEMA_URI, None) + mock_get_from_url.assert_called_with(TEST_HTTPS_TARGET_SCHEMA_URI, None) + mock_get_request.assert_called_with(TEST_HTTPS_TARGET_SCHEMA_URI, timeout=60) + + +def test_load_type_schema_from_https_url_unsuccessful_fallback_to_default(loader): + mock_request = Mock() + mock_request.status_code = 404 + + patch_get_request = patch( + "rpdk.core.type_schema_loader.requests.get", return_value=mock_request + ) + patch_load_from_uri = patch.object( + loader, "load_type_schema_from_uri", wraps=loader.load_type_schema_from_uri + ) + patch_get_from_url = patch.object( + loader, "_get_type_schema_from_url", wraps=loader._get_type_schema_from_url + ) + + with patch_get_request as mock_get_request, patch_load_from_uri as mock_load_from_uri, patch_get_from_url as mock_get_from_url: + type_schema = loader.load_type_schema( + TEST_HTTPS_TARGET_SCHEMA_URI, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + mock_load_from_uri.assert_called_with( + TEST_HTTPS_TARGET_SCHEMA_URI, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + mock_get_from_url.assert_called_with( + TEST_HTTPS_TARGET_SCHEMA_URI, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + mock_get_request.assert_called_with(TEST_HTTPS_TARGET_SCHEMA_URI, timeout=60) + + +def test_load_type_schema_from_s3(loader): + loader.s3_client.get_object.return_value = { + "Body": BytesIO(TEST_TARGET_SCHEMA_JSON.encode("utf-8")) + } + + patch_load_from_uri = patch.object( + loader, "load_type_schema_from_uri", wraps=loader.load_type_schema_from_uri + ) + patch_get_from_s3 = patch.object( + loader, "_get_type_schema_from_s3", wraps=loader._get_type_schema_from_s3 + ) + + with patch_load_from_uri as mock_load_from_uri, patch_get_from_s3 as mock_get_from_s3: + type_schema = loader.load_type_schema(TEST_S3_TARGET_SCHEMA_URI) + + assert_dict_equals(TEST_TARGET_SCHEMA, type_schema) + mock_load_from_uri.assert_called_with(TEST_S3_TARGET_SCHEMA_URI, None) + mock_get_from_s3.assert_called_with( + TEST_TARGET_SCHEMA_BUCKET, TEST_TARGET_SCHEMA_KEY, None + ) + loader.s3_client.get_object.assert_called_once_with( + Bucket=TEST_TARGET_SCHEMA_BUCKET, Key=TEST_TARGET_SCHEMA_KEY + ) + + +def test_load_type_schema_from_s3_client_error(loader): + loader.s3_client.get_object.side_effect = ClientError( + {"Error": {"Code": "", "Message": "Bucket does not exist"}}, + "get_object", + ) + + patch_load_from_uri = patch.object( + loader, "load_type_schema_from_uri", wraps=loader.load_type_schema_from_uri + ) + patch_get_from_s3 = patch.object( + loader, "_get_type_schema_from_s3", wraps=loader._get_type_schema_from_s3 + ) + + with patch_load_from_uri as mock_load_from_uri, patch_get_from_s3 as mock_get_from_s3: + type_schema = loader.load_type_schema(TEST_S3_TARGET_SCHEMA_URI) + + assert not type_schema + mock_load_from_uri.assert_called_with(TEST_S3_TARGET_SCHEMA_URI, None) + mock_get_from_s3.assert_called_with( + TEST_TARGET_SCHEMA_BUCKET, TEST_TARGET_SCHEMA_KEY, None + ) + loader.s3_client.get_object.assert_called_once_with( + Bucket=TEST_TARGET_SCHEMA_BUCKET, Key=TEST_TARGET_SCHEMA_KEY + ) + + +def test_load_type_schema_from_s3_error_fallback_to_default(loader): + loader.s3_client.get_object.side_effect = ClientError( + {"Error": {"Code": "", "Message": "Bucket does not exist"}}, + "get_object", + ) + + patch_load_from_uri = patch.object( + loader, "load_type_schema_from_uri", wraps=loader.load_type_schema_from_uri + ) + patch_get_from_s3 = patch.object( + loader, "_get_type_schema_from_s3", wraps=loader._get_type_schema_from_s3 + ) + + with patch_load_from_uri as mock_load_from_uri, patch_get_from_s3 as mock_get_from_s3: + type_schema = loader.load_type_schema( + TEST_S3_TARGET_SCHEMA_URI, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + mock_load_from_uri.assert_called_with( + TEST_S3_TARGET_SCHEMA_URI, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + mock_get_from_s3.assert_called_with( + TEST_TARGET_SCHEMA_BUCKET, + TEST_TARGET_SCHEMA_KEY, + OTHER_TEST_TARGET_FALLBACK_SCHEMA, + ) + loader.s3_client.get_object.assert_called_once_with( + Bucket=TEST_TARGET_SCHEMA_BUCKET, Key=TEST_TARGET_SCHEMA_KEY + ) + + +def test_load_type_schema_from_cfn_registry(loader): + loader.cfn_client.describe_type.return_value = { + "Schema": TEST_TARGET_SCHEMA_JSON, + "Type": "RESOURCE", + "ProvisioningType": "FULLY_MUTABLE", + } + + type_schema, target_type, provisioning_type = loader.load_schema_from_cfn_registry( + TEST_TARGET_TYPE_NAME, "RESOURCE" + ) + + assert_dict_equals(TEST_TARGET_SCHEMA, type_schema) + assert target_type == "RESOURCE" + assert provisioning_type == "FULLY_MUTABLE" + loader.cfn_client.describe_type.assert_called_once_with( + Type="RESOURCE", TypeName=TEST_TARGET_TYPE_NAME + ) + + +def test_load_type_schema_from_cfn_registry_client_error(loader): + loader.cfn_client.describe_type.side_effect = ClientError( + {"Error": {"Code": "", "Message": "Type does not exist"}}, + "get_object", + ) + + type_schema, target_type, provisioning_type = loader.load_schema_from_cfn_registry( + TEST_TARGET_TYPE_NAME, "RESOURCE" + ) + + assert not type_schema + assert not target_type + assert not provisioning_type + loader.cfn_client.describe_type.assert_called_once_with( + Type="RESOURCE", TypeName=TEST_TARGET_TYPE_NAME + ) + + +def test_load_type_schema_from_cfn_registry_error_fallback_to_default(loader): + loader.cfn_client.describe_type.side_effect = ClientError( + {"Error": {"Code": "", "Message": "Type does not exist"}}, + "get_object", + ) + + type_schema, target_type, provisioning_type = loader.load_schema_from_cfn_registry( + TEST_TARGET_TYPE_NAME, "RESOURCE", OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + assert not target_type + assert not provisioning_type + loader.cfn_client.describe_type.assert_called_once_with( + Type="RESOURCE", TypeName=TEST_TARGET_TYPE_NAME + ) + + +def test_get_provision_type(loader): + loader.cfn_client.describe_type.return_value = { + "Schema": TEST_TARGET_SCHEMA_JSON, + "Type": "RESOURCE", + "ProvisioningType": "IMMUTABLE", + } + + provisioning_type = loader.get_provision_type(TEST_TARGET_TYPE_NAME, "RESOURCE") + + assert provisioning_type == "IMMUTABLE" + loader.cfn_client.describe_type.assert_called_once_with( + Type="RESOURCE", TypeName=TEST_TARGET_TYPE_NAME + ) + + +def test_get_provision_type_client_error(loader): + loader.cfn_client.describe_type.side_effect = ClientError( + {"Error": {"Code": "", "Message": "Type does not exist"}}, + "get_object", + ) + + provisioning_type = loader.get_provision_type(TEST_TARGET_TYPE_NAME, "RESOURCE") + + assert not provisioning_type + loader.cfn_client.describe_type.assert_called_once_with( + Type="RESOURCE", TypeName=TEST_TARGET_TYPE_NAME + ) + + +def test_load_type_schema_null_input(loader): + type_schema = loader.load_type_schema(None, OTHER_TEST_TARGET_FALLBACK_SCHEMA) + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + + type_schema = loader.load_type_schema_from_json( + None, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + + type_schema = loader.load_type_schema_from_uri( + None, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + + type_schema = loader.load_type_schema_from_file( + None, OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + + +def test_load_type_schema_invalid_input(loader): + type_schema = loader.load_type_schema( + "This is invalid input", OTHER_TEST_TARGET_FALLBACK_SCHEMA + ) + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + + with patch( + "rpdk.core.type_schema_loader.is_valid_type_schema_uri", return_value=True + ): + type_schema = loader.load_type_schema_from_uri( + "ftp://unsupportedurlschema.com/test-schema.json", + OTHER_TEST_TARGET_FALLBACK_SCHEMA, + ) + assert_dict_equals(OTHER_TEST_TARGET_FALLBACK_SCHEMA, type_schema) + + +@pytest.mark.parametrize( + "uri", + [ + TEST_TARGET_SCHEMA_FILE_URI, + TEST_HTTPS_TARGET_SCHEMA_URI, + TEST_S3_TARGET_SCHEMA_URI, + ], +) +def test_is_valid_type_schema_uri(uri): + assert is_valid_type_schema_uri(uri) + + +@pytest.mark.parametrize( + "uri", [None, "ftp://unsupportedurlschema.com/test-schema.json"] +) +def test_is_invalid_type_schema_uri(uri): + assert not is_valid_type_schema_uri(uri) diff --git a/tests/utils/test_handler_utils.py b/tests/utils/test_handler_utils.py new file mode 100644 index 00000000..df67d595 --- /dev/null +++ b/tests/utils/test_handler_utils.py @@ -0,0 +1,45 @@ +import pytest + +from rpdk.core.contract.interface import Action, HookInvocationPoint +from rpdk.core.utils.handler_utils import generate_handler_name + +RESOURCE_HANDLERS = { + Action.CREATE: "create", + Action.UPDATE: "update", + Action.DELETE: "delete", + Action.READ: "read", + Action.LIST: "list", +} + +HOOK_HANDLERS = { + HookInvocationPoint.CREATE_PRE_PROVISION: "preCreate", + HookInvocationPoint.UPDATE_PRE_PROVISION: "preUpdate", + HookInvocationPoint.DELETE_PRE_PROVISION: "preDelete", +} + + +def test_generate_handler_name(): + operation = "SOME_HANDLER_OPERATION" + expected_handler_name = "someHandlerOperation" + + handler_name = generate_handler_name(operation) + assert handler_name == expected_handler_name + + +@pytest.mark.parametrize( + "action", [Action.CREATE, Action.UPDATE, Action.DELETE, Action.READ, Action.LIST] +) +def test_generate_resource_handler_name(action): + assert generate_handler_name(action) == RESOURCE_HANDLERS[action] + + +@pytest.mark.parametrize( + "invoke_point", + [ + HookInvocationPoint.CREATE_PRE_PROVISION, + HookInvocationPoint.UPDATE_PRE_PROVISION, + HookInvocationPoint.DELETE_PRE_PROVISION, + ], +) +def test_generate_hook_handler_name(invoke_point): + assert generate_handler_name(invoke_point) == HOOK_HANDLERS[invoke_point] diff --git a/tests/utils/test_init_utils.py b/tests/utils/test_init_utils.py index 5e71220f..e0624ca2 100644 --- a/tests/utils/test_init_utils.py +++ b/tests/utils/test_init_utils.py @@ -70,11 +70,21 @@ def mock_validator(value): assert ERROR in out +def test_input_artifact_type_hook(): + artifact_type = "HOOK" + patch_input = patch("rpdk.core.utils.init_utils.input", return_value=artifact_type) + with patch_input as mock_input: + assert init_artifact_type() == artifact_type + mock_input.assert_called_once() + + def test_validate_artifact_type_valid(): assert validate_artifact_type("m") == "MODULE" assert validate_artifact_type("module") == "MODULE" assert validate_artifact_type("r") == "RESOURCE" assert validate_artifact_type("resource") == "RESOURCE" + assert validate_artifact_type("h") == "HOOK" + assert validate_artifact_type("hook") == "HOOK" def test_validate_artifact_type_invalid():