diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 61070ec7..5a9ee115 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1316,13 +1316,27 @@ def execute(self) -> dict[str, Any] | None: """ from .reentry import _get_runner - if self.has_terminated(): - return self.result() + if not self.has_terminated(): + coro = self.step_until_terminated() + with _get_runner() as runner: + result = runner.run(coro) + + return result + + else: + self.result() + + async def step_until_terminated(self) -> Any: + """If the process has not terminated, + run the current step and wait until the step finished. - runner = _get_runner() - with runner as runner: - return runner.run(self.step_until_terminated()) - # return asyncio.run(self.step_until_terminated()) + This is the function run by the event loop (not ``step``). + + """ + while not self.has_terminated(): + await self.step() + + return await self.future() @ensure_not_closed async def step(self) -> None: @@ -1378,18 +1392,6 @@ async def step(self) -> None: self._stepping = False self._set_interrupt_action(None) - async def step_until_terminated(self) -> Any: - """If the process has not terminated, - run the current step and wait until the step finished. - - This is the function run by the event loop (not ``step``). - - """ - while not self.has_terminated(): - await self.step() - - return await self.future() - # endregion @ensure_not_closed diff --git a/tests/test_processes.py b/tests/test_processes.py index eefd581d..767442e4 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -41,7 +41,6 @@ def on_kill(self, msg): super().on_kill(msg) -@pytest.mark.usefixtures('custom_event_loop_policy') def test_process_is_savable(): proc = utils.DummyProcess() assert isinstance(proc, Savable) @@ -71,7 +70,6 @@ async def task(self, steps: list): class TestProcess: - @pytest.mark.usefixtures('custom_event_loop_policy') def test_spec(self): """ Check that the references to specs are doing the right thing... @@ -586,7 +584,6 @@ def run(self): proc = StackTest() proc.execute() - @pytest.mark.usefixtures('custom_event_loop_policy') def test_process_stack_multiple(self): """ Run multiple and nested processes to make sure the process stack is always correct @@ -622,7 +619,6 @@ def run(self): assert len(expect_true) == n_run * 3 - @pytest.mark.usefixtures('custom_event_loop_policy') def test_process_nested(self): """ Run multiple and nested processes to make sure the process stack is always correct