|
99 | 99 | from flytekit.models.project import Project |
100 | 100 | from flytekit.remote.backfill import create_backfill_workflow |
101 | 101 | from flytekit.remote.data import download_literal |
102 | | -from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow |
| 102 | +from flytekit.remote.entities import ( |
| 103 | + FlyteBranchNode, |
| 104 | + FlyteLaunchPlan, |
| 105 | + FlyteNode, |
| 106 | + FlyteTask, |
| 107 | + FlyteTaskNode, |
| 108 | + FlyteWorkflow, |
| 109 | +) |
103 | 110 | from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution |
104 | 111 | from flytekit.remote.interface import TypedInterface |
105 | 112 | from flytekit.remote.lazy_entity import LazyEntity |
@@ -2693,11 +2700,23 @@ def sync_node_execution( |
2693 | 2700 |
|
2694 | 2701 | # Handle the case where it's a branch node |
2695 | 2702 | elif execution._node.branch_node is not None: |
2696 | | - logger.info( |
2697 | | - "Skipping branch node execution for now - branch nodes will " |
2698 | | - "not have inputs and outputs filled in" |
2699 | | - ) |
2700 | | - return execution |
| 2703 | + sub_flyte_workflow = typing.cast(FlyteBranchNode, execution._node.flyte_entity) |
| 2704 | + sub_node_mapping = {} |
| 2705 | + if sub_flyte_workflow.if_else.case.then_node: |
| 2706 | + then_node = sub_flyte_workflow.if_else.case.then_node |
| 2707 | + sub_node_mapping[then_node.id] = then_node |
| 2708 | + if sub_flyte_workflow.if_else.other: |
| 2709 | + for case in sub_flyte_workflow.if_else.other: |
| 2710 | + then_node = case.then_node |
| 2711 | + sub_node_mapping[then_node.id] = then_node |
| 2712 | + if sub_flyte_workflow.if_else.else_node: |
| 2713 | + else_node = sub_flyte_workflow.if_else.else_node |
| 2714 | + sub_node_mapping[else_node.id] = else_node |
| 2715 | + |
| 2716 | + execution._underlying_node_executions = [ |
| 2717 | + self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping) |
| 2718 | + for cne in child_node_executions |
| 2719 | + ] |
2701 | 2720 | else: |
2702 | 2721 | logger.error(f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}") |
2703 | 2722 | raise ValueError(f"Node execution undeterminable, entity has type {type(execution._node)}") |
@@ -2839,15 +2858,19 @@ def _assign_inputs_and_outputs( |
2839 | 2858 | self, |
2840 | 2859 | execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution], |
2841 | 2860 | execution_data, |
2842 | | - interface: TypedInterface, |
| 2861 | + interface: typing.Optional[TypedInterface] = None, |
2843 | 2862 | ): |
2844 | 2863 | """Helper for assigning synced inputs and outputs to an execution object.""" |
2845 | 2864 | input_literal_map = self._get_input_literal_map(execution_data) |
2846 | | - execution._inputs = LiteralsResolver(input_literal_map.literals, interface.inputs, self.context) |
| 2865 | + execution._inputs = LiteralsResolver( |
| 2866 | + input_literal_map.literals, interface.inputs if interface else None, self.context |
| 2867 | + ) |
2847 | 2868 |
|
2848 | 2869 | if execution.is_done and not execution.error: |
2849 | 2870 | output_literal_map = self._get_output_literal_map(execution_data) |
2850 | | - execution._outputs = LiteralsResolver(output_literal_map.literals, interface.outputs, self.context) |
| 2871 | + execution._outputs = LiteralsResolver( |
| 2872 | + output_literal_map.literals, interface.outputs if interface else None, self.context |
| 2873 | + ) |
2851 | 2874 | return execution |
2852 | 2875 |
|
2853 | 2876 | def _get_input_literal_map(self, execution_data: ExecutionDataResponse) -> literal_models.LiteralMap: |
|
0 commit comments