Skip to content

Commit

Permalink
Move test_workchain and test func_stepper savable
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jan 23, 2025
1 parent 9846280 commit 24d73a0
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 24 deletions.
7 changes: 5 additions & 2 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Callable, Optional, Self, final
from typing import TYPE_CHECKING, Any, Callable, Optional, final

from typing_extensions import Self

from plumpy.loaders import ObjectLoader
from plumpy.persistence import LoadSaveContext, auto_load, auto_save, ensure_object_loader
Expand All @@ -13,7 +17,6 @@

_LOGGER = logging.getLogger(__name__)

# FIXME: test me

@final
@persistence.auto_persist('_listeners', '_listener_type')
Expand Down
8 changes: 4 additions & 4 deletions src/plumpy/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
List,
Optional,
Protocol,
Self,
TypeVar,
runtime_checkable,
)

import yaml
from typing_extensions import Self

from . import futures, loaders
from .utils import PID_TYPE, SAVED_STATE_TYPE
Expand Down Expand Up @@ -577,18 +577,18 @@ def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSav

def auto_persist(*members: str) -> Callable[[type[T]], type[T]]:
def wrapped(cls: type[T]) -> type[T]:
if not hasattr(cls, '_auto_persist') or cls._auto_persist is None:
if not hasattr(cls, '_auto_persist') or cls._auto_persist is None: # type: ignore[attr-defined]
cls._auto_persist = set() # type: ignore[attr-defined]
else:
cls._auto_persist = set(cls._auto_persist)
cls._auto_persist = set(cls._auto_persist) # type: ignore[attr-defined]

cls._auto_persist.update(members) # type: ignore[attr-defined]
# XXX: validate on `save` and `recreate_from` method??
return cls

return wrapped


# FIXME: test me after clear event loop management
@auto_persist('_state', '_result')
class SavableFuture(futures.Future):
"""
Expand Down
6 changes: 5 additions & 1 deletion src/plumpy/process_listener.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any, Dict, Optional, Self
from typing import TYPE_CHECKING, Any, Dict, Optional

from typing_extensions import Self

from plumpy.loaders import ObjectLoader
from plumpy.persistence import LoadSaveContext, auto_save, ensure_object_loader
Expand Down
13 changes: 6 additions & 7 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Callable,
ClassVar,
Optional,
Self,
Tuple,
Type,
Union,
Expand All @@ -20,7 +19,7 @@
)

import yaml
from typing_extensions import override
from typing_extensions import Self, override
from yaml.loader import Loader

from plumpy.loaders import ObjectLoader
Expand Down Expand Up @@ -268,7 +267,7 @@ class Running:
COMMAND = 'command' # The key used to store an upcoming command

# Class level defaults
_command: Kill | Stop | Wait | Continue | None = None
_command: Command | None = None
_running: bool = False
_run_handle = None

Expand Down Expand Up @@ -311,13 +310,13 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa
obj.run_fn = ensure_coroutine(getattr(obj.process, saved_state[obj.RUN_FN]))
if obj.COMMAND in saved_state:
loaded_cmd = persistence.load(saved_state[obj.COMMAND], load_context)
if isinstance(loaded_cmd, Command):
if not isinstance(loaded_cmd, Command):
# runtime check for loading from persistence
obj._command = loaded_cmd
else:
# XXX: debug log in principle unreachable
raise RuntimeError(f'command `{obj.COMMAND}` loaded from Running state not a valid `Command` type')

obj._command = loaded_cmd

return obj

def interrupt(self, reason: Any) -> None:
Expand Down Expand Up @@ -354,7 +353,7 @@ async def execute(self) -> st.State:
next_state = self._action_command(command)
return next_state

def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.State:
def _action_command(self, command: Command) -> st.State:
if isinstance(command, Kill):
state = st.create_state(self.process, ProcessState.KILLED, msg=command.msg)
# elif isinstance(command, Pause):
Expand Down
2 changes: 1 addition & 1 deletion src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
Hashable,
List,
Optional,
Self,
Sequence,
Tuple,
Type,
Expand All @@ -37,6 +36,7 @@
)

import kiwipy
from typing_extensions import Self

from plumpy.coordinator import Coordinator
from plumpy.loaders import ObjectLoader
Expand Down
16 changes: 9 additions & 7 deletions src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
MutableSequence,
Optional,
Protocol,
Self,
Sequence,
Tuple,
Type,
Union,
cast,
)

from typing_extensions import Self

from plumpy import utils
from plumpy.base import state_machine
from plumpy.base.utils import call_with_super_check
Expand Down Expand Up @@ -399,7 +400,7 @@ def create_stepper(self, workchain: 'WorkChain') -> _FunctionStepper:

def recreate_stepper(self, saved_state: SAVED_STATE_TYPE, workchain: 'WorkChain') -> _FunctionStepper:
load_context = persistence.LoadSaveContext(workchain=workchain, func_spec=self)
return cast(_FunctionStepper, _FunctionStepper.recreate_from(saved_state, load_context))
return _FunctionStepper.recreate_from(saved_state, load_context)

def get_description(self) -> str:
desc = self._fn.__name__
Expand Down Expand Up @@ -479,6 +480,7 @@ class _Block(_Instruction, collections.abc.Sequence):
Represents a block of instructions i.e. a sequential list of instructions.
"""

