Skip to content

Commit

Permalink
Merge pull request #157 from scipp/prune-type-vars-v2
Browse files Browse the repository at this point in the history
Prune type vars v2
  • Loading branch information
jl-wynen authored Dec 17, 2024
2 parents f48e0cd + a7ecb43 commit 34cb6f4
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 32 deletions.
27 changes: 25 additions & 2 deletions src/ess/reduce/nexus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,17 @@ class TransmissionRun(Generic[ScatteringRunType]):
TransmissionRun[BackgroundRun],
VanadiumRun,
)
"""TypeVar used for specifying BackgroundRun, EmptyBeamRun or SampleRun"""
"""TypeVar for specifying what run some data belongs to.
Possible values:
- :class:`BackgroundRun`
- :class:`EmptyBeamRun`
- :class:`SampleRun`
- :class:`TransmissionRun`
- :class:`VanadiumRun`
"""


# 1.2 Monitor types
Monitor1 = NewType('Monitor1', int)
Expand Down Expand Up @@ -108,7 +118,20 @@ class TransmissionRun(Generic[ScatteringRunType]):
IncidentMonitor,
TransmissionMonitor,
)
"""TypeVar used for specifying the monitor type such as Incident or Transmission"""
"""TypeVar for specifying what monitor some data belongs to.
Possible values:
- :class:`Monitor1`
- :class:`Monitor2`
- :class:`Monitor3`
- :class:`Monitor4`
- :class:`Monitor5`
- :class:`Monitor6`
- :class:`IncidentMonitor`
- :class:`TransmissionMonitor`
"""


