Skip to content

Commit

Permalink
Allow rpc input use different parameter name (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
longquanzheng committed Sep 12, 2023
1 parent ad60b8b commit bee7f4b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
19 changes: 10 additions & 9 deletions iwf/rpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from functools import wraps
from inspect import signature
from typing import Any, Callable, Optional
from typing import Callable, Optional

from iwf_api.models import PersistenceLoadingPolicy, PersistenceLoadingType

Expand All @@ -21,7 +21,7 @@ class RPCInfo:

rpc_definition_err = WorkflowDefinitionError(
"an RPC must have at most 5 params: self, context:WorkflowContext, input:Any, persistence:Persistence, "
'communication:Communication, where input can be any type as long as the param name is "input" '
"communication:Communication, where input can be any type"
)


Expand Down Expand Up @@ -49,9 +49,8 @@ def wrapper(*args, **kwargs):
from iwf.workflow_context import WorkflowContext
from iwf.communication import Communication

valid_param_types = {
valid_param_types_exclude_input = {
_empty: True,
Any: True,
Persistence: True,
WorkflowContext: True,
Communication: True,
Expand All @@ -61,16 +60,18 @@ def wrapper(*args, **kwargs):
if len(params) > 5:
raise rpc_definition_err

has_input = False
for k, v in params.items():
if k != "self":
params_order.append(v.annotation)
if k == "input":
rpc_info.input_type = v.annotation
continue

if v.annotation == Persistence:
need_persistence = True
if v.annotation not in valid_param_types:
raise rpc_definition_err
if v.annotation not in valid_param_types_exclude_input:
if not has_input:
has_input = True
else:
raise rpc_definition_err
if not need_persistence:
rpc_info.data_attribute_loading_policy = PersistenceLoadingPolicy(
persistence_loading_type=PersistenceLoadingType.LOAD_NONE
Expand Down
8 changes: 6 additions & 2 deletions iwf/tests/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def test_rpc_persistence_read(self, pers: Persistence):
return pers.get_data_attribute(test_data_attribute)

@rpc()
def test_rpc_trigger_state(self, com: Communication):
def test_rpc_trigger_state(self, pers: Persistence, com: Communication, i: int):
pers.set_data_attribute(test_data_attribute, i)
com.trigger_state_execution(WaitState)

@rpc()
Expand Down Expand Up @@ -117,7 +118,10 @@ def test_complicated_rpc(self):
self.client.invoke_rpc(wf_id, RPCWorkflow.test_rpc_persistence_write, 100)
res = self.client.invoke_rpc(wf_id, RPCWorkflow.test_rpc_persistence_read)
assert res == 100
self.client.invoke_rpc(wf_id, RPCWorkflow.test_rpc_trigger_state)
self.client.invoke_rpc(wf_id, RPCWorkflow.test_rpc_trigger_state, 200)
res = self.client.invoke_rpc(wf_id, RPCWorkflow.test_rpc_persistence_read)
assert res == 200

self.client.invoke_rpc(wf_id, RPCWorkflow.test_rpc_publish_channel)
result = self.client.get_simple_workflow_result_with_wait(wf_id, str)
assert result == "done"

0 comments on commit bee7f4b

Please sign in to comment.