# XXX: swap workchain and instructions
def __init__(self, instructions: Sequence[Union[_Instruction, WC_COMMAND_TYPE]]) -> None:
# Build up the list of commands
comms: MutableSequence[_Instruction | _FunctionCall] = []
Expand All @@ -502,7 +504,7 @@ def create_stepper(self, workchain: 'WorkChain') -> _BlockStepper:

def recreate_stepper(self, saved_state: SAVED_STATE_TYPE, workchain: 'WorkChain') -> _BlockStepper:
load_context = persistence.LoadSaveContext(workchain=workchain, block_instruction=self)
return cast(_BlockStepper, _BlockStepper.recreate_from(saved_state, load_context))
return _BlockStepper.recreate_from(saved_state, load_context)

def get_description(self) -> List[str]:
return [instruction.get_description() for instruction in self._instruction]
Expand Down Expand Up @@ -551,7 +553,7 @@ def is_true(self, workflow: 'WorkChain') -> bool:

return result

def __call__(self, *instructions: Union[_Instruction, WC_COMMAND_TYPE]) -> _Instruction:
def __call__(self, *instructions: _Instruction | WC_COMMAND_TYPE) -> _Instruction:
assert self._body is None, 'Instructions have already been set'
self._body = _Block(instructions)
return self._parent
Expand All @@ -562,6 +564,7 @@ def __str__(self) -> str:

@persistence.auto_persist('_pos')
class _IfStepper:
# XXX: swap workchain and if_instruction
def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None:
self._workchain = workchain
self._if_instruction = if_instruction
Expand Down Expand Up @@ -671,7 +674,7 @@ def create_stepper(self, workchain: 'WorkChain') -> _IfStepper:

def recreate_stepper(self, saved_state: SAVED_STATE_TYPE, workchain: 'WorkChain') -> _IfStepper:
load_context = persistence.LoadSaveContext(workchain=workchain, if_instruction=self)
return cast(_IfStepper, _IfStepper.recreate_from(saved_state, load_context))
return _IfStepper.recreate_from(saved_state, load_context)

def get_description(self) -> Mapping[str, Any]:
description = collections.OrderedDict()
Expand Down Expand Up @@ -759,7 +762,7 @@ def create_stepper(self, workchain: 'WorkChain') -> _WhileStepper:

def recreate_stepper(self, saved_state: SAVED_STATE_TYPE, workchain: 'WorkChain') -> _WhileStepper:
load_context = persistence.LoadSaveContext(workchain=workchain, while_instruction=self)
return cast(_WhileStepper, _WhileStepper.recreate_from(saved_state, load_context))
return _WhileStepper.recreate_from(saved_state, load_context)

def get_description(self) -> Dict[str, Any]:
return {f'while({self.predicate.__name__})': self.body.get_description()}
Expand All @@ -771,7 +774,6 @@ def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None:
self.exit_code = exit_code


