Skip to content

Commit 5e38ab4

Browse files
committed
change task's data inputs, outputs and properties to dict
1 parent c74bf4e commit 5e38ab4

File tree

10 files changed

+100
-55
lines changed

10 files changed

+100
-55
lines changed

src/aiida_workgraph/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
WORKGRAPH_SHORT_EXTRA_KEY = "_workgraph_short"
66

77

8-
builtin_inputs = [{"name": "_wait", "link_limit": 1e6, "arg_type": "none"}]
8+
builtin_inputs = [
9+
{"name": "_wait", "link_limit": 1e6, "metadata": {"arg_type": "none"}}
10+
]
911
builtin_outputs = [{"name": "_wait"}, {"name": "_outputs"}]
1012

1113

src/aiida_workgraph/decorator.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,14 @@
3939
def create_task(tdata):
4040
"""Wrap create_node from node_graph to create a Task."""
4141
from node_graph.decorator import create_node
42+
from node_graph.utils import list_to_dict
4243

4344
tdata["type_mapping"] = type_mapping
4445
tdata["metadata"]["node_type"] = tdata["metadata"].pop("task_type")
46+
tdata["properties"] = list_to_dict(tdata.get("properties", {}))
47+
tdata["inputs"] = list_to_dict(tdata.get("inputs", {}))
48+
tdata["outputs"] = list_to_dict(tdata.get("outputs", {}))
49+
4550
return create_node(tdata)
4651

4752

@@ -67,8 +72,11 @@ def add_input_recursive(
6772
{
6873
"identifier": "workgraph.namespace",
6974
"name": port_name,
70-
"arg_type": "kwargs",
71-
"metadata": {"required": required, "dynamic": port.dynamic},
75+
"metadata": {
76+
"arg_type": "kwargs",
77+
"required": required,
78+
"dynamic": port.dynamic,
79+
},
7280
}
7381
)
7482
for value in port.values():
@@ -87,8 +95,7 @@ def add_input_recursive(
8795
{
8896
"identifier": socket_type,
8997
"name": port_name,
90-
"arg_type": "kwargs",
91-
"metadata": {"required": required},
98+
"metadata": {"arg_type": "kwargs", "required": required},
9299
}
93100
)
94101
return inputs
@@ -249,8 +256,7 @@ def build_task_from_AiiDA(
249256
{
250257
"identifier": "workgraph.namespace",
251258
"name": name,
252-
"arg_type": "var_kwargs",
253-
"metadata": {"dynamic": True},
259+
"metadata": {"arg_type": "var_kwargs", "dynamic": True},
254260
}
255261
)
256262

src/aiida_workgraph/engine/workgraph.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,10 @@ def read_wgdata_from_base(self) -> t.Dict[str, t.Any]:
314314
for name, task in wgdata["tasks"].items():
315315
wgdata["tasks"][name] = deserialize_unsafe(task)
316316
for _, input in wgdata["tasks"][name]["inputs"].items():
317-
if input["property"] is None:
318-
continue
319-
prop = input["property"]
320-
if isinstance(prop["value"], PickledLocalFunction):
321-
prop["value"] = prop["value"].value
317+
if input.get("property"):
318+
prop = input["property"]
319+
if isinstance(prop["value"], PickledLocalFunction):
320+
prop["value"] = prop["value"].value
322321
wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"])
323322
wgdata["context"] = deserialize_unsafe(wgdata["context"])
324323
return wgdata

