Skip to content

Commit 15cc1b6

Browse files
authored
don't transform bound inputs to list for remote entities (#3168)
* don't transform bound inputs to list for remote entities Signed-off-by: Paul Dittamo <[email protected]> * clean up Signed-off-by: Paul Dittamo <[email protected]> --------- Signed-off-by: Paul Dittamo <[email protected]>
1 parent 85e9a90 commit 15cc1b6

File tree

5 files changed

+22
-6
lines changed

5 files changed

+22
-6
lines changed

flytekit/core/array_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
self.target.python_interface, self._bound_inputs, output_as_list_of_optionals
8989
)
9090
elif self.target.interface:
91-
self._remote_interface = self.target.interface.transform_interface_to_list()
91+
self._remote_interface = self.target.interface.transform_interface_to_list(self.bound_inputs)
9292
else:
9393
raise ValueError("No interface found for the target entity.")
9494

flytekit/core/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def transform_interface_to_list_interface(
372372
Takes a single task interface and interpolates it to an array interface - to allow performing distributed python map
373373
like functions
374374
:param interface: Interface to be upgraded to a list interface
375-
:param bound_inputs: fixed inputs that should not upgraded to a list and will be maintained as scalars.
375+
:param bound_inputs: fixed inputs that should not be updated to a list and will be maintained as is
376376
"""
377377
map_inputs = transform_types_to_list_of_type(interface.inputs, bound_inputs)
378378
map_outputs = transform_types_to_list_of_type(interface.outputs, set(), optional_outputs)

flytekit/models/interface.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,19 @@ def from_flyte_idl(cls, proto: _interface_pb2.TypedInterface) -> "TypedInterface
158158
outputs={k: Variable.from_flyte_idl(v) for k, v in proto.outputs.variables.items()},
159159
)
160160

161-
def transform_interface_to_list(self) -> "TypedInterface":
161+
def transform_interface_to_list(self, bound_inputs: typing.Set[str]) -> "TypedInterface":
162162
"""
163163
Takes a single task interface and interpolates it to an array interface - to allow performing distributed
164164
python map like functions
165+
:param bound_inputs: fixed inputs that should not be updated to a list and will be maintained as is
165166
"""
166167
list_interface = _interface_pb2.TypedInterface(
167-
inputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl_list() for k, v in self.inputs.items()}),
168+
inputs=_interface_pb2.VariableMap(
169+
variables={
170+
k: (v.to_flyte_idl_list() if k not in bound_inputs else v.to_flyte_idl())
171+
for k, v in self.inputs.items()
172+
}
173+
),
168174
outputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl_list() for k, v in self.outputs.items()}),
169175
)
170176
return self.from_flyte_idl(list_interface)

flytekit/remote/remote.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2639,7 +2639,7 @@ def sync_node_execution(
26392639
launch_plan = self.fetch_launch_plan(
26402640
launch_plan_id.project, launch_plan_id.domain, launch_plan_id.name, launch_plan_id.version
26412641
)
2642-
task_execution_interface = launch_plan.interface.transform_interface_to_list()
2642+
task_execution_interface = launch_plan.interface.transform_interface_to_list(bound_inputs=set())
26432643
execution._task_executions = [
26442644
self.sync_task_execution(
26452645
FlyteTaskExecution.promote_from_model(task_execution), task_execution_interface

tests/flytekit/unit/models/test_interface.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_typed_interface(literal_type):
5151
assert len(deserialized_typed_interface.inputs) == 1
5252
assert len(deserialized_typed_interface.outputs) == 2
5353

54-
deserialized_typed_interface_list = typed_interface.transform_interface_to_list()
54+
deserialized_typed_interface_list = typed_interface.transform_interface_to_list(set())
5555
assert deserialized_typed_interface_list.inputs["a"].type == types.LiteralType(collection_type=literal_type)
5656
assert deserialized_typed_interface_list.outputs["b"].type == types.LiteralType(collection_type=literal_type)
5757
assert deserialized_typed_interface_list.outputs["c"].type == types.LiteralType(collection_type=literal_type)
@@ -61,6 +61,16 @@ def test_typed_interface(literal_type):
6161
assert len(deserialized_typed_interface_list.inputs) == 1
6262
assert len(deserialized_typed_interface_list.outputs) == 2
6363

64+
deserialized_typed_interface_list = typed_interface.transform_interface_to_list({"a"})
65+
assert deserialized_typed_interface_list.inputs["a"].type == literal_type
66+
assert deserialized_typed_interface_list.outputs["b"].type == types.LiteralType(collection_type=literal_type)
67+
assert deserialized_typed_interface_list.outputs["c"].type == types.LiteralType(collection_type=literal_type)
68+
assert deserialized_typed_interface_list.inputs["a"].description == "description1"
69+
assert deserialized_typed_interface_list.outputs["b"].description == "description2"
70+
assert deserialized_typed_interface_list.outputs["c"].description == "description3"
71+
assert len(deserialized_typed_interface_list.inputs) == 1
72+
assert len(deserialized_typed_interface_list.outputs) == 2
73+
6474

6575
def test_parameter():
6676
v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf")

0 commit comments

Comments
 (0)