Skip to content

Commit

Permalink
wg.wait now can wait for special tasks
Browse files Browse the repository at this point in the history
This is useful in the test.
  • Loading branch information
superstar54 committed Jul 31, 2024
1 parent d51e012 commit 2754577
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
25 changes: 17 additions & 8 deletions aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,28 +206,37 @@ def to_dict(self, store_nodes=False) -> Dict[str, Any]:

return wgdata

def wait(self, timeout: int = 50) -> None:
def wait(self, timeout: int = 50, tasks: dict = None) -> None:
"""
Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout.
Args:
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 50.
"""

start = time.time()
self.update()
while self.state not in (
terminating_states = (
"KILLED",
"PAUSED",
"FINISHED",
"FAILED",
"CANCELLED",
"EXCEPTED",
):
time.sleep(0.5)
)
start = time.time()
self.update()
finished = False
while not finished:
self.update()
if tasks is not None:
states = []
for name in tasks:
flag = self.tasks[name].state in terminating_states
states.append(flag)
finished = all(states)
else:
finished = self.state in terminating_states
time.sleep(0.5)
if time.time() - start > timeout:
return
break

def update(self) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions tests/test_error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def handle_negative_sum(self, task_name: str, **kwargs):
"add1": {"code": add_code, "x": orm.Int(1), "y": orm.Int(-2)},
},
wait=True,
timeout=80,
)
report = get_workchain_report(wg.process, "REPORT")
assert "Run error handler: handle_negative_sum." in report
Expand Down
14 changes: 10 additions & 4 deletions tests/test_workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def test_pause_task_before_submit(wg_calcjob):
wg.name = "test_pause_task"
wg.pause_tasks(["add2"])
wg.submit()
time.sleep(20)
wg.wait(tasks=["add1"])
assert wg.tasks["add1"].node.process_state.value.upper() == "FINISHED"
# wait for the workgraph to launch add2
time.sleep(3)
wg.update()
assert wg.tasks["add2"].node.process_state.value.upper() == "CREATED"
assert wg.tasks["add2"].node.process_status == "Paused through WorkGraph"
Expand All @@ -120,20 +123,23 @@ def test_pause_task_before_submit(wg_calcjob):
assert wg.tasks["add2"].outputs["sum"].value == 9


@pytest.mark.usefixtures("started_daemon_client")
def test_pause_task_after_submit(wg_calcjob):
wg = wg_calcjob
wg.name = "test_pause_task"
wg.submit()
# wait for the daemon to start the workgraph
time.sleep(3)
time.sleep(2)
# wg.run()
wg.pause_tasks(["add2"])
time.sleep(20)
wg.wait(tasks=["add1"])
# wait for the workgraph to launch add2
time.sleep(3)
wg.update()
assert wg.tasks["add2"].node.process_state.value.upper() == "CREATED"
assert wg.tasks["add2"].node.process_status == "Paused through WorkGraph"
wg.play_tasks(["add2"])
wg.wait()
wg.wait(tasks=["add2"])
assert wg.tasks["add2"].outputs["sum"].value == 9


Expand Down

0 comments on commit 2754577

Please sign in to comment.