src/aiida_workgraph/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def process_nested_inputs(
118118
# create input sockets and links for items inside a dynamic socket
119119
# TODO the input value could be nested, but we only support one level for now
120120
for key in data:
121-
if self.inputs[key].identifier == "workgraph.namespace":
121+
if self.inputs[key]._socket_identifier == "workgraph.namespace":
122122
process_nested_inputs(
123123
key,
124124
self.inputs[key].value,

src/aiida_workgraph/tasks/builtins.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def __init__(self, *args, **kwargs):
1919
def create_sockets(self) -> None:
2020
self.inputs._clear()
2121
self.outputs._clear()
22-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
22+
self.add_input(
23+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
24+
)
2325
self.add_output("workgraph.any", "_wait")
2426

2527
def to_dict(self, short: bool = False) -> Dict[str, Any]:
@@ -43,7 +45,9 @@ class While(Zone):
4345
def create_sockets(self) -> None:
4446
self.inputs._clear()
4547
self.outputs._clear()
46-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
48+
self.add_input(
49+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
50+
)
4751
self.add_input(
4852
"node_graph.int", "max_iterations", property_data={"default": 10000}
4953
)
@@ -62,7 +66,9 @@ class If(Zone):
6266
def create_sockets(self) -> None:
6367
self.inputs._clear()
6468
self.outputs._clear()
65-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
69+
self.add_input(
70+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
71+
)
6672
self.add_input("workgraph.any", "conditions")
6773
self.add_input("workgraph.any", "invert_condition")
6874
self.add_output("workgraph.any", "_wait")
@@ -81,7 +87,9 @@ def create_sockets(self) -> None:
8187
self.outputs._clear()
8288
self.add_input("workgraph.any", "key")
8389
self.add_input("workgraph.any", "value")
84-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
90+
self.add_input(
91+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
92+
)
8593
self.add_output("workgraph.any", "_wait")
8694

8795

@@ -97,7 +105,9 @@ def create_sockets(self) -> None:
97105
self.inputs._clear()
98106
self.outputs._clear()
99107
self.add_input("workgraph.any", "key")
100-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
108+
self.add_input(
109+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
110+
)
101111
self.add_output("workgraph.any", "result")
102112
self.add_output("workgraph.any", "_wait")
103113

@@ -115,7 +125,9 @@ class AiiDAInt(Task):
115125

116126
def create_sockets(self) -> None:
117127
self.add_input("workgraph.any", "value", property_data={"default": 0.0})
118-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
128+
self.add_input(
129+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
130+
)
119131
self.add_output("workgraph.aiida_int", "result")
120132
self.add_output("workgraph.any", "_wait")
121133

@@ -135,7 +147,9 @@ def create_sockets(self) -> None:
135147
self.inputs._clear()
136148
self.outputs._clear()
137149
self.add_input("workgraph.float", "value", property_data={"default": 0.0})
138-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
150+
self.add_input(
151+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
152+
)
139153
self.add_output("workgraph.aiida_float", "result")
140154
self.add_output("workgraph.any", "_wait")
141155

@@ -155,7 +169,9 @@ def create_sockets(self) -> None:
155169
self.inputs._clear()
156170
self.outputs._clear()
157171
self.add_input("workgraph.string", "value", property_data={"default": ""})
158-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
172+
self.add_input(
173+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
174+
)
159175
self.add_output("workgraph.aiida_string", "result")
160176
self.add_output("workgraph.any", "_wait")
161177

@@ -175,7 +191,9 @@ def create_sockets(self) -> None:
175191
self.inputs._clear()
176192
self.outputs._clear()
177193
self.add_input("workgraph.any", "value", property_data={"default": []})
178-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
194+
self.add_input(
195+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
196+
)
179197
self.add_output("workgraph.aiida_list", "result")
180198
self.add_output("workgraph.any", "_wait")
181199

@@ -195,7 +213,9 @@ def create_sockets(self) -> None:
195213
self.inputs._clear()
196214
self.outputs._clear()
197215
self.add_input("workgraph.any", "value", property_data={"default": {}})
198-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
216+
self.add_input(
217+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
218+
)
199219
self.add_output("workgraph.aiida_dict", "result")
200220
self.add_output("workgraph.any", "_wait")
201221

@@ -223,7 +243,9 @@ def create_sockets(self) -> None:
223243
self.add_input("workgraph.any", "pk")
224244
self.add_input("workgraph.any", "uuid")
225245
self.add_input("workgraph.any", "label")
226-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
246+
self.add_input(
247+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
248+
)
227249
self.add_output("workgraph.any", "node")
228250
self.add_output("workgraph.any", "_wait")
229251

@@ -248,7 +270,9 @@ def create_sockets(self) -> None:
248270
self.add_input("workgraph.any", "pk")
249271
self.add_input("workgraph.any", "uuid")
250272
self.add_input("workgraph.any", "label")
251-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
273+
self.add_input(
274+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
275+
)
252276
self.add_output("workgraph.any", "Code")
253277
self.add_output("workgraph.any", "_wait")
254278

@@ -272,6 +296,8 @@ def create_sockets(self) -> None:
272296
self.add_input("workgraph.any", "condition")
273297
self.add_input("workgraph.any", "true")
274298
self.add_input("workgraph.any", "false")
275-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
299+
self.add_input(
300+
"workgraph.any", "_wait", link_limit=100000, metadata={"arg_type": "none"}
301+
)
276302
self.add_output("workgraph.any", "result")
277303
self.add_output("workgraph.any", "_wait")

src/aiida_workgraph/tasks/monitors.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def create_sockets(self) -> None:
2222
inp.add_property("workgraph.any", default=1.0)
2323
inp = self.add_input("workgraph.any", "timeout")
2424
inp.add_property("workgraph.any", default=86400.0)
25-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
25+
self.add_input(
26+
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
27+
)
2628
inp.socket_link_limit = 100000
2729
self.add_output("workgraph.any", "result")
2830
self.add_output("workgraph.any", "_wait")
@@ -49,7 +51,9 @@ def create_sockets(self) -> None:
4951
inp.add_property("workgraph.any", default=1.0)
5052
inp = self.add_input("workgraph.any", "timeout")
5153
inp.add_property("workgraph.any", default=86400.0)
52-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
54+
self.add_input(
55+
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
56+
)
5357
inp.socket_link_limit = 100000
5458
self.add_output("workgraph.any", "result")
5559
self.add_output("workgraph.any", "_wait")
@@ -78,7 +82,9 @@ def create_sockets(self) -> None:
7882
inp.add_property("workgraph.any", default=1.0)
7983
inp = self.add_input("workgraph.any", "timeout")
8084
inp.add_property("workgraph.any", default=86400.0)
81-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
85+
self.add_input(
86+
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
87+
)
8288
inp.socket_link_limit = 100000
8389
self.add_output("workgraph.any", "result")
8490
self.add_output("workgraph.any", "_wait")

src/aiida_workgraph/tasks/test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def create_sockets(self) -> None:
2323
inp.add_property("workgraph.aiida_float", "x", default=0.0)
2424
inp = self.add_input("workgraph.aiida_float", "y")
2525
inp.add_property("workgraph.aiida_float", "y", default=0.0)
26-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
26+
self.add_input(
27+
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
28+
)
2729
self.add_output("workgraph.aiida_float", "sum")
2830
self.add_output("workgraph.any", "_wait")
2931
self.add_output("workgraph.any", "_outputs")
@@ -51,7 +53,9 @@ def create_sockets(self) -> None:
5153
inp.add_property("workgraph.aiida_float", "x", default=0.0)
5254
inp = self.add_input("workgraph.aiida_float", "y")
5355
inp.add_property("workgraph.aiida_float", "y", default=0.0)
54-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
56+
self.add_input(
57+
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
58+
)
5559
self.add_output("workgraph.aiida_float", "sum")
5660
self.add_output("workgraph.aiida_float", "diff")
5761
self.add_output("workgraph.any", "_wait")
@@ -83,7 +87,9 @@ def create_sockets(self) -> None:
8387
inp.add_property("workgraph.aiida_int", "y", default=0.0)
8488
inp = self.add_input("workgraph.aiida_int", "z")
8589
inp.add_property("workgraph.aiida_int", "z", default=0.0)
86-
self.add_input("workgraph.any", "_wait", arg_type="none", link_limit=100000)
90+
self.add_input(
91+
"workgraph.any", "_wait", metadata={"arg_type": "none"}, link_limit=100000
92+
)
8793
self.add_output("workgraph.aiida_int", "result")
8894
self.add_output("workgraph.any", "_wait")
8995
self.add_output("workgraph.any", "_outputs")

src/aiida_workgraph/utils/__init__.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -252,17 +252,16 @@ def organize_nested_inputs(wgdata: Dict[str, Any]) -> None:
252252
update_nested_dict(root_prop["value"], key, prop["value"])
253253
prop["value"] = None
254254
for key, input in task["inputs"].items():
255-
if input["property"] is None:
256-
continue
257-
prop = input["property"]
258-
if "." in key and prop["value"] not in [None, {}]:
259-
root, key = key.split(".", 1)
260-
root_prop = task["inputs"][root]["property"]
261-
# update the root property
262-
root_prop["value"] = update_nested_dict(
263-
root_prop["value"], key, prop["value"]
264-
)
265-
prop["value"] = None
255+
if input.get("property"):
256+
prop = input["property"]
257+
if "." in key and prop["value"] not in [None, {}]:
258+
root, key = key.split(".", 1)
259+
root_prop = task["inputs"][root]["property"]
260+
# update the root property
261+
root_prop["value"] = update_nested_dict(
262+
root_prop["value"], key, prop["value"]
263+
)
264+
prop["value"] = None
266265

267266

268267
def generate_node_graph(
@@ -466,11 +465,10 @@ def serialize_workgraph_inputs(wgdata):
466465
if task["metadata"]["node_type"].upper() == "PYTHONJOB":
467466
PythonJob.serialize_pythonjob_data(task)
468467
for _, input in task["inputs"].items():
469-
if input["property"] is None:
470-
continue
471-
prop = input["property"]
472-
if inspect.isfunction(prop["value"]):
473-
prop["value"] = PickledLocalFunction(prop["value"]).store()
468+
if input.get("property"):
469+
prop = input["property"]
470+
if inspect.isfunction(prop["value"]):
471+
prop["value"] = PickledLocalFunction(prop["value"]).store()
474472
# error_handlers of the workgraph
475473
for _, data in wgdata["error_handlers"].items():
476474
if not data["handler"]["use_module_path"]:
@@ -548,7 +546,7 @@ def process_properties(task: Dict) -> Dict:
548546
}
549547
#
550548
for name, input in task["inputs"].items():
551-
if input["property"] is not None:
549+
if input.get("property"):
552550
prop = input["property"]
553551
identifier = prop["identifier"]
554552
value = prop.get("value")
@@ -617,6 +615,7 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di
617615
:raises TypeError: If wrong types are provided to the task
618616
:return: Processed `inputs`/`outputs` list.
619617
"""
618+
from node_graph.utils import list_to_dict
620619

621620
if not all(isinstance(item, (dict, str)) for item in inout_list):
622621
raise TypeError(
@@ -631,6 +630,8 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di
631630
elif isinstance(item, dict):
632631
processed_inout_list.append(item)
633632

633+
processed_inout_list = list_to_dict(processed_inout_list)
634+
634635
return processed_inout_list
635636

636637

src/aiida_workgraph/utils/analysis.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,10 @@ def insert_workgraph_to_db(self) -> None:
215215
self.save_task_states()
216216
for name, task in self.wgdata["tasks"].items():
217217
for _, input in task["inputs"].items():
218-
if input["property"] is None:
219-
continue
220-
prop = input["property"]
221-
if inspect.isfunction(prop["value"]):
222-
prop["value"] = PickledLocalFunction(prop["value"]).store()
218+
if input.get("property"):
219+
prop = input["property"]
220+
if inspect.isfunction(prop["value"]):
221+
prop["value"] = PickledLocalFunction(prop["value"]).store()
223222
self.wgdata["tasks"][name] = serialize(task)
224223
# nodes is a copy of tasks, so we need to pop it out
225224
self.wgdata["error_handlers"] = serialize(self.wgdata["error_handlers"])

src/aiida_workgraph/workgraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def update(self) -> None:
297297
i = 0
298298
for socket in self.tasks[name].outputs:
299299
socket.value = get_nested_dict(
300-
node.outputs, socket.name, default=None
300+
node.outputs, socket.socket_name, default=None
301301
)
302302
i += 1
303303
# read results from the process outputs

0 commit comments

Comments
 (0)