Skip to content

Commit

Permalink
Add signalCommand and also reset/signal/skipTimer client APIs (#17)
Browse files Browse the repository at this point in the history
* skip timer and reset

* test skip timer command

* add type hint

* done all

* Done signal channel

* fix timer

* fix wf type conflict in test
  • Loading branch information
longquanzheng committed Sep 12, 2023
1 parent bee7f4b commit 0756648
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 26 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ See more in https://github.com/indeedeng/iwf#what-is-iwf
- [x] Improve workflow uncompleted error return(canceled, failed, timeout, terminated)
- [x] Support execute API failure policy
- [x] Support workflow RPC
- [ ] Signal command
- [x] Signal command
- [x] Reset workflow API
- [x] Skip timer API for testing/operation

## Future -- the advanced features that already supported in server. Contributions are welcome to implement them in this SDK!
- [ ] Atomic conditional complete workflow by checking signal/internal channel emptiness
Expand All @@ -54,8 +56,6 @@ See more in https://github.com/indeedeng/iwf#what-is-iwf
- [ ] Describe workflow API
- [ ] TryGetWorkflowResults API
- [ ] Consume N messages in a single command
- [ ] Reset workflow API
- [ ] Skip timer API for testing/operation
- [ ] Decider trigger type: any command combination
- [ ] Failing workflow with results
- [ ] Wait_until API failure policy
Expand Down
63 changes: 61 additions & 2 deletions iwf/client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import inspect
from typing import Any, Callable, Optional, Type, TypeVar
from typing import Any, Callable, Optional, Type, TypeVar, Union

from iwf.client_options import ClientOptions
from iwf.errors import InvalidArgumentError
from iwf.registry import Registry
from iwf.reset_workflow_type_and_options import ResetWorkflowTypeAndOptions
from iwf.stop_workflow_options import StopWorkflowOptions
from iwf.unregistered_client import UnregisteredClient, UnregisteredWorkflowOptions
from iwf.workflow import ObjectWorkflow, get_workflow_type_by_class
from iwf.workflow_options import WorkflowOptions
from iwf.workflow_state import get_state_id, should_skip_wait_until
from iwf.workflow_state import (
WorkflowState,
get_state_id,
get_state_id_by_class,
should_skip_wait_until,
)
from iwf.workflow_state_options import _to_idl_state_options

T = TypeVar("T")
Expand Down Expand Up @@ -131,3 +137,56 @@ def invoke_rpc(
all_defined_search_attribute_types=[],
return_type_hint=return_type_hint,
)

def signal_workflow(
self,
workflow_id: str,
signal_channel_name: str,
signal_value: Optional[Any] = None,
):
return self._unregistered_client.signal_workflow(
workflow_id, "", signal_channel_name, signal_value
)

def reset_workflow(
self,
workflow_id: str,
reset_workflow_type_and_options: ResetWorkflowTypeAndOptions,
):
return self._unregistered_client.reset_workflow(
workflow_id, "", reset_workflow_type_and_options
)

def skip_timer_by_command_id(
self,
workflow_id: str,
workflow_state_id: str,
timer_command_id: str,
state_execution_number: int = 1,
):
return self._unregistered_client.skip_timer_by_command_id(
workflow_id,
"",
workflow_state_id,
timer_command_id=timer_command_id,
state_execution_number=state_execution_number,
)

def skip_timer_at_command_index(
self,
workflow_id: str,
workflow_state_id: Union[str, type[WorkflowState]],
state_execution_number: int = 1,
timer_command_index: int = 0,
):
if isinstance(workflow_state_id, type):
state_id = get_state_id_by_class(workflow_state_id)
else:
state_id = workflow_state_id
return self._unregistered_client.skip_timer_at_command_index(
workflow_id,
"",
state_id,
state_execution_number,
timer_command_index,
)
27 changes: 24 additions & 3 deletions iwf/command_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from iwf_api.models.inter_state_channel_command import (
InterStateChannelCommand as IdlInternalChannelCommand,
)
from iwf_api.models.signal_command import SignalCommand as IdlSignalCommand
from iwf_api.models.timer_command import TimerCommand as IdlTimerCommand


Expand Down Expand Up @@ -38,6 +39,18 @@ def by_name(cls, channel_name: str, command_id: Optional[str] = None):
)


@dataclass
class SignalChannelCommand:
command_id: str
channel_name: str

