Skip to content

Commit 6a4c2d8

Browse files
authored
feat(remote): Support branch node execution sync (#3353)
1 parent a13edac commit 6a4c2d8

File tree

1 file changed

+32
-9
lines changed

1 file changed

+32
-9
lines changed

flytekit/remote/remote.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,14 @@
9999
from flytekit.models.project import Project
100100
from flytekit.remote.backfill import create_backfill_workflow
101101
from flytekit.remote.data import download_literal
102-
from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow
102+
from flytekit.remote.entities import (
103+
FlyteBranchNode,
104+
FlyteLaunchPlan,
105+
FlyteNode,
106+
FlyteTask,
107+
FlyteTaskNode,
108+
FlyteWorkflow,
109+
)
103110
from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution
104111
from flytekit.remote.interface import TypedInterface
105112
from flytekit.remote.lazy_entity import LazyEntity
@@ -2693,11 +2700,23 @@ def sync_node_execution(
26932700

26942701
# Handle the case where it's a branch node
26952702
elif execution._node.branch_node is not None:
2696-
logger.info(
2697-
"Skipping branch node execution for now - branch nodes will "
2698-
"not have inputs and outputs filled in"
2699-
)
2700-
return execution
2703+
sub_flyte_workflow = typing.cast(FlyteBranchNode, execution._node.flyte_entity)
2704+
sub_node_mapping = {}
2705+
if sub_flyte_workflow.if_else.case.then_node:
2706+
then_node = sub_flyte_workflow.if_else.case.then_node
2707+
sub_node_mapping[then_node.id] = then_node
2708+
if sub_flyte_workflow.if_else.other:
2709+
for case in sub_flyte_workflow.if_else.other:
2710+
then_node = case.then_node
2711+
sub_node_mapping[then_node.id] = then_node
2712+
if sub_flyte_workflow.if_else.else_node:
2713+
else_node = sub_flyte_workflow.if_else.else_node
2714+
sub_node_mapping[else_node.id] = else_node
2715+
2716+
execution._underlying_node_executions = [
2717+
self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping)
2718+
for cne in child_node_executions
2719+
]
27012720
else:
27022721
logger.error(f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}")
27032722
raise ValueError(f"Node execution undeterminable, entity has type {type(execution._node)}")
@@ -2839,15 +2858,19 @@ def _assign_inputs_and_outputs(
28392858
self,
28402859
execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution],
28412860
execution_data,
2842-
interface: TypedInterface,
2861+
interface: typing.Optional[TypedInterface] = None,
28432862
):
28442863
"""Helper for assigning synced inputs and outputs to an execution object."""
28452864
input_literal_map = self._get_input_literal_map(execution_data)
2846-
execution._inputs = LiteralsResolver(input_literal_map.literals, interface.inputs, self.context)
2865+
execution._inputs = LiteralsResolver(
2866+
input_literal_map.literals, interface.inputs if interface else None, self.context
2867+
)
28472868

28482869
if execution.is_done and not execution.error:
28492870
output_literal_map = self._get_output_literal_map(execution_data)
2850-
execution._outputs = LiteralsResolver(output_literal_map.literals, interface.outputs, self.context)
2871+
execution._outputs = LiteralsResolver(
2872+
output_literal_map.literals, interface.outputs if interface else None, self.context
2873+
)
28512874
return execution
28522875

28532876
def _get_input_literal_map(self, execution_data: ExecutionDataResponse) -> literal_models.LiteralMap:

0 commit comments

Comments
 (0)