Component = TypeVar(
'Component',
Expand Down
71 changes: 49 additions & 22 deletions src/ess/reduce/nexus/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

"""Workflow and workflow components for interacting with NeXus files."""

from collections.abc import Sequence
from collections.abc import Iterable
from copy import deepcopy
from typing import Any

import networkx as nx
import sciline
import sciline.typing
import scipp as sc
import scippnexus as snx
from scipp.constants import g
Expand Down Expand Up @@ -649,30 +649,39 @@ def LoadDetectorWorkflow() -> sciline.Pipeline:

def GenericNeXusWorkflow(
*,
run_types: Sequence[sciline.typing.Key] | None = None,
monitor_types: Sequence[sciline.typing.Key] | None = None,
run_types: Iterable[sciline.typing.Key] | None = None,
monitor_types: Iterable[sciline.typing.Key] | None = None,
) -> sciline.Pipeline:
"""
Generic workflow for loading detector and monitor data from a NeXus file.
It is possible to limit which run types and monitor types
are supported by the returned workflow.
This is useful to reduce the size of the workflow and make it easier to inspect.
Make sure to add *all* required run types and monitor types when using this feature.
Attention
---------
Filtering by run type and monitor type does not work with nested type vars.
E.g., if you have a type like ``Outer[Inner[RunType]]``, this type and its
provider will be removed.
Parameters
----------
run_types:
List of run types to include in the workflow. If not provided, all run types
are included. It is recommended to specify run types to avoid creating very
large workflows.
are included.
Must be a possible value of :class:`ess.reduce.nexus.types.RunType`.
monitor_types:
List of monitor types to include in the workflow. If not provided, all monitor
types are included. It is recommended to specify monitor types to avoid creating
very large workflows.
types are included.
Must be a possible value of :class:`ess.reduce.nexus.types.MonitorType`.
Returns
-------
:
The workflow.
"""
if monitor_types is not None and run_types is None:
raise ValueError("run_types must be specified if monitor_types is specified")
wf = sciline.Pipeline(
(
*_common_providers,
Expand All @@ -685,16 +694,34 @@ def GenericNeXusWorkflow(
wf[DetectorBankSizes] = DetectorBankSizes({})
wf[PreopenNeXusFile] = PreopenNeXusFile(False)

g = wf.underlying_graph
ancestors = set()
# DetectorData and MonitorData are the "final" outputs, so finding and removing all
# their ancestors is what we need to strip unused run and monitor types.
for rt in run_types or ():
ancestors |= nx.ancestors(g, DetectorData[rt])
ancestors.add(DetectorData[rt])
for mt in monitor_types or ():
ancestors |= nx.ancestors(g, MonitorData[rt, mt])
ancestors.add(MonitorData[rt, mt])
if run_types is not None:
g.remove_nodes_from(set(g.nodes) - ancestors)
if run_types is not None or monitor_types is not None:
_prune_type_vars(wf, run_types=run_types, monitor_types=monitor_types)

return wf


def _prune_type_vars(
workflow: sciline.Pipeline,
*,
run_types: Iterable[sciline.typing.Key] | None,
monitor_types: Iterable[sciline.typing.Key] | None,
) -> None:
# Remove all nodes that use a run type or monitor types that is
# not listed in the function arguments.
excluded_run_types = _excluded_type_args(RunType, run_types)
excluded_monitor_types = _excluded_type_args(MonitorType, monitor_types)
excluded_types = excluded_run_types | excluded_monitor_types

graph = workflow.underlying_graph
to_remove = [
node for node in graph if excluded_types & set(getattr(node, "__args__", set()))
]
graph.remove_nodes_from(to_remove)


def _excluded_type_args(
type_var: Any, keep: Iterable[sciline.typing.Key] | None
) -> set[sciline.typing.Key]:
if keep is None:
return set()
return set(type_var.__constraints__) - set(keep)
82 changes: 74 additions & 8 deletions tests/nexus/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,20 @@
BackgroundRun,
Choppers,
DetectorData,
EmptyBeamRun,
Filename,
Monitor1,
Monitor2,
Monitor3,
MonitorData,
MonitorType,
NeXusComponentLocationSpec,
NeXusName,
NeXusTransformation,
RunType,
SampleRun,
TimeInterval,
TransmissionMonitor,
)
from ess.reduce.nexus.workflow import (
GenericNeXusWorkflow,
Expand Down Expand Up @@ -574,16 +578,11 @@ def test_generic_nexus_workflow_load_analyzers() -> None:
assert analyzer['usage'] == 'Bragg'


def test_generic_nexus_workflow_raises_if_monitor_types_but_not_run_types_given() -> (
None
):
with pytest.raises(ValueError, match='run_types'):
GenericNeXusWorkflow(monitor_types=[Monitor1])


def test_generic_nexus_workflow_includes_only_given_run_and_monitor_types() -> None:
wf = GenericNeXusWorkflow(run_types=[SampleRun], monitor_types=[Monitor1, Monitor3])
graph = wf.underlying_graph

# Check some examples to avoid relying entirely on complicated loops below.
assert DetectorData[SampleRun] in graph
assert DetectorData[BackgroundRun] not in graph
assert MonitorData[SampleRun, Monitor1] in graph
Expand All @@ -592,7 +591,11 @@ def test_generic_nexus_workflow_includes_only_given_run_and_monitor_types() -> N
assert MonitorData[BackgroundRun, Monitor1] not in graph
assert MonitorData[BackgroundRun, Monitor2] not in graph
assert MonitorData[BackgroundRun, Monitor3] not in graph
# Many other keys are also removed, this is just an example
assert Choppers[SampleRun] in graph
assert Choppers[BackgroundRun] not in graph
assert Analyzers[SampleRun] in graph
assert Analyzers[BackgroundRun] not in graph

assert NeXusComponentLocationSpec[Monitor1, SampleRun] in graph
assert NeXusComponentLocationSpec[Monitor2, SampleRun] not in graph
assert NeXusComponentLocationSpec[Monitor3, SampleRun] in graph
Expand All @@ -605,3 +608,66 @@ def test_generic_nexus_workflow_includes_only_given_run_and_monitor_types() -> N
assert NeXusComponentLocationSpec[snx.NXdetector, BackgroundRun] not in graph
assert NeXusComponentLocationSpec[snx.NXsample, BackgroundRun] not in graph
assert NeXusComponentLocationSpec[snx.NXsource, BackgroundRun] not in graph

excluded_run_types = set(RunType.__constraints__) - {SampleRun}
excluded_monitor_types = set(MonitorType.__constraints__) - {Monitor1, Monitor3}
for node in graph:
assert_not_contains_type_arg(node, excluded_run_types)
assert_not_contains_type_arg(node, excluded_monitor_types)


def test_generic_nexus_workflow_includes_only_given_run_types() -> None:
wf = GenericNeXusWorkflow(run_types=[EmptyBeamRun])
graph = wf.underlying_graph

# Check some examples to avoid relying entirely on complicated loops below.
assert DetectorData[EmptyBeamRun] in graph
assert DetectorData[SampleRun] not in graph
assert MonitorData[EmptyBeamRun, Monitor1] in graph
assert MonitorData[EmptyBeamRun, Monitor2] in graph
assert MonitorData[EmptyBeamRun, Monitor3] in graph
assert MonitorData[SampleRun, Monitor1] not in graph
assert MonitorData[SampleRun, Monitor2] not in graph
assert MonitorData[SampleRun, Monitor3] not in graph
assert Choppers[EmptyBeamRun] in graph
assert Choppers[SampleRun] not in graph
assert Analyzers[EmptyBeamRun] in graph
assert Analyzers[SampleRun] not in graph

excluded_run_types = set(RunType.__constraints__) - {EmptyBeamRun}
for node in graph:
assert_not_contains_type_arg(node, excluded_run_types)


def test_generic_nexus_workflow_includes_only_given_monitor_types() -> None:
wf = GenericNeXusWorkflow(monitor_types=[TransmissionMonitor, Monitor1])
graph = wf.underlying_graph

# Check some examples to avoid relying entirely on complicated loops below.
assert DetectorData[SampleRun] in graph
assert DetectorData[BackgroundRun] in graph
assert MonitorData[SampleRun, TransmissionMonitor] in graph
assert MonitorData[SampleRun, Monitor1] in graph
assert MonitorData[SampleRun, Monitor2] not in graph
assert MonitorData[SampleRun, Monitor3] not in graph
assert MonitorData[BackgroundRun, TransmissionMonitor] in graph
assert MonitorData[BackgroundRun, Monitor1] in graph
assert MonitorData[BackgroundRun, Monitor2] not in graph
assert MonitorData[BackgroundRun, Monitor3] not in graph
assert Choppers[SampleRun] in graph
assert Choppers[BackgroundRun] in graph
assert Analyzers[SampleRun] in graph
assert Analyzers[BackgroundRun] in graph

excluded_monitor_types = set(MonitorType.__constraints__) - {
Monitor1,
TransmissionMonitor,
}
for node in graph:
assert_not_contains_type_arg(node, excluded_monitor_types)


def assert_not_contains_type_arg(node: object, excluded: set[type]) -> None:
assert not any(
arg in excluded for arg in getattr(node, "__args__", ())
), f"Node {node} contains one of {excluded!r}"

0 comments on commit 34cb6f4

Please sign in to comment.