diff --git a/src/ess/reduce/streaming.py b/src/ess/reduce/streaming.py index bc138b41..507c615f 100644 --- a/src/ess/reduce/streaming.py +++ b/src/ess/reduce/streaming.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable +from copy import deepcopy from typing import Any, Generic, TypeVar import networkx as nx @@ -29,6 +30,8 @@ def maybe_hist(value: T) -> T: : Histogram. """ + if not isinstance(value, sc.Variable | sc.DataArray): + return value return value if value.bins is None else value.hist() @@ -90,11 +93,11 @@ def __init__(self, **kwargs: Any) -> None: @property def value(self) -> T: - return self._value.copy() + return deepcopy(self._value) def _do_push(self, value: T) -> None: if self._value is None: - self._value = value.copy() + self._value = deepcopy(value) else: self._value += value @@ -146,6 +149,7 @@ def __init__( target_keys: tuple[sciline.typing.Key, ...], accumulators: dict[sciline.typing.Key, Accumulator, Callable[..., Accumulator]] | tuple[sciline.typing.Key, ...], + allow_bypass: bool = False, ) -> None: """ Create a stream processor. @@ -163,6 +167,12 @@ def __init__( passed, :py:class:`EternalAccumulator` is used for all keys. Otherwise, a dict mapping keys to accumulator instances can be passed. If a dict value is a callable, base_workflow.bind_and_call(value) is used to make an instance. + allow_bypass: + If True, allow bypassing accumulators for keys that are not in the + accumulators dict. This is useful for dynamic keys that are not "terminated" + in any accumulator. USE WITH CARE! This will lead to incorrect results + unless the values for these keys are valid for all chunks comprised in the + final accumulators at the point where :py:meth:`finalize` is called. """ workflow = sciline.Pipeline() for key in target_keys: @@ -201,19 +211,59 @@ def __init__( for key, value in self._accumulators.items() } self._target_keys = target_keys + self._allow_bypass = allow_bypass def add_chunk( self, chunks: dict[sciline.typing.Key, Any] ) -> dict[sciline.typing.Key, Any]: + """ + Legacy interface for accumulating values from chunks and finalizing the result. + + It is recommended to use :py:meth:`accumulate` and :py:meth:`finalize` instead. + + Parameters + ---------- + chunks: + Chunks to be processed. + + Returns + ------- + : + Finalized result. + """ + self.accumulate(chunks) + return self.finalize() + + def accumulate(self, chunks: dict[sciline.typing.Key, Any]) -> None: + """ + Accumulate values from chunks without finalizing the result. + + Parameters + ---------- + chunks: + Chunks to be processed. + """ for key, value in chunks.items(): self._process_chunk_workflow[key] = value # There can be dynamic keys that do not "terminate" in any accumulator. In # that case, we need to make sure they can be and are used when computing # the target keys. - self._finalize_workflow[key] = value + if self._allow_bypass: + self._finalize_workflow[key] = value to_accumulate = self._process_chunk_workflow.compute(self._accumulators) for key, processed in to_accumulate.items(): self._accumulators[key].push(processed) + + def finalize(self) -> dict[sciline.typing.Key, Any]: + """ + Get the final result by computing the target keys based on accumulated values. + + Returns + ------- + : + Finalized result. + """ + for key in self._accumulators: self._finalize_workflow[key] = self._accumulators[key].value return self._finalize_workflow.compute(self._target_keys) diff --git a/tests/streaming_test.py b/tests/streaming_test.py index a4739149..407fac4c 100644 --- a/tests/streaming_test.py +++ b/tests/streaming_test.py @@ -3,6 +3,7 @@ from typing import NewType +import pytest import sciline import scipp as sc @@ -214,6 +215,7 @@ def test_StreamProcess_with_zero_accumulators_for_buffered_workflow_calls() -> N dynamic_keys=(DynamicA, DynamicB), target_keys=(Target,), accumulators=(), + allow_bypass=True, ) result = streaming_wf.add_chunk({DynamicA: sc.scalar(1), DynamicB: sc.scalar(4)}) assert sc.identical(result[Target], sc.scalar(2 * 1.0 / 4.0)) @@ -222,3 +224,101 @@ def test_StreamProcess_with_zero_accumulators_for_buffered_workflow_calls() -> N result = streaming_wf.add_chunk({DynamicA: sc.scalar(3), DynamicB: sc.scalar(6)}) assert sc.identical(result[Target], sc.scalar(2 * 3.0 / 6.0)) assert make_static_a.call_count == 1 + + +def test_StreamProcessor_with_bypass() -> None: + def _make_static_a() -> StaticA: + _make_static_a.call_count += 1 + return StaticA(2.0) + + _make_static_a.call_count = 0 + + base_workflow = sciline.Pipeline( + (_make_static_a, make_accum_a, make_accum_b, make_target) + ) + orig_workflow = base_workflow.copy() + + streaming_wf = streaming.StreamProcessor( + base_workflow=base_workflow, + dynamic_keys=(DynamicA, DynamicB), + target_keys=(Target,), + accumulators=(AccumA,), # Note: No AccumB + allow_bypass=True, + ) + streaming_wf.accumulate({DynamicA: sc.scalar(1), DynamicB: sc.scalar(4)}) + result = streaming_wf.finalize() + assert sc.identical(result[Target], sc.scalar(2 * 1.0 / 4.0)) + streaming_wf.accumulate({DynamicA: sc.scalar(2), DynamicB: sc.scalar(5)}) + result = streaming_wf.finalize() + # Note denominator is 5, not 9 + assert sc.identical(result[Target], sc.scalar(2 * 3.0 / 5.0)) + streaming_wf.accumulate({DynamicA: sc.scalar(3), DynamicB: sc.scalar(6)}) + result = streaming_wf.finalize() + # Note denominator is 6, not 15 + assert sc.identical(result[Target], sc.scalar(2 * 6.0 / 6.0)) + assert _make_static_a.call_count == 1 + + # Consistency check: Run the original workflow with the same inputs, all at once + orig_workflow[DynamicA] = sc.scalar(1 + 2 + 3) + orig_workflow[DynamicB] = sc.scalar(6) + expected = orig_workflow.compute(Target) + assert sc.identical(expected, result[Target]) + + +def test_StreamProcessor_without_bypass_raises() -> None: + def _make_static_a() -> StaticA: + _make_static_a.call_count += 1 + return StaticA(2.0) + + _make_static_a.call_count = 0 + + base_workflow = sciline.Pipeline( + (_make_static_a, make_accum_a, make_accum_b, make_target) + ) + + streaming_wf = streaming.StreamProcessor( + base_workflow=base_workflow, + dynamic_keys=(DynamicA, DynamicB), + target_keys=(Target,), + accumulators=(AccumA,), # Note: No AccumB + ) + streaming_wf.accumulate({DynamicA: 1, DynamicB: 4}) + # Sciline passes `None` to the provider that needs AccumB. + with pytest.raises(TypeError, match='unsupported operand type'): + _ = streaming_wf.finalize() + + +def test_StreamProcessor_calls_providers_after_accumulators_only_when_finalizing() -> ( + None +): + def _make_target(accum_a: AccumA, accum_b: AccumB) -> Target: + _make_target.call_count += 1 + return Target(accum_a / accum_b) + + _make_target.call_count = 0 + + base_workflow = sciline.Pipeline( + (make_accum_a, make_accum_b, _make_target), params={StaticA: 2.0} + ) + + streaming_wf = streaming.StreamProcessor( + base_workflow=base_workflow, + dynamic_keys=(DynamicA, DynamicB), + target_keys=(Target,), + accumulators=(AccumA, AccumB), + ) + streaming_wf.accumulate({DynamicA: sc.scalar(1), DynamicB: sc.scalar(4)}) + streaming_wf.accumulate({DynamicA: sc.scalar(2), DynamicB: sc.scalar(5)}) + assert _make_target.call_count == 0 + result = streaming_wf.finalize() + assert sc.identical(result[Target], sc.scalar(2 * 3.0 / 9.0)) + assert _make_target.call_count == 1 + streaming_wf.accumulate({DynamicA: sc.scalar(3), DynamicB: sc.scalar(6)}) + assert _make_target.call_count == 1 + result = streaming_wf.finalize() + assert sc.identical(result[Target], sc.scalar(2 * 6.0 / 15.0)) + assert _make_target.call_count == 2 + result = streaming_wf.finalize() + assert sc.identical(result[Target], sc.scalar(2 * 6.0 / 15.0)) + # Outputs are not cached. + assert _make_target.call_count == 3