Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: aws sfn added #26784

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyright/master/requirements-pinned.txt
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ mypy-boto3-emr==1.35.68
mypy-boto3-emr-serverless==1.35.79
mypy-boto3-glue==1.35.80
mypy-boto3-logs==1.35.81
mypy-boto3-stepfunctions==1.35.68
mypy-boto3-s3==1.35.81
mypy-extensions==1.0.0
mypy-protobuf==3.6.0
Expand Down
160 changes: 160 additions & 0 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/sfn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import time
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast

import boto3
import dagster._check as check
from botocore.exceptions import ClientError
from dagster import MetadataValue, PipesClient
from dagster._annotations import experimental, public
from dagster._core.definitions.metadata import RawMetadataMapping
from dagster._core.definitions.resource_annotation import TreatAsResourceParam
from dagster._core.errors import DagsterExecutionInterruptedError
from dagster._core.execution.context.asset_execution_context import AssetExecutionContext
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.pipes.client import (
PipesClientCompletedInvocation,
PipesContextInjector,
PipesMessageReader,
)
from dagster._core.pipes.utils import open_pipes_session

from dagster_aws.pipes.message_readers import PipesCloudWatchMessageReader

if TYPE_CHECKING:
from mypy_boto3_stepfunctions import SFNClient
from mypy_boto3_stepfunctions.type_defs import (
DescribeExecutionOutputTypeDef,
StartExecutionInputRequestTypeDef,
)


@experimental
class PipesSFNClient(PipesClient, TreatAsResourceParam):
"""A pipes client for invoking AWS Step Functions.

Args:
context_injector (Optional[PipesContextInjector]): A context injector to use to inject
context into the Step Function execution.
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the Step Function execution. Defaults to :py:class:`PipesCloudWatchMessageReader`.
client (Optional[boto3.client]): The boto Step Functions client used to start the execution.
forward_termination (bool): Whether to cancel the Step Function execution when the Dagster process receives a termination signal.
"""

def __init__(
self,
context_injector: PipesContextInjector,
message_reader: Optional[PipesMessageReader] = None,
client: Optional["SFNClient"] = None,
forward_termination: bool = True,
):
self._client: SFNClient = client or boto3.client("stepfunctions")
self._context_injector = context_injector
self._message_reader = message_reader or PipesCloudWatchMessageReader()
self.forward_termination = check.bool_param(forward_termination, "forward_termination")

@classmethod
def _is_dagster_maintained(cls) -> bool:
return True

@public
def run(
self,
*,
context: Union[OpExecutionContext, AssetExecutionContext],
start_execution_input: "StartExecutionInputRequestTypeDef",
extras: Optional[Dict[str, Any]] = None,
) -> PipesClientCompletedInvocation:
"""Start a Step Function execution, enriched with the pipes protocol.

See also: `AWS API Documentation <https://docs.aws.amazon.com/step-functions/latest/apireference/API_StartExecution.html>`_

Args:
context (Union[OpExecutionContext, AssetExecutionContext]): The context of the currently executing Dagster op or asset.
start_execution_input (Dict): Parameters for the ``start_execution`` boto3 Step Functions client call.
extras (Optional[Dict[str, Any]]): Additional Dagster metadata to pass to the Step Function execution.

Returns:
PipesClientCompletedInvocation: Wrapper containing results reported by the external
process.
"""
params = start_execution_input

state_machine_arn = cast(str, params["stateMachineArn"])

with open_pipes_session(
context=context,
message_reader=self._message_reader,
context_injector=self._context_injector,
extras=extras,
) as session:
_ = session.get_bootstrap_cli_arguments()

try:
execution_arn = self._client.start_execution(**params)["executionArn"]

except ClientError as err:
error_info = err.response.get("Error", {})
context.log.error(
"Couldn't start execution %s. Here's why: %s: %s",
state_machine_arn,
error_info.get("Code", "Unknown error"),
error_info.get("Message", "No error message available"),
)
raise

response = self._client.describe_execution(executionArn=execution_arn)
context.log.info(
f"Started AWS Step Function execution {state_machine_arn} run: {execution_arn}"
)