@persistence.auto_persist()
class _ReturnStepper:
def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None:
self._workchain = workchain
Expand Down
1 change: 1 addition & 0 deletions tests/test_event_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from plumpy.event_helper import EventHelper
from plumpy.persistence import Savable, load
from tests.utils import DummyProcess, ProcessListenerTester
Expand Down
1 change: 1 addition & 0 deletions tests/test_process_listener.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from plumpy.persistence import Savable, load
from tests.utils import DummyProcess, ProcessListenerTester

Expand Down
2 changes: 2 additions & 0 deletions tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def test_process_is_savable():
proc = utils.DummyProcess()
assert isinstance(proc, Savable)

# TODO: direct save load round trip regression


@pytest.mark.asyncio
async def test_process_scope():
Expand Down
Empty file added tests/workchain/__init__.py
Empty file.
105 changes: 105 additions & 0 deletions tests/workchain/test_steppers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
import pytest
from plumpy.base.state_machine import StateMachine
from plumpy.persistence import LoadSaveContext, Savable, load
from plumpy.workchains import WorkChain, if_, while_


class DummyWc(WorkChain):
@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(
cls.do_nothing,
if_(cls.cond)(cls.do_cond),
while_(cls.wcond)(
cls.do_wcond,
),
)

@staticmethod
def do_nothing(_wc: WorkChain) -> None:
pass

@staticmethod
def cond(_wc: WorkChain) -> bool:
return True

@staticmethod
def do_cond(_wc: WorkChain) -> None:
pass

@staticmethod
def wcond(_wc: WorkChain) -> bool:
return True

@staticmethod
def do_wcond(_wc: WorkChain) -> None:
pass


@pytest.fixture(scope='function')
def wc() -> StateMachine:
return DummyWc()


def test_func_stepper_savable(wc: DummyWc):
from plumpy.workchains import _FunctionStepper

fs = _FunctionStepper(workchain=wc, fn=wc.do_nothing)
assert isinstance(fs, Savable)

ctx = LoadSaveContext(workchain=wc)
saved_state = fs.save()
loaded_state = load(saved_state=saved_state, load_context=ctx)
saved_state2 = loaded_state.save()

assert saved_state == saved_state2


def test_block_stepper_savable(wc: DummyWc):
"""block stepper test with a dummy function call"""
from plumpy.workchains import _BlockStepper, _FunctionCall

block = [_FunctionCall(wc.do_nothing)]
bs = _BlockStepper(block=block, workchain=wc)
assert isinstance(bs, Savable)

ctx = LoadSaveContext(workchain=wc, block_instruction=block)
saved_state = bs.save()
loaded_state = load(saved_state=saved_state, load_context=ctx)
saved_state2 = loaded_state.save()

assert saved_state == saved_state2


def test_if_stepper_savable(wc: DummyWc):
"""block stepper test with a dummy function call"""
from plumpy.workchains import _If, _IfStepper

dummy_if = _If(wc.cond)
ifs = _IfStepper(if_instruction=dummy_if, workchain=wc)
assert isinstance(ifs, Savable)

ctx = LoadSaveContext(workchain=wc, if_instruction=ifs)
saved_state = ifs.save()
loaded_state = load(saved_state=saved_state, load_context=ctx)
saved_state2 = loaded_state.save()

assert saved_state == saved_state2


def test_while_stepper_savable(wc: DummyWc):
"""block stepper test with a dummy function call"""
from plumpy.workchains import _While, _WhileStepper

dummy_while = _While(wc.cond)
wfs = _WhileStepper(while_instruction=dummy_while, workchain=wc)
assert isinstance(wfs, Savable)

ctx = LoadSaveContext(workchain=wc, while_instruction=wfs)
saved_state = wfs.save()
loaded_state = load(saved_state=saved_state, load_context=ctx)
saved_state2 = loaded_state.save()

assert saved_state == saved_state2
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from plumpy.process_listener import ProcessListener
from plumpy.workchains import *

from . import utils
from .. import utils

# FIXME: test steppers are savable and round trip persistence

class Wf(WorkChain):
# Keep track of which steps were completed by the workflow
Expand Down Expand Up @@ -90,6 +89,8 @@ def test_workchain_is_savable():
w = Wf(inputs=dict(value='A', n=3))
assert isinstance(w, Savable)

# TODO: direct regression save load round trip


class IfTest(WorkChain):
@classmethod
Expand Down

0 comments on commit 24d73a0

Please sign in to comment.