@classmethod
def by_name(cls, channel_name: str, command_id: Optional[str] = None):
return SignalChannelCommand(
command_id if command_id is not None else "", channel_name
)


BaseCommand = Union[TimerCommand, InternalChannelCommand]


Expand Down Expand Up @@ -72,14 +85,22 @@ def _to_idl_command_request(request: CommandRequest) -> IdlCommandRequest:
if isinstance(t, TimerCommand)
]

internal_channel_command = [
internal_channel_commands = [
IdlInternalChannelCommand(i.command_id, i.channel_name)
for i in request.commands
if isinstance(i, InternalChannelCommand)
]

signal_commands = [
IdlSignalCommand(i.command_id, i.channel_name)
for i in request.commands
if isinstance(i, SignalChannelCommand)
]

if len(timer_commands) > 0:
req.timer_commands = timer_commands
if len(internal_channel_command) > 0:
req.inter_state_channel_commands = internal_channel_command
if len(internal_channel_commands) > 0:
req.inter_state_channel_commands = internal_channel_commands
if len(signal_commands) > 0:
req.signal_commands = signal_commands
return req
26 changes: 25 additions & 1 deletion iwf/command_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,28 @@ class InternalChannelCommandResult:
command_id: str


@dataclass
class SignalChannelCommandResult:
channel_name: str
value: Any
status: ChannelRequestStatus
command_id: str


@dataclass
class CommandResults:
timer_commands: list[TimerCommandResult]
internal_channel_commands: list[InternalChannelCommandResult]
signal_channel_commands: list[SignalChannelCommandResult]


def from_idl_command_results(
idl_results: Union[Unset, IdlCommandResults],
internal_channel_types: dict[str, typing.Optional[type]],
signal_channel_types: dict[str, typing.Optional[type]],
object_encoder: ObjectEncoder,
) -> CommandResults:
results = CommandResults(list(), list())
results = CommandResults(list(), list(), list())
if isinstance(idl_results, Unset):
return results
if not isinstance(idl_results.timer_results, Unset):
Expand All @@ -58,4 +68,18 @@ def from_idl_command_results(
inter.command_id,
)
)

if not isinstance(idl_results.signal_results, Unset):
for sig in idl_results.signal_results:
results.signal_channel_commands.append(
SignalChannelCommandResult(
sig.signal_channel_name,
object_encoder.decode(
sig.signal_value,
signal_channel_types.get(sig.signal_channel_name),
),
sig.signal_request_status,
sig.command_id,
)
)
return results
14 changes: 14 additions & 0 deletions iwf/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Registry:
_starting_state_store: dict[str, WorkflowState]
_state_store: dict[str, dict[str, WorkflowState]]
_internal_channel_type_store: dict[str, dict[str, Optional[type]]]
_signal_channel_type_store: dict[str, dict[str, Optional[type]]]
_data_attribute_types: dict[str, dict[str, Optional[type]]]
_rpc_infos: dict[str, dict[str, RPCInfo]]

Expand All @@ -21,13 +22,15 @@ def __init__(self):
self._starting_state_store = dict()
self._state_store = dict()
self._internal_channel_type_store = dict()
self._signal_channel_type_store = dict()
self._data_attribute_types = dict()
self._rpc_infos = dict()

def add_workflow(self, wf: ObjectWorkflow):
self._register_workflow_type(wf)
self._register_workflow_state(wf)
self._register_internal_channels(wf)
self._register_signal_channels(wf)
self._register_data_attributes(wf)
self._register_workflow_rpcs(wf)

Expand Down Expand Up @@ -63,6 +66,9 @@ def get_state_store(self, wf_type: str) -> dict[str, WorkflowState]:
def get_internal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]:
return self._internal_channel_type_store[wf_type]

def get_signal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]:
return self._signal_channel_type_store[wf_type]

def get_data_attribute_types(self, wf_type: str) -> dict[str, Optional[type]]:
return self._data_attribute_types[wf_type]

Expand All @@ -83,6 +89,14 @@ def _register_internal_channels(self, wf: ObjectWorkflow):
types[method.name] = method.value_type
self._internal_channel_type_store[wf_type] = types

