From 24d73a09b9cc93e6ba6cd55559dbc549ff529434 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Thu, 23 Jan 2025 01:02:29 +0100 Subject: [PATCH] Move test_workchain and test func_stepper savable --- src/plumpy/event_helper.py | 7 +- src/plumpy/persistence.py | 8 +- src/plumpy/process_listener.py | 6 +- src/plumpy/process_states.py | 13 ++- src/plumpy/processes.py | 2 +- src/plumpy/workchains.py | 16 ++-- tests/test_event_helper.py | 1 + tests/test_process_listener.py | 1 + tests/test_processes.py | 2 + tests/workchain/__init__.py | 0 tests/workchain/test_steppers.py | 105 +++++++++++++++++++++++ tests/{ => workchain}/test_workchains.py | 5 +- 12 files changed, 142 insertions(+), 24 deletions(-) create mode 100644 tests/workchain/__init__.py create mode 100644 tests/workchain/test_steppers.py rename tests/{ => workchain}/test_workchains.py (99%) diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index 9d70e1c4..47188031 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -1,6 +1,10 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import logging -from typing import TYPE_CHECKING, Any, Callable, Optional, Self, final +from typing import TYPE_CHECKING, Any, Callable, Optional, final + +from typing_extensions import Self from plumpy.loaders import ObjectLoader from plumpy.persistence import LoadSaveContext, auto_load, auto_save, ensure_object_loader @@ -13,7 +17,6 @@ _LOGGER = logging.getLogger(__name__) -# FIXME: test me @final @persistence.auto_persist('_listeners', '_listener_type') diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index bcc037ae..02b5ff76 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -21,12 +21,12 @@ List, Optional, Protocol, - Self, TypeVar, runtime_checkable, ) import yaml +from typing_extensions import Self from . import futures, loaders from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -577,18 +577,18 @@ def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSav def auto_persist(*members: str) -> Callable[[type[T]], type[T]]: def wrapped(cls: type[T]) -> type[T]: - if not hasattr(cls, '_auto_persist') or cls._auto_persist is None: + if not hasattr(cls, '_auto_persist') or cls._auto_persist is None: # type: ignore[attr-defined] cls._auto_persist = set() # type: ignore[attr-defined] else: - cls._auto_persist = set(cls._auto_persist) + cls._auto_persist = set(cls._auto_persist) # type: ignore[attr-defined] cls._auto_persist.update(members) # type: ignore[attr-defined] - # XXX: validate on `save` and `recreate_from` method?? return cls return wrapped +# FIXME: test me after clear event loop management @auto_persist('_state', '_result') class SavableFuture(futures.Future): """ diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index 280137b0..8bc7c828 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -1,6 +1,10 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import abc -from typing import TYPE_CHECKING, Any, Dict, Optional, Self +from typing import TYPE_CHECKING, Any, Dict, Optional + +from typing_extensions import Self from plumpy.loaders import ObjectLoader from plumpy.persistence import LoadSaveContext, auto_save, ensure_object_loader diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index d2d38d4d..9874e616 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -11,7 +11,6 @@ Callable, ClassVar, Optional, - Self, Tuple, Type, Union, @@ -20,7 +19,7 @@ ) import yaml -from typing_extensions import override +from typing_extensions import Self, override from yaml.loader import Loader from plumpy.loaders import ObjectLoader @@ -268,7 +267,7 @@ class Running: COMMAND = 'command' # The key used to store an upcoming command # Class level defaults - _command: Kill | Stop | Wait | Continue | None = None + _command: Command | None = None _running: bool = False _run_handle = None @@ -311,13 +310,13 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa obj.run_fn = ensure_coroutine(getattr(obj.process, saved_state[obj.RUN_FN])) if obj.COMMAND in saved_state: loaded_cmd = persistence.load(saved_state[obj.COMMAND], load_context) - if isinstance(loaded_cmd, Command): + if not isinstance(loaded_cmd, Command): # runtime check for loading from persistence - obj._command = loaded_cmd - else: # XXX: debug log in principle unreachable raise RuntimeError(f'command `{obj.COMMAND}` loaded from Running state not a valid `Command` type') + obj._command = loaded_cmd + return obj def interrupt(self, reason: Any) -> None: @@ -354,7 +353,7 @@ async def execute(self) -> st.State: next_state = self._action_command(command) return next_state - def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.State: + def _action_command(self, command: Command) -> st.State: if isinstance(command, Kill): state = st.create_state(self.process, ProcessState.KILLED, msg=command.msg) # elif isinstance(command, Pause): diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 24db39a4..9d985344 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -27,7 +27,6 @@ Hashable, List, Optional, - Self, Sequence, Tuple, Type, @@ -37,6 +36,7 @@ ) import kiwipy +from typing_extensions import Self from plumpy.coordinator import Coordinator from plumpy.loaders import ObjectLoader diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 9412b490..e6e527f3 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -16,7 +16,6 @@ MutableSequence, Optional, Protocol, - Self, Sequence, Tuple, Type, @@ -24,6 +23,8 @@ cast, ) +from typing_extensions import Self + from plumpy import utils from plumpy.base import state_machine from plumpy.base.utils import call_with_super_check @@ -399,7 +400,7 @@ def create_stepper(self, workchain: 'WorkChain') -> _FunctionStepper: def recreate_stepper(self, saved_state: SAVED_STATE_TYPE, workchain: 'WorkChain') -> _FunctionStepper: load_context = persistence.LoadSaveContext(workchain=workchain, func_spec=self) - return cast(_FunctionStepper, _FunctionStepper.recreate_from(saved_state, load_context)) + return _FunctionStepper.recreate_from(saved_state, load_context) def get_description(self) -> str: desc = self._fn.__name__ @@ -479,6 +480,7 @@ class _Block(_Instruction, collections.abc.Sequence): Represents a block of instructions i.e. a sequential list of instructions. """ + # XXX: swap workchain and instructions def __init__(self, instructions: Sequence[Union[_Instruction, WC_COMMAND_TYPE]]) -> None: # Build up the list of commands comms: MutableSequence[_Instruction | _FunctionCall] = [] @@ -502,7 +504,7 @@ def create_stepper(self, workchain: 'WorkChain') -> _BlockStepper: def recreate_stepper(self, saved_state: SAVED_STATE_TYPE, workchain: 'WorkChain') -> _BlockStepper: load_context = persistence.LoadSaveContext(workchain=workchain, block_instruction=self) - return cast(_BlockStepper, _BlockStepper.recreate_from(saved_state, load_context)) + return _BlockStepper.recreate_from(saved_state, load_context) def get_description(self) -> List[str]: return [instruction.get_description() for instruction in self._instruction] @@ -551,7 +553,7 @@ def is_true(self, workflow: 'WorkChain') -> bool: return result - def __call__(self, *instructions: Union[_Instruction, WC_COMMAND_TYPE]) -> _Instruction: + def __call__(self, *instructions: _Instruction | WC_COMMAND_TYPE) -> _Instruction: assert self._body is None, 'Instructions have already been set' self._body = _Block(instructions) return self._parent @@ -562,6 +564,7 @@ def __str__(self) -> str: @persistence.auto_persist('_pos') class _IfStepper: + # XXX: swap workchain and if_instruction def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None: self._workchain = workchain self._if_instruction = if_instruction @@ -671,7 +674,7 @@ def create_stepper(self, workchain: 'WorkChain') -> _IfStepper: def recreate_stepper(self, saved_state: SAVED_STATE_TYPE, workchain: 'WorkChain') -> _IfStepper: load_context = persistence.LoadSaveContext(workchain=workchain, if_instruction=self) - return cast(_IfStepper, _IfStepper.recreate_from(saved_state, load_context)) + return _IfStepper.recreate_from(saved_state, load_context) def get_description(self) -> Mapping[str, Any]: description = collections.OrderedDict() @@ -759,7 +762,7 @@ def create_stepper(self, workchain: 'WorkChain') -> _WhileStepper: def recreate_stepper(self, saved_state: SAVED_STATE_TYPE, workchain: 'WorkChain') -> _WhileStepper: load_context = persistence.LoadSaveContext(workchain=workchain, while_instruction=self) - return cast(_WhileStepper, _WhileStepper.recreate_from(saved_state, load_context)) + return _WhileStepper.recreate_from(saved_state, load_context) def get_description(self) -> Dict[str, Any]: return {f'while({self.predicate.__name__})': self.body.get_description()} @@ -771,7 +774,6 @@ def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: self.exit_code = exit_code -@persistence.auto_persist() class _ReturnStepper: def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: self._workchain = workchain diff --git a/tests/test_event_helper.py b/tests/test_event_helper.py index fc2310fa..a8351dd4 100644 --- a/tests/test_event_helper.py +++ b/tests/test_event_helper.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from plumpy.event_helper import EventHelper from plumpy.persistence import Savable, load from tests.utils import DummyProcess, ProcessListenerTester diff --git a/tests/test_process_listener.py b/tests/test_process_listener.py index 37d41593..4223b4cc 100644 --- a/tests/test_process_listener.py +++ b/tests/test_process_listener.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from plumpy.persistence import Savable, load from tests.utils import DummyProcess, ProcessListenerTester diff --git a/tests/test_processes.py b/tests/test_processes.py index 308d2fb1..7e6b82c8 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -46,6 +46,8 @@ def test_process_is_savable(): proc = utils.DummyProcess() assert isinstance(proc, Savable) + # TODO: direct save load round trip regression + @pytest.mark.asyncio async def test_process_scope(): diff --git a/tests/workchain/__init__.py b/tests/workchain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/workchain/test_steppers.py b/tests/workchain/test_steppers.py new file mode 100644 index 00000000..6bd5a07b --- /dev/null +++ b/tests/workchain/test_steppers.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +import pytest +from plumpy.base.state_machine import StateMachine +from plumpy.persistence import LoadSaveContext, Savable, load +from plumpy.workchains import WorkChain, if_, while_ + + +class DummyWc(WorkChain): + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline( + cls.do_nothing, + if_(cls.cond)(cls.do_cond), + while_(cls.wcond)( + cls.do_wcond, + ), + ) + + @staticmethod + def do_nothing(_wc: WorkChain) -> None: + pass + + @staticmethod + def cond(_wc: WorkChain) -> bool: + return True + + @staticmethod + def do_cond(_wc: WorkChain) -> None: + pass + + @staticmethod + def wcond(_wc: WorkChain) -> bool: + return True + + @staticmethod + def do_wcond(_wc: WorkChain) -> None: + pass + + +@pytest.fixture(scope='function') +def wc() -> StateMachine: + return DummyWc() + + +def test_func_stepper_savable(wc: DummyWc): + from plumpy.workchains import _FunctionStepper + + fs = _FunctionStepper(workchain=wc, fn=wc.do_nothing) + assert isinstance(fs, Savable) + + ctx = LoadSaveContext(workchain=wc) + saved_state = fs.save() + loaded_state = load(saved_state=saved_state, load_context=ctx) + saved_state2 = loaded_state.save() + + assert saved_state == saved_state2 + + +def test_block_stepper_savable(wc: DummyWc): + """block stepper test with a dummy function call""" + from plumpy.workchains import _BlockStepper, _FunctionCall + + block = [_FunctionCall(wc.do_nothing)] + bs = _BlockStepper(block=block, workchain=wc) + assert isinstance(bs, Savable) + + ctx = LoadSaveContext(workchain=wc, block_instruction=block) + saved_state = bs.save() + loaded_state = load(saved_state=saved_state, load_context=ctx) + saved_state2 = loaded_state.save() + + assert saved_state == saved_state2 + + +def test_if_stepper_savable(wc: DummyWc): + """block stepper test with a dummy function call""" + from plumpy.workchains import _If, _IfStepper + + dummy_if = _If(wc.cond) + ifs = _IfStepper(if_instruction=dummy_if, workchain=wc) + assert isinstance(ifs, Savable) + + ctx = LoadSaveContext(workchain=wc, if_instruction=ifs) + saved_state = ifs.save() + loaded_state = load(saved_state=saved_state, load_context=ctx) + saved_state2 = loaded_state.save() + + assert saved_state == saved_state2 + + +def test_while_stepper_savable(wc: DummyWc): + """block stepper test with a dummy function call""" + from plumpy.workchains import _While, _WhileStepper + + dummy_while = _While(wc.cond) + wfs = _WhileStepper(while_instruction=dummy_while, workchain=wc) + assert isinstance(wfs, Savable) + + ctx = LoadSaveContext(workchain=wc, while_instruction=wfs) + saved_state = wfs.save() + loaded_state = load(saved_state=saved_state, load_context=ctx) + saved_state2 = loaded_state.save() + + assert saved_state == saved_state2 diff --git a/tests/test_workchains.py b/tests/workchain/test_workchains.py similarity index 99% rename from tests/test_workchains.py rename to tests/workchain/test_workchains.py index 6d55c611..cb76020c 100644 --- a/tests/test_workchains.py +++ b/tests/workchain/test_workchains.py @@ -9,9 +9,8 @@ from plumpy.process_listener import ProcessListener from plumpy.workchains import * -from . import utils +from .. import utils -# FIXME: test steppers are savable and round trip persistence class Wf(WorkChain): # Keep track of which steps were completed by the workflow @@ -90,6 +89,8 @@ def test_workchain_is_savable(): w = Wf(inputs=dict(value='A', n=3)) assert isinstance(w, Savable) + # TODO: direct regression save load round trip + class IfTest(WorkChain): @classmethod