Skip to content

Commit

Permalink
Set the node' label as the workgraph/task's name. (#213)
Browse files Browse the repository at this point in the history
Set the node' label as the workgraph/task's name when the process node is created. This is very useful for querying. For example, for `WorkGraph<abc_name>` and `PythonJob<abc_name>`, the `node.label` will be `abc_name`.Note: `node.label` is mutable, thus user can modify it after.
  • Loading branch information
superstar54 authored Aug 10, 2024
1 parent e378b19 commit 1cdaa44
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 7 deletions.
6 changes: 6 additions & 0 deletions aiida_workgraph/calculations/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ def _build_process_label(self) -> str:
else:
return f"PythonJob<{self.inputs.function_name.value}>"

def on_create(self) -> None:
"""Called when a Process is created."""

super().on_create()
self.node.label = self.inputs.process_label.value

def prepare_for_submission(self, folder: Folder) -> CalcInfo:
"""Prepare the calculation for submission.
Expand Down
2 changes: 1 addition & 1 deletion aiida_workgraph/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict:
function_kwargs = serialize_to_aiida_nodes(function_kwargs)
# transfer the args to kwargs
inputs = {
"process_label": f"PythonJob<{task['name']}",
"process_label": f"PythonJob<{task['name']}>",
"function_source_code": orm.Str(function_source_code),
"function_name": orm.Str(function_name),
"code": code,
Expand Down
1 change: 1 addition & 0 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def on_create(self) -> None:
)
saver = WorkGraphSaver(self.node, wgdata, restart_process=restart_process)
saver.save()
self.node.label = wgdata["name"]

def setup(self) -> None:
# track if the awaitable callback is added to the runner
Expand Down
15 changes: 9 additions & 6 deletions tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,26 @@ def multiply(x: Any, y: Any) -> Any:
wg = WorkGraph("test_PythonJob_outputs")
wg.add_task(
add,
name="add",
name="add1",
x=1,
y=2,
computer="localhost",
)
wg.add_task(
decorted_multiply,
name="multiply",
x=wg.tasks["add"].outputs["sum"],
name="multiply1",
x=wg.tasks["add1"].outputs["sum"],
y=3,
computer="localhost",
)
# wg.submit(wait=True)
wg.run()
assert wg.tasks["add"].outputs["sum"].value.value == 3
assert wg.tasks["add"].outputs["diff"].value.value == -1
assert wg.tasks["multiply"].outputs["result"].value.value == 9
assert wg.tasks["add1"].outputs["sum"].value.value == 3
assert wg.tasks["add1"].outputs["diff"].value.value == -1
assert wg.tasks["multiply1"].outputs["result"].value.value == 9
# process_label and label
assert wg.tasks["add1"].node.process_label == "PythonJob<add1>"
assert wg.tasks["add1"].node.label == "add1"


@pytest.mark.usefixtures("started_daemon_client")
Expand Down
1 change: 1 addition & 0 deletions tests/test_workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_save_load(wg_calcfunction):
wg.save()
assert wg.process.process_state.value.upper() == "CREATED"
assert wg.process.process_label == "WorkGraph<test_save_load>"
assert wg.process.label == "test_save_load"
wg2 = WorkGraph.load(wg.process.pk)
assert len(wg.tasks) == len(wg2.tasks)

Expand Down

0 comments on commit 1cdaa44

Please sign in to comment.