def _register_signal_channels(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
types: dict[str, Optional[type]] = {}
for method in wf.get_communication_schema().communication_methods:
if method.method_type == CommunicationMethodType.SignalChannel:
types[method.name] = method.value_type
self._signal_channel_type_store[wf_type] = types

def _register_data_attributes(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
types: dict[str, Optional[type]] = {}
Expand Down
92 changes: 92 additions & 0 deletions iwf/tests/test_signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import inspect
import time
import unittest

from iwf_api.models import ChannelRequestStatus

from iwf.client import Client
from iwf.command_request import (
CommandRequest,
SignalChannelCommand,
)
from iwf.command_results import CommandResults, SignalChannelCommandResult
from iwf.communication import Communication
from iwf.communication_schema import CommunicationMethod, CommunicationSchema
from iwf.persistence import Persistence
from iwf.state_decision import StateDecision
from iwf.state_schema import StateSchema
from iwf.tests.worker_server import registry
from iwf.workflow import ObjectWorkflow
from iwf.workflow_context import WorkflowContext
from iwf.workflow_state import T, WorkflowState

test_channel_int = "test-int"
test_channel_none = "test-none"
test_channel_str = "test-str"


class WaitState(WorkflowState[None]):
def wait_until(
self,
ctx: WorkflowContext,
input: T,
persistence: Persistence,
communication: Communication,
) -> CommandRequest:
return CommandRequest.for_all_command_completed(
SignalChannelCommand.by_name(test_channel_int),
SignalChannelCommand.by_name(test_channel_none),
SignalChannelCommand.by_name(test_channel_str),
)

def execute(
self,
ctx: WorkflowContext,
input: T,
command_results: CommandResults,
persistence: Persistence,
communication: Communication,
) -> StateDecision:
assert len(command_results.signal_channel_commands) == 3
sig1 = command_results.signal_channel_commands[0]
sig2 = command_results.signal_channel_commands[1]
sig3 = command_results.signal_channel_commands[2]
assert sig1 == SignalChannelCommandResult(
test_channel_int, 123, ChannelRequestStatus.RECEIVED, ""
)
assert sig2 == SignalChannelCommandResult(
test_channel_none, None, ChannelRequestStatus.RECEIVED, ""
)
assert sig3 == SignalChannelCommandResult(
test_channel_str, "abc", ChannelRequestStatus.RECEIVED, ""
)
return StateDecision.graceful_complete_workflow(sig3.value)


class WaitSignalWorkflow(ObjectWorkflow):
def get_communication_schema(self) -> CommunicationSchema:
return CommunicationSchema.create(
CommunicationMethod.signal_channel_def(test_channel_int, int),
CommunicationMethod.signal_channel_def(test_channel_none, type(None)),
CommunicationMethod.signal_channel_def(test_channel_str, str),
)

def get_workflow_states(self) -> StateSchema:
return StateSchema.with_starting_state(WaitState())


class TestSignal(unittest.TestCase):
@classmethod
def setUpClass(cls):
wf = WaitSignalWorkflow()
registry.add_workflow(wf)
cls.client = Client(registry)

def test_signal(self):
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"
self.client.start_workflow(WaitSignalWorkflow, wf_id, 1)
self.client.signal_workflow(wf_id, test_channel_int, 123)
self.client.signal_workflow(wf_id, test_channel_str, "abc")
self.client.signal_workflow(wf_id, test_channel_none)
res = self.client.get_simple_workflow_result_with_wait(wf_id)
assert res == "abc"
7 changes: 5 additions & 2 deletions iwf/tests/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def wait_until(
communication: Communication,
) -> CommandRequest:
return CommandRequest.for_all_command_completed(
TimerCommand.timer_command_by_duration(timedelta(seconds=input))
TimerCommand.timer_command_by_duration(timedelta(hours=input)),
TimerCommand.timer_command_by_duration(timedelta(seconds=input)),
)

def execute(
Expand Down Expand Up @@ -52,7 +53,9 @@ def test_timer_workflow():
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"

client.start_workflow(TimerWorkflow, wf_id, 100, 5)
time.sleep(1)
client.skip_timer_at_command_index(wf_id, WaitState)
start_ms = time.time_ns() / 1000000
client.get_simple_workflow_result_with_wait(wf_id, None)
elapsed_ms = time.time_ns() / 1000000 - start_ms
assert 4000 <= elapsed_ms <= 7000, f"expected 5000 ms timer, actual is {elapsed_ms}"
assert 3000 <= elapsed_ms <= 6000, f"expected 5000 ms timer, actual is {elapsed_ms}"

0 comments on commit 0756648

Please sign in to comment.