Skip to content

Commit

Permalink
Tests & small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonHeybrock committed Feb 4, 2025
1 parent 4d91258 commit 8abd87d
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 21 deletions.
37 changes: 16 additions & 21 deletions src/ess/reduce/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import deepcopy
from typing import Any, Generic, TypeVar

import networkx as nx
Expand All @@ -29,6 +30,8 @@ def maybe_hist(value: T) -> T:
:
Histogram.
"""
if not isinstance(value, sc.Variable | sc.DataArray):
return value
return value if value.bins is None else value.hist()


Expand Down Expand Up @@ -90,11 +93,11 @@ def __init__(self, **kwargs: Any) -> None:

@property
def value(self) -> T:
return self._value.copy()
return deepcopy(self._value)

def _do_push(self, value: T) -> None:
if self._value is None:
self._value = value.copy()
self._value = deepcopy(value)
else:
self._value += value

Expand Down Expand Up @@ -146,6 +149,7 @@ def __init__(
target_keys: tuple[sciline.typing.Key, ...],
accumulators: dict[sciline.typing.Key, Accumulator, Callable[..., Accumulator]]
| tuple[sciline.typing.Key, ...],
allow_bypass: bool = False,
) -> None:
"""
Create a stream processor.
Expand All @@ -163,6 +167,12 @@ def __init__(
passed, :py:class:`EternalAccumulator` is used for all keys. Otherwise, a
dict mapping keys to accumulator instances can be passed. If a dict value is
a callable, base_workflow.bind_and_call(value) is used to make an instance.
allow_bypass:
If True, allow bypassing accumulators for keys that are not in the
accumulators dict. This is useful for dynamic keys that are not "terminated"
in any accumulator. USE WITH CARE! This will lead to incorrect results
unless the values for these keys are valid for all chunks comprised in the
final accumulators at the point where :py:meth:`finalize` is called.
"""
workflow = sciline.Pipeline()
for key in target_keys:
Expand Down Expand Up @@ -201,6 +211,7 @@ def __init__(
for key, value in self._accumulators.items()
}
self._target_keys = target_keys
self._allow_bypass = allow_bypass

def add_chunk(
self, chunks: dict[sciline.typing.Key, Any]
Expand All @@ -210,14 +221,6 @@ def add_chunk(
It is recommended to use :py:meth:`accumulate` and :py:meth:`finalize` instead.
Warning
-------
This method automatically bypasses accumulators for keys that are not in the
accumulators dict. This is useful for dynamic keys that are not "terminated" in
any accumulator. USE WITH CARE! This will lead to incorrect results unless the
values for these keys are valid for all chunks comprised in the final result.
Parameters
----------
chunks:
Expand All @@ -228,32 +231,24 @@ def add_chunk(
:
Finalized result.
"""
self.accumulate(chunks, allow_bypass=True)
self.accumulate(chunks)
return self.finalize()

def accumulate(
self, chunks: dict[sciline.typing.Key, Any], allow_bypass: bool = False
) -> None:
def accumulate(self, chunks: dict[sciline.typing.Key, Any]) -> None:
"""
Accumulate values from chunks without finalizing the result.
Parameters
----------
chunks:
Chunks to be processed.
allow_bypass:
If True, allow bypassing accumulators for keys that are not in the
accumulators dict. This is useful for dynamic keys that are not "terminated"
in any accumulator. USE WITH CARE! This will lead to incorrect results
unless the values for these keys are valid for all chunks comprised in the
final result.
"""
for key, value in chunks.items():
self._process_chunk_workflow[key] = value
# There can be dynamic keys that do not "terminate" in any accumulator. In
# that case, we need to make sure they can be and are used when computing
# the target keys.
if allow_bypass:
if self._allow_bypass:
self._finalize_workflow[key] = value
to_accumulate = self._process_chunk_workflow.compute(self._accumulators)
for key, processed in to_accumulate.items():
Expand Down
100 changes: 100 additions & 0 deletions tests/streaming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import NewType

import pytest
import sciline
import scipp as sc

Expand Down Expand Up @@ -214,6 +215,7 @@ def test_StreamProcess_with_zero_accumulators_for_buffered_workflow_calls() -> N
dynamic_keys=(DynamicA, DynamicB),
target_keys=(Target,),
accumulators=(),
allow_bypass=True,
)
result = streaming_wf.add_chunk({DynamicA: sc.scalar(1), DynamicB: sc.scalar(4)})
assert sc.identical(result[Target], sc.scalar(2 * 1.0 / 4.0))
Expand All @@ -222,3 +224,101 @@ def test_StreamProcess_with_zero_accumulators_for_buffered_workflow_calls() -> N
result = streaming_wf.add_chunk({DynamicA: sc.scalar(3), DynamicB: sc.scalar(6)})
assert sc.identical(result[Target], sc.scalar(2 * 3.0 / 6.0))
assert make_static_a.call_count == 1


