Skip to content

Commit

Permalink
test_event_helper_savable
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jan 23, 2025
1 parent 64929b7 commit 9846280
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
12 changes: 6 additions & 6 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
import logging
from typing import TYPE_CHECKING, Any, Callable, Optional, Self
from typing import TYPE_CHECKING, Any, Callable, Optional, Self, final

from plumpy.loaders import ObjectLoader
from plumpy.persistence import LoadSaveContext, auto_load, auto_save, ensure_object_loader
Expand All @@ -9,20 +9,20 @@
from . import persistence

if TYPE_CHECKING:
from typing import Set, Type

from .process_listener import ProcessListener

_LOGGER = logging.getLogger(__name__)

# FIXME: test me

@final
@persistence.auto_persist('_listeners', '_listener_type')
class EventHelper:
def __init__(self, listener_type: 'Type[ProcessListener]'):
def __init__(self, listener_type: 'type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()
self._listeners: 'set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
Expand Down Expand Up @@ -55,7 +55,7 @@ def save(self, loader: ObjectLoader | None = None) -> SAVED_STATE_TYPE:
return out_state

@property
def listeners(self) -> 'Set[ProcessListener]':
def listeners(self) -> 'set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def auto_save(obj: Savable, loader: loaders.ObjectLoader | None = None) -> SAVED
# Save object class name
SaveUtil.set_class_name(out_state, loader.identify_object(obj.__class__))

# FIXME: it should be an iter call to save until all resolved
# TODO: it should be an regression call to save until all resolved
if isinstance(obj, SavableWithAutoPersist):
for member in obj._auto_persist:
value = getattr(obj, member)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_event_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from plumpy.event_helper import EventHelper
from plumpy.persistence import Savable, load
from tests.utils import DummyProcess, ProcessListenerTester


def test_event_helper_savable():
eh = EventHelper(ProcessListenerTester)

proc = DummyProcess()
pl1 = ProcessListenerTester(proc, ('killed'))
pl2 = ProcessListenerTester(proc, ('paused'))
eh.add_listener(pl1)
eh.add_listener(pl2)

assert isinstance(eh, Savable)

saved = eh.save()
loaded = load(saved_state=saved)
saved2 = loaded.save()

assert saved == saved2
3 changes: 2 additions & 1 deletion tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from plumpy.utils import AttributesFrozendict
from . import utils

# FIXME: any process listener is savable

class ForgetToCallParent(plumpy.Process):
def __init__(self, forget_on):
Expand All @@ -42,10 +41,12 @@ def on_kill(self, msg):
if self.forget_on != 'kill':
super().on_kill(msg)


def test_process_is_savable():
proc = utils.DummyProcess()
assert isinstance(proc, Savable)


@pytest.mark.asyncio
async def test_process_scope():
class ProcessTaskInterleave(plumpy.Process):
Expand Down

0 comments on commit 9846280

Please sign in to comment.