try:
response = self._wait_for_execution_completion(execution_arn)
except DagsterExecutionInterruptedError:
if self.forward_termination:
self._terminate_execution(context=context, execution_arn=execution_arn)
raise

status = response["status"]
if status != "SUCCEEDED":
raise RuntimeError(
f"Step Function execution {state_machine_arn} run {execution_arn} completed with status {status} :\n{response.get('errorMessage')}"
)
else:
context.log.info(
f"Step Function execution {state_machine_arn} run {execution_arn} completed successfully"
)

return PipesClientCompletedInvocation(
session, metadata=self._extract_dagster_metadata(response)
)

def _wait_for_execution_completion(
self, execution_arn: str
) -> "DescribeExecutionOutputTypeDef":
while True:
response = self._client.describe_execution(executionArn=execution_arn)
if response["status"] in ["FAILED", "SUCCEEDED", "TIMED_OUT", "ABORTED"]:
return response
time.sleep(5)

def _extract_dagster_metadata(
self, response: "DescribeExecutionOutputTypeDef"
) -> RawMetadataMapping:
metadata: RawMetadataMapping = {}
for key, value in response.items():
metadata[key] = MetadataValue.text(str(value))

return metadata

def _terminate_execution(
self, context: Union[OpExecutionContext, AssetExecutionContext], execution_arn: str
):
"""Creates a handler which will gracefully stop the Run in case of external termination.
It will stop the Step Function execution before doing so.
"""
context.log.warning(
f"[pipes] execution interrupted, stopping Step Function execution {execution_arn}..."
)
self._client.stop_execution(executionArn=execution_arn)
context.log.warning(f"Successfully stopped Step Function execution {execution_arn}.")
Empty file.
155 changes: 155 additions & 0 deletions python_modules/libraries/dagster-aws/dagster_aws/sfn/sfn_launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import logging
import time
from functools import cached_property
from typing import TYPE_CHECKING, Optional

import boto3
from botocore.exceptions import ClientError
from dagster import (
Field,
StringSource,
_check as check,
)
from dagster._core.errors import DagsterExecutionInterruptedError
from dagster._core.instance import T_DagsterInstance
from dagster._core.launcher.base import LaunchRunContext, RunLauncher
from dagster._serdes import ConfigurableClass
from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_fixed

if TYPE_CHECKING:
from dagster._serdes.config_class import ConfigurableClassData
from mypy_boto3_stepfunctions import SFNClient
from mypy_boto3_stepfunctions.type_defs import (
DescribeExecutionOutputTypeDef,
StartExecutionInputRequestTypeDef,
StopExecutionOutputTypeDef,
)


SFN_FINISHED_STATUSES = ["FAILED", "SUCCEEDED", "TIMED_OUT", "ABORTED"]


class SFNFinishedExecutioinError(Exception):
pass


class SFNLauncher(RunLauncher[T_DagsterInstance], ConfigurableClass):
def __init__(
self,
inst_data: Optional["ConfigurableClassData"],
sfn_arn: str,
name: Optional[str] = None,
input: Optional[str] = None, # noqa: A002
trace_header: Optional[str] = None,
) -> None:
self._client: SFNClient = boto3.client("stepfunctions")
self._inst_data = inst_data
self._sfn_arn = check.str_param(sfn_arn, "sfn_arn")
self._name = check.opt_str_param(name, "name") if name is not None else None
self._input = check.opt_str_param(input, "input") if name is not None else None
self._trace_header = (
check.opt_str_param(trace_header, "trace_header") if trace_header is not None else None
)

@cached_property
def _start_execution_input(self) -> "StartExecutionInputRequestTypeDef":
input_dict: "StartExecutionInputRequestTypeDef" = {"stateMachineArn": self._sfn_arn}
if self._name is not None:
input_dict["name"] = self._name
if self._input is not None:
input_dict["input"] = self._input
if self._trace_header is not None:
input_dict["traceHeader"] = self._trace_header
return input_dict

def launch_run(self, context: LaunchRunContext) -> None:
try:
execution_arn = self._client.start_execution(**self._start_execution_input)[
"executionArn"
]

