Skip to content

Commit

Permalink
Adapt to using greenlet runner
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Feb 1, 2025
1 parent 6139cf1 commit b6a957c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 25 deletions.
40 changes: 22 additions & 18 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def __init__(
self.spec().seal()

self._loop = loop or asyncio.get_event_loop()
print("process: ", id(self._loop))

self._setup_event_hooks()

Expand Down Expand Up @@ -1316,13 +1317,28 @@ def execute(self) -> dict[str, Any] | None:
"""
from .reentry import _get_runner

if self.has_terminated():
return self.result()
if not self.has_terminated():
coro = self.step_until_terminated()
with _get_runner(loop=self.loop) as runner:
print("I got runner: ", runner)
result = runner.run(coro)

return result

else:
self.result()

async def step_until_terminated(self) -> Any:
"""If the process has not terminated,
run the current step and wait until the step finished.
runner = _get_runner()
with runner as runner:
return runner.run(self.step_until_terminated())
# return asyncio.run(self.step_until_terminated())
This is the function run by the event loop (not ``step``).
"""
while not self.has_terminated():
await self.step()

return await self.future()

@ensure_not_closed
async def step(self) -> None:
Expand Down Expand Up @@ -1378,18 +1394,6 @@ async def step(self) -> None:
self._stepping = False
self._set_interrupt_action(None)

async def step_until_terminated(self) -> Any:
"""If the process has not terminated,
run the current step and wait until the step finished.
This is the function run by the event loop (not ``step``).
"""
while not self.has_terminated():
await self.step()

return await self.future()

# endregion

@ensure_not_closed
Expand Down
33 changes: 30 additions & 3 deletions src/plumpy/reentry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from weakref import WeakSet
from greenlet import greenlet


def _close_loop(loop):
if loop is not None:
try:
Expand All @@ -22,12 +23,14 @@ def _close_loop(loop):
finally:
loop.close()


class _Genlet(greenlet):
"""
Generator-like object based on ``greenlets``. It allows nested :class:`_Genlet`
to make their parent yield on their behalf, as if callees could decide to
be annotated ``yield from`` without modifying the caller.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -225,21 +228,27 @@ def _check_executor_alive(executor):

_PATCHED_LOOP_LOCK = threading.Lock()
_PATCHED_LOOP = WeakSet()


def _install_task_factory(loop):
"""
Install a task factory on the given event ``loop`` so that top-level
coroutines are wrapped using :func:`allow_nested_run`. This ensures that
the nested :func:`run` infrastructure will be available.
"""

def install(loop):
if sys.version_info >= (3, 11):

def default_factory(loop, coro, context=None):
return asyncio.Task(coro, loop=loop, context=context)
else:

def default_factory(loop, coro, context=None):
return asyncio.Task(coro, loop=loop)

make_task = loop.get_task_factory() or default_factory

def factory(loop, coro, context=None):
# Make sure each Task will be able to yield on behalf of its nested
# await beneath blocking layers
Expand Down Expand Up @@ -273,6 +282,7 @@ class _CoroRunner(abc.ABC):
the awaitables yielded by an async generator that are all attached to a
single event loop.
"""

@abc.abstractmethod
def _run(self, coro):
pass
Expand All @@ -298,6 +308,7 @@ class _ThreadCoroRunner(_CoroRunner):
Critically, this allows running multiple coroutines out of the same thread,
which will be reserved until the runner ``__exit__`` method is called.
"""

def __init__(self, future, jobq, resq):
self._future = future
self._jobq = jobq
Expand Down Expand Up @@ -336,7 +347,9 @@ def from_executor(cls, executor):
if _check_executor_alive(executor):
raise e
else:
raise RuntimeError('Devlib relies on nested asyncio implementation requiring threads. These threads are not available while shutting down the interpreter.')
raise RuntimeError(
'Relies on nested asyncio implementation requiring threads. These threads are not available while shutting down the interpreter.'
)

return cls(
jobq=jobq,
Expand Down Expand Up @@ -368,6 +381,7 @@ class _LoopCoroRunner(_CoroRunner):
a new event loop will be created in ``__enter__`` and closed in
``__exit__``.
"""

def __init__(self, loop):
self.loop = loop
self._owned = False
Expand All @@ -380,6 +394,7 @@ def _run(self, coro):
# context=...) or loop.create_task(..., context=...) but these APIs are
# only available since Python 3.11
ctx = None

async def capture_ctx():
nonlocal ctx
try:
Expand All @@ -397,6 +412,7 @@ def __enter__(self):
if loop is None:
owned = True
loop = asyncio.new_event_loop()
print(id(loop))
else:
owned = False

Expand All @@ -410,28 +426,39 @@ def __exit__(self, *args, **kwargs):
if self._owned:
asyncio.set_event_loop(None)
_close_loop(self.loop)
print(f"close {id(self.loop)} {self.loop}")


class _GenletCoroRunner(_CoroRunner):
"""
Run a coroutine assuming one of the parent coroutines was wrapped with
:func:`allow_nested_run`.
"""

def __init__(self, g):
self._g = g

def _run(self, coro):
return self._g.consume_coro(coro, None)


def _get_runner():
def _get_runner(loop=None):
executor = _CORO_THREAD_EXECUTOR
g = _Genlet.get_enclosing()
try:
loop = asyncio.get_running_loop()
loop = loop or asyncio.get_running_loop()
except RuntimeError:
loop = None

print("!!!!!!")
print("NN", loop)
print("id is: ", id(loop))

# 1. If there is an existing _Genlet in the call stack, it uses _GenletCoroRunner.
# 2. Else if there is no running event loop, it creates or re-uses one with _LoopCoroRunner.
# 3. Else if there is a running event loop in the same thread but no _Genlet, it uses a _ThreadCoroRunner to
# offload to a separate thread with its own event loop.

# We have an coroutine wrapped with allow_nested_run() higher in the
# callstack, that we will be able to use as a conduit to yield the
# futures.
Expand Down
15 changes: 11 additions & 4 deletions tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def on_kill(self, msg):
super().on_kill(msg)


@pytest.mark.usefixtures('custom_event_loop_policy')
def test_process_is_savable():
proc = utils.DummyProcess()
assert isinstance(proc, Savable)
Expand Down Expand Up @@ -71,7 +70,6 @@ async def task(self, steps: list):


class TestProcess:
@pytest.mark.usefixtures('custom_event_loop_policy')
def test_spec(self):
"""
Check that the references to specs are doing the right thing...
Expand Down Expand Up @@ -586,7 +584,6 @@ def run(self):
proc = StackTest()
proc.execute()

@pytest.mark.usefixtures('custom_event_loop_policy')
def test_process_stack_multiple(self):
"""
Run multiple and nested processes to make sure the process stack is always correct
Expand Down Expand Up @@ -622,7 +619,6 @@ def run(self):

assert len(expect_true) == n_run * 3

@pytest.mark.usefixtures('custom_event_loop_policy')
def test_process_nested(self):
"""
Run multiple and nested processes to make sure the process stack is always correct
Expand All @@ -638,6 +634,17 @@ def run(self):

ParentProcess().execute()

@pytest.mark.asyncio
async def test_processes_run_in_sequence(self):
"""Run execute for two processes in sequence, to check the main thread event loop is intactive"""

class StackTest(plumpy.Process):
def run(self):
pass

StackTest().execute()
StackTest().execute()

def test_call_soon(self):
class CallSoon(plumpy.Process):
def run(self):
Expand Down

0 comments on commit b6a957c

Please sign in to comment.