Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split interface of StreamProcessor into two parts #175

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 53 additions & 3 deletions src/ess/reduce/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import deepcopy
from typing import Any, Generic, TypeVar

import networkx as nx
Expand All @@ -29,6 +30,8 @@ def maybe_hist(value: T) -> T:
:
Histogram.
"""
if not isinstance(value, sc.Variable | sc.DataArray):
return value
return value if value.bins is None else value.hist()


Expand Down Expand Up @@ -90,11 +93,11 @@ def __init__(self, **kwargs: Any) -> None:

@property
def value(self) -> T:
return self._value.copy()
return deepcopy(self._value)

def _do_push(self, value: T) -> None:
if self._value is None:
self._value = value.copy()
self._value = deepcopy(value)
else:
self._value += value

Expand Down Expand Up @@ -146,6 +149,7 @@ def __init__(
target_keys: tuple[sciline.typing.Key, ...],
accumulators: dict[sciline.typing.Key, Accumulator, Callable[..., Accumulator]]
| tuple[sciline.typing.Key, ...],
allow_bypass: bool = False,
) -> None:
"""
Create a stream processor.
Expand All @@ -163,6 +167,12 @@ def __init__(
passed, :py:class:`EternalAccumulator` is used for all keys. Otherwise, a
dict mapping keys to accumulator instances can be passed. If a dict value is
a callable, base_workflow.bind_and_call(value) is used to make an instance.
allow_bypass:
If True, allow bypassing accumulators for keys that are not in the
accumulators dict. This is useful for dynamic keys that are not "terminated"
in any accumulator. USE WITH CARE! This will lead to incorrect results
unless the values for these keys are valid for all chunks comprised in the
final accumulators at the point where :py:meth:`finalize` is called.
"""
workflow = sciline.Pipeline()
for key in target_keys:
Expand Down Expand Up @@ -201,19 +211,59 @@ def __init__(
for key, value in self._accumulators.items()
}
self._target_keys = target_keys
self._allow_bypass = allow_bypass

def add_chunk(
self, chunks: dict[sciline.typing.Key, Any]
) -> dict[sciline.typing.Key, Any]:
"""
Legacy interface for accumulating values from chunks and finalizing the result.

It is recommended to use :py:meth:`accumulate` and :py:meth:`finalize` instead.

Parameters
----------
chunks:
Chunks to be processed.

Returns
-------
:
Finalized result.
"""
self.accumulate(chunks)
return self.finalize()

def accumulate(self, chunks: dict[sciline.typing.Key, Any]) -> None:
"""
Accumulate values from chunks without finalizing the result.

Parameters
----------
chunks:
Chunks to be processed.
"""
for key, value in chunks.items():
self._process_chunk_workflow[key] = value
# There can be dynamic keys that do not "terminate" in any accumulator. In
# that case, we need to make sure they can be and are used when computing
# the target keys.
self._finalize_workflow[key] = value
if self._allow_bypass:
self._finalize_workflow[key] = value
to_accumulate = self._process_chunk_workflow.compute(self._accumulators)
for key, processed in to_accumulate.items():
self._accumulators[key].push(processed)

def finalize(self) -> dict[sciline.typing.Key, Any]:
"""
Get the final result by computing the target keys based on accumulated values.

Returns
-------
:
Finalized result.
"""
for key in self._accumulators:
self._finalize_workflow[key] = self._accumulators[key].value
return self._finalize_workflow.compute(self._target_keys)

Expand Down
100 changes: 100 additions & 0 deletions tests/streaming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import NewType

import pytest
import sciline
import scipp as sc

Expand Down Expand Up @@ -214,6 +215,7 @@ def test_StreamProcess_with_zero_accumulators_for_buffered_workflow_calls() -> N
dynamic_keys=(DynamicA, DynamicB),
target_keys=(Target,),
accumulators=(),
allow_bypass=True,
)
result = streaming_wf.add_chunk({DynamicA: sc.scalar(1), DynamicB: sc.scalar(4)})
assert sc.identical(result[Target], sc.scalar(2 * 1.0 / 4.0))
Expand All @@ -222,3 +224,101 @@ def test_StreamProcess_with_zero_accumulators_for_buffered_workflow_calls() -> N
result = streaming_wf.add_chunk({DynamicA: sc.scalar(3), DynamicB: sc.scalar(6)})
assert sc.identical(result[Target], sc.scalar(2 * 3.0 / 6.0))
assert make_static_a.call_count == 1


def test_StreamProcessor_with_bypass() -> None:
def _make_static_a() -> StaticA:
_make_static_a.call_count += 1
return StaticA(2.0)

_make_static_a.call_count = 0

base_workflow = sciline.Pipeline(
(_make_static_a, make_accum_a, make_accum_b, make_target)
)
orig_workflow = base_workflow.copy()

streaming_wf = streaming.StreamProcessor(
base_workflow=base_workflow,
dynamic_keys=(DynamicA, DynamicB),
target_keys=(Target,),
accumulators=(AccumA,), # Note: No AccumB
allow_bypass=True,
)
streaming_wf.accumulate({DynamicA: sc.scalar(1), DynamicB: sc.scalar(4)})
result = streaming_wf.finalize()
assert sc.identical(result[Target], sc.scalar(2 * 1.0 / 4.0))
streaming_wf.accumulate({DynamicA: sc.scalar(2), DynamicB: sc.scalar(5)})
result = streaming_wf.finalize()
# Note denominator is 5, not 9
assert sc.identical(result[Target], sc.scalar(2 * 3.0 / 5.0))
streaming_wf.accumulate({DynamicA: sc.scalar(3), DynamicB: sc.scalar(6)})
result = streaming_wf.finalize()
# Note denominator is 6, not 15
assert sc.identical(result[Target], sc.scalar(2 * 6.0 / 6.0))
assert _make_static_a.call_count == 1

# Consistency check: Run the original workflow with the same inputs, all at once
orig_workflow[DynamicA] = sc.scalar(1 + 2 + 3)
orig_workflow[DynamicB] = sc.scalar(6)
expected = orig_workflow.compute(Target)
assert sc.identical(expected, result[Target])


def test_StreamProcessor_without_bypass_raises() -> None:
def _make_static_a() -> StaticA:
_make_static_a.call_count += 1
return StaticA(2.0)

_make_static_a.call_count = 0

base_workflow = sciline.Pipeline(
(_make_static_a, make_accum_a, make_accum_b, make_target)
)

streaming_wf = streaming.StreamProcessor(
base_workflow=base_workflow,
dynamic_keys=(DynamicA, DynamicB),
target_keys=(Target,),
accumulators=(AccumA,), # Note: No AccumB
)
streaming_wf.accumulate({DynamicA: 1, DynamicB: 4})
# Sciline passes `None` to the provider that needs AccumB.
with pytest.raises(TypeError, match='unsupported operand type'):
_ = streaming_wf.finalize()


def test_StreamProcessor_calls_providers_after_accumulators_only_when_finalizing() -> (
None
):
def _make_target(accum_a: AccumA, accum_b: AccumB) -> Target:
_make_target.call_count += 1
return Target(accum_a / accum_b)

_make_target.call_count = 0

base_workflow = sciline.Pipeline(
(make_accum_a, make_accum_b, _make_target), params={StaticA: 2.0}
)

streaming_wf = streaming.StreamProcessor(
base_workflow=base_workflow,
dynamic_keys=(DynamicA, DynamicB),
target_keys=(Target,),
accumulators=(AccumA, AccumB),
)
streaming_wf.accumulate({DynamicA: sc.scalar(1), DynamicB: sc.scalar(4)})
streaming_wf.accumulate({DynamicA: sc.scalar(2), DynamicB: sc.scalar(5)})
assert _make_target.call_count == 0
result = streaming_wf.finalize()
assert sc.identical(result[Target], sc.scalar(2 * 3.0 / 9.0))
assert _make_target.call_count == 1
streaming_wf.accumulate({DynamicA: sc.scalar(3), DynamicB: sc.scalar(6)})
assert _make_target.call_count == 1
result = streaming_wf.finalize()
assert sc.identical(result[Target], sc.scalar(2 * 6.0 / 15.0))
assert _make_target.call_count == 2
result = streaming_wf.finalize()
assert sc.identical(result[Target], sc.scalar(2 * 6.0 / 15.0))
# Outputs are not cached.
assert _make_target.call_count == 3