except ClientError as err:
error_info = err.response.get("Error", {})
logging.error(
"Couldn't start execution %s. Here's why: %s: %s",
self._sfn_arn,
error_info.get("Code", "Unknown error"),
error_info.get("Message", "No error message available"),
)
raise

try:
response = self._wait_for_execution_completion(execution_arn)
except DagsterExecutionInterruptedError as e:
logging.error(f"Error waiting for execution completion: {e}")
_ = self._stop_execution(execution_arn)
raise
else:
logging.info(
f"Execution {execution_arn} completed successfully, Status: {response['status']}"
)
if response["status"] != "SUCCEEDED":
raise SFNFinishedExecutioinError(
f"Step Function execution {self._sfn_arn} run {execution_arn} completed with status {response['status']} :\n{response.get('errorMessage')}",
)

@retry(
retry=retry_if_exception_type(ClientError),
stop=stop_after_delay(5),
wait=wait_fixed(5),
)
def _stop_execution(self, execution_arn: str) -> "StopExecutionOutputTypeDef":
logging.info(f"Terminating execution {execution_arn}")
try:
response = self._client.stop_execution(executionArn=execution_arn)
except ClientError as err:
logging.error(f"Couldn't terminate execution {execution_arn}. Here's why: {err}")
raise
else:
logging.info(
f"Execution {execution_arn} terminated, Stop execution response: {response}"
)
return response

def terminate(self, run_id):
check.not_implemented("Termination not supported.")

def _wait_for_execution_completion(
self, execution_arn: str
) -> "DescribeExecutionOutputTypeDef":
while True:
response = self._describe_execution(execution_arn)
if response["status"] in SFN_FINISHED_STATUSES:
return response
time.sleep(5)

@retry(
retry=retry_if_exception_type(ClientError),
stop=stop_after_delay(5),
wait=wait_fixed(5),
)
def _describe_execution(self, execution_arn: str) -> "DescribeExecutionOutputTypeDef":
try:
response = self._client.describe_execution(executionArn=execution_arn)
except ClientError as err:
logging.error(f"Couldn't describe execution {execution_arn}. Here's why: {err}")
raise
else:
return response

@property
def inst_data(self) -> Optional["ConfigurableClassData"]:
return self._inst_data

@classmethod
def config_type(cls):
return {
"sfn_arn": Field(StringSource, is_required=True),
"name": Field(StringSource, is_required=False),
"input": Field(StringSource, is_required=False),
"trace_header": Field(StringSource, is_required=False),
}

@classmethod
def from_config_value(cls, inst_data: Optional["ConfigurableClassData"], config_value):
return cls(inst_data=inst_data, **config_value)
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def test_run_task(ecs, ec2, subnet):
)
response = ecs.run_task(taskDefinition="container")
assert response["tasks"][0]["containers"]
# ECS does not expose the task definition's environment when
# does not expose the task definition's environment when
# describing tasks
assert "FOO" not in response

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from datetime import datetime
from typing import Any, Dict, Optional


class LocalSfnMockClient:
def __init__(self):
self.default_dt = datetime(2024, 1, 1, 12, 0, 0)
self.execution_arn = (
"arn:aws:states:us-east-1:123456789012:execution:StateMachine:execution-id"
)
self.state_machine_arn = "arn:aws:states:us-east-1:123456789012:stateMachine:StateMachine"

def start_execution(
self,
stateMachineArn: str,
input: Optional[str] = None, # noqa: A002 # AWS API parameter
name: Optional[str] = None,
) -> Dict[str, Any]:
return {"executionArn": self.execution_arn, "startDate": self.default_dt}

def describe_execution(self, executionArn: str):
return {
"executionArn": self.execution_arn,
"stateMachineArn": self.state_machine_arn,
"name": "execution-id",
"status": "SUCCEEDED",
"startDate": self.default_dt,
"input": '{"key": "value"}',
}

def stop_execution(self, executionArn: str):
return {"executionArn": executionArn, "stopDate": self.default_dt}
Loading