def test_StreamProcessor_with_bypass() -> None:
def _make_static_a() -> StaticA:
_make_static_a.call_count += 1
return StaticA(2.0)

_make_static_a.call_count = 0

base_workflow = sciline.Pipeline(
(_make_static_a, make_accum_a, make_accum_b, make_target)
)
orig_workflow = base_workflow.copy()

streaming_wf = streaming.StreamProcessor(
base_workflow=base_workflow,
dynamic_keys=(DynamicA, DynamicB),
target_keys=(Target,),
accumulators=(AccumA,), # Note: No AccumB
allow_bypass=True,
)
streaming_wf.accumulate({DynamicA: sc.scalar(1), DynamicB: sc.scalar(4)})
result = streaming_wf.finalize()
assert sc.identical(result[Target], sc.scalar(2 * 1.0 / 4.0))
streaming_wf.accumulate({DynamicA: sc.scalar(2), DynamicB: sc.scalar(5)})
result = streaming_wf.finalize()
# Note denominator is 5, not 9
assert sc.identical(result[Target], sc.scalar(2 * 3.0 / 5.0))
streaming_wf.accumulate({DynamicA: sc.scalar(3), DynamicB: sc.scalar(6)})
result = streaming_wf.finalize()
# Note denominator is 6, not 15
assert sc.identical(result[Target], sc.scalar(2 * 6.0 / 6.0))
assert _make_static_a.call_count == 1

# Consistency check: Run the original workflow with the same inputs, all at once
orig_workflow[DynamicA] = sc.scalar(1 + 2 + 3)
orig_workflow[DynamicB] = sc.scalar(6)
expected = orig_workflow.compute(Target)
assert sc.identical(expected, result[Target])


def test_StreamProcessor_without_bypass_raises() -> None:
def _make_static_a() -> StaticA:
_make_static_a.call_count += 1
return StaticA(2.0)

_make_static_a.call_count = 0

base_workflow = sciline.Pipeline(
(_make_static_a, make_accum_a, make_accum_b, make_target)
)

streaming_wf = streaming.StreamProcessor(
base_workflow=base_workflow,
dynamic_keys=(DynamicA, DynamicB),
target_keys=(Target,),
accumulators=(AccumA,), # Note: No AccumB
)
streaming_wf.accumulate({DynamicA: 1, DynamicB: 4})
# Sciline passes `None` to the provider that needs AccumB.
with pytest.raises(TypeError, match='unsupported operand type'):
_ = streaming_wf.finalize()


def test_StreamProcessor_calls_providers_after_accumulators_only_when_finalizing() -> (
None
):
def _make_target(accum_a: AccumA, accum_b: AccumB) -> Target:
_make_target.call_count += 1
return Target(accum_a / accum_b)

_make_target.call_count = 0

base_workflow = sciline.Pipeline(
(make_accum_a, make_accum_b, _make_target), params={StaticA: 2.0}
)

streaming_wf = streaming.StreamProcessor(
base_workflow=base_workflow,
dynamic_keys=(DynamicA, DynamicB),
target_keys=(Target,),
accumulators=(AccumA, AccumB),
)
streaming_wf.accumulate({DynamicA: sc.scalar(1), DynamicB: sc.scalar(4)})
streaming_wf.accumulate({DynamicA: sc.scalar(2), DynamicB: sc.scalar(5)})
assert _make_target.call_count == 0
result = streaming_wf.finalize()
assert sc.identical(result[Target], sc.scalar(2 * 3.0 / 9.0))
assert _make_target.call_count == 1
streaming_wf.accumulate({DynamicA: sc.scalar(3), DynamicB: sc.scalar(6)})
assert _make_target.call_count == 1
result = streaming_wf.finalize()
assert sc.identical(result[Target], sc.scalar(2 * 6.0 / 15.0))
assert _make_target.call_count == 2
result = streaming_wf.finalize()
assert sc.identical(result[Target], sc.scalar(2 * 6.0 / 15.0))
# Outputs are not cached.
assert _make_target.call_count == 3

0 comments on commit 8abd87d

Please sign in to comment.