Skip to content

Commit 34cb6f4

Browse files
authored
Merge pull request #157 from scipp/prune-type-vars-v2
Prune type vars v2
2 parents f48e0cd + a7ecb43 commit 34cb6f4

File tree

3 files changed

+148
-32
lines changed

3 files changed

+148
-32
lines changed

src/ess/reduce/nexus/types.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,17 @@ class TransmissionRun(Generic[ScatteringRunType]):
7878
TransmissionRun[BackgroundRun],
7979
VanadiumRun,
8080
)
81-
"""TypeVar used for specifying BackgroundRun, EmptyBeamRun or SampleRun"""
81+
"""TypeVar for specifying what run some data belongs to.
82+
83+
Possible values:
84+
85+
- :class:`BackgroundRun`
86+
- :class:`EmptyBeamRun`
87+
- :class:`SampleRun`
88+
- :class:`TransmissionRun`
89+
- :class:`VanadiumRun`
90+
"""
91+
8292

8393
# 1.2 Monitor types
8494
Monitor1 = NewType('Monitor1', int)
@@ -108,7 +118,20 @@ class TransmissionRun(Generic[ScatteringRunType]):
108118
IncidentMonitor,
109119
TransmissionMonitor,
110120
)
111-
"""TypeVar used for specifying the monitor type such as Incident or Transmission"""
121+
"""TypeVar for specifying what monitor some data belongs to.
122+
123+
Possible values:
124+
125+
- :class:`Monitor1`
126+
- :class:`Monitor2`
127+
- :class:`Monitor3`
128+
- :class:`Monitor4`
129+
- :class:`Monitor5`
130+
- :class:`Monitor6`
131+
- :class:`IncidentMonitor`
132+
- :class:`TransmissionMonitor`
133+
"""
134+
112135

113136
Component = TypeVar(
114137
'Component',

src/ess/reduce/nexus/workflow.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
"""Workflow and workflow components for interacting with NeXus files."""
55

6-
from collections.abc import Sequence
6+
from collections.abc import Iterable
77
from copy import deepcopy
88
from typing import Any
99

10-
import networkx as nx
1110
import sciline
11+
import sciline.typing
1212
import scipp as sc
1313
import scippnexus as snx
1414
from scipp.constants import g
@@ -649,30 +649,39 @@ def LoadDetectorWorkflow() -> sciline.Pipeline:
649649

650650
def GenericNeXusWorkflow(
651651
*,
652-
run_types: Sequence[sciline.typing.Key] | None = None,
653-
monitor_types: Sequence[sciline.typing.Key] | None = None,
652+
run_types: Iterable[sciline.typing.Key] | None = None,
653+
monitor_types: Iterable[sciline.typing.Key] | None = None,
654654
) -> sciline.Pipeline:
655655
"""
656656
Generic workflow for loading detector and monitor data from a NeXus file.
657657
658+
It is possible to limit which run types and monitor types
659+
are supported by the returned workflow.
660+
This is useful to reduce the size of the workflow and make it easier to inspect.
661+
Make sure to add *all* required run types and monitor types when using this feature.
662+
663+
Attention
664+
---------
665+
Filtering by run type and monitor type does not work with nested type vars.
666+
E.g., if you have a type like ``Outer[Inner[RunType]]``, this type and its
667+
provider will be removed.
668+
658669
Parameters
659670
----------
660671
run_types:
661672
List of run types to include in the workflow. If not provided, all run types
662-
are included. It is recommended to specify run types to avoid creating very
663-
large workflows.
673+
are included.
674+
Must be a possible value of :class:`ess.reduce.nexus.types.RunType`.
664675
monitor_types:
665676
List of monitor types to include in the workflow. If not provided, all monitor
666-
types are included. It is recommended to specify monitor types to avoid creating
667-
very large workflows.
677+
types are included.
678+
Must be a possible value of :class:`ess.reduce.nexus.types.MonitorType`.
668679
669680
Returns
670681
-------
671682
:
672683
The workflow.
673684
"""
674-
if monitor_types is not None and run_types is None:
675-
raise ValueError("run_types must be specified if monitor_types is specified")
676685
wf = sciline.Pipeline(
677686
(
678687
*_common_providers,
@@ -685,16 +694,34 @@ def GenericNeXusWorkflow(
685694
wf[DetectorBankSizes] = DetectorBankSizes({})
686695
wf[PreopenNeXusFile] = PreopenNeXusFile(False)
687696

688-
g = wf.underlying_graph
689-
ancestors = set()
690-
# DetectorData and MonitorData are the "final" outputs, so finding and removing all
691-
# their ancestors is what we need to strip unused run and monitor types.
692-
for rt in run_types or ():
693-
ancestors |= nx.ancestors(g, DetectorData[rt])
694-
ancestors.add(DetectorData[rt])
695-
for mt in monitor_types or ():
696-
ancestors |= nx.ancestors(g, MonitorData[rt, mt])
697-
ancestors.add(MonitorData[rt, mt])
698-
if run_types is not None:
699-
g.remove_nodes_from(set(g.nodes) - ancestors)
697+
if run_types is not None or monitor_types is not None:
698+
_prune_type_vars(wf, run_types=run_types, monitor_types=monitor_types)
699+
700700
return wf
701+
702+
703+
def _prune_type_vars(
704+
workflow: sciline.Pipeline,
705+
*,
706+
run_types: Iterable[sciline.typing.Key] | None,
707+
monitor_types: Iterable[sciline.typing.Key] | None,
708+
) -> None:
709+
# Remove all nodes that use a run type or monitor types that is
710+
# not listed in the function arguments.
711+
excluded_run_types = _excluded_type_args(RunType, run_types)
712+
excluded_monitor_types = _excluded_type_args(MonitorType, monitor_types)
713+
excluded_types = excluded_run_types | excluded_monitor_types
714+
715+
graph = workflow.underlying_graph
716+
to_remove = [
717+
node for node in graph if excluded_types & set(getattr(node, "__args__", set()))
718+
]
719+
graph.remove_nodes_from(to_remove)
720+
721+
722+
def _excluded_type_args(
723+
type_var: Any, keep: Iterable[sciline.typing.Key] | None
724+
) -> set[sciline.typing.Key]:
725+
if keep is None:
726+
return set()
727+
return set(type_var.__constraints__) - set(keep)

tests/nexus/workflow_test.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,20 @@
1212
BackgroundRun,
1313
Choppers,
1414
DetectorData,
15+
EmptyBeamRun,
1516
Filename,
1617
Monitor1,
1718
Monitor2,
1819
Monitor3,
1920
MonitorData,
21+
MonitorType,
2022
NeXusComponentLocationSpec,
2123
NeXusName,
2224
NeXusTransformation,
25+
RunType,
2326
SampleRun,
2427
TimeInterval,
28+
TransmissionMonitor,
2529
)
2630
from ess.reduce.nexus.workflow import (
2731
GenericNeXusWorkflow,
@@ -574,16 +578,11 @@ def test_generic_nexus_workflow_load_analyzers() -> None:
574578
assert analyzer['usage'] == 'Bragg'
575579

576580

577-
def test_generic_nexus_workflow_raises_if_monitor_types_but_not_run_types_given() -> (
578-
None
579-
):
580-
with pytest.raises(ValueError, match='run_types'):
581-
GenericNeXusWorkflow(monitor_types=[Monitor1])
582-
583-
584581
def test_generic_nexus_workflow_includes_only_given_run_and_monitor_types() -> None:
585582
wf = GenericNeXusWorkflow(run_types=[SampleRun], monitor_types=[Monitor1, Monitor3])
586583
graph = wf.underlying_graph
584+
585+
# Check some examples to avoid relying entirely on complicated loops below.
587586
assert DetectorData[SampleRun] in graph
588587
assert DetectorData[BackgroundRun] not in graph
589588
assert MonitorData[SampleRun, Monitor1] in graph
@@ -592,7 +591,11 @@ def test_generic_nexus_workflow_includes_only_given_run_and_monitor_types() -> N
592591
assert MonitorData[BackgroundRun, Monitor1] not in graph
593592
assert MonitorData[BackgroundRun, Monitor2] not in graph
594593
assert MonitorData[BackgroundRun, Monitor3] not in graph
595-
# Many other keys are also removed, this is just an example
594+
assert Choppers[SampleRun] in graph
595+
assert Choppers[BackgroundRun] not in graph
596+
assert Analyzers[SampleRun] in graph
597+
assert Analyzers[BackgroundRun] not in graph
598+
596599
assert NeXusComponentLocationSpec[Monitor1, SampleRun] in graph
597600
assert NeXusComponentLocationSpec[Monitor2, SampleRun] not in graph
598601
assert NeXusComponentLocationSpec[Monitor3, SampleRun] in graph
@@ -605,3 +608,66 @@ def test_generic_nexus_workflow_includes_only_given_run_and_monitor_types() -> N
605608
assert NeXusComponentLocationSpec[snx.NXdetector, BackgroundRun] not in graph
606609
assert NeXusComponentLocationSpec[snx.NXsample, BackgroundRun] not in graph
607610
assert NeXusComponentLocationSpec[snx.NXsource, BackgroundRun] not in graph
611+
612+
excluded_run_types = set(RunType.__constraints__) - {SampleRun}
613+
excluded_monitor_types = set(MonitorType.__constraints__) - {Monitor1, Monitor3}
614+
for node in graph:
615+
assert_not_contains_type_arg(node, excluded_run_types)
616+
assert_not_contains_type_arg(node, excluded_monitor_types)
617+
618+
619+
def test_generic_nexus_workflow_includes_only_given_run_types() -> None:
620+
wf = GenericNeXusWorkflow(run_types=[EmptyBeamRun])
621+
graph = wf.underlying_graph
622+
623+
# Check some examples to avoid relying entirely on complicated loops below.
624+
assert DetectorData[EmptyBeamRun] in graph
625+
assert DetectorData[SampleRun] not in graph
626+
assert MonitorData[EmptyBeamRun, Monitor1] in graph
627+
assert MonitorData[EmptyBeamRun, Monitor2] in graph
628+
assert MonitorData[EmptyBeamRun, Monitor3] in graph
629+
assert MonitorData[SampleRun, Monitor1] not in graph
630+
assert MonitorData[SampleRun, Monitor2] not in graph
631+
assert MonitorData[SampleRun, Monitor3] not in graph
632+
assert Choppers[EmptyBeamRun] in graph
633+
assert Choppers[SampleRun] not in graph
634+
assert Analyzers[EmptyBeamRun] in graph
635+
assert Analyzers[SampleRun] not in graph
636+
637+
excluded_run_types = set(RunType.__constraints__) - {EmptyBeamRun}
638+
for node in graph:
639+
assert_not_contains_type_arg(node, excluded_run_types)
640+
641+
642+
def test_generic_nexus_workflow_includes_only_given_monitor_types() -> None:
643+
wf = GenericNeXusWorkflow(monitor_types=[TransmissionMonitor, Monitor1])
644+
graph = wf.underlying_graph
645+
646+
# Check some examples to avoid relying entirely on complicated loops below.
647+
assert DetectorData[SampleRun] in graph
648+
assert DetectorData[BackgroundRun] in graph
649+
assert MonitorData[SampleRun, TransmissionMonitor] in graph
650+
assert MonitorData[SampleRun, Monitor1] in graph
651+
assert MonitorData[SampleRun, Monitor2] not in graph
652+
assert MonitorData[SampleRun, Monitor3] not in graph
653+
assert MonitorData[BackgroundRun, TransmissionMonitor] in graph
654+
assert MonitorData[BackgroundRun, Monitor1] in graph
655+
assert MonitorData[BackgroundRun, Monitor2] not in graph
656+
assert MonitorData[BackgroundRun, Monitor3] not in graph
657+
assert Choppers[SampleRun] in graph
658+
assert Choppers[BackgroundRun] in graph
659+
assert Analyzers[SampleRun] in graph
660+
assert Analyzers[BackgroundRun] in graph
661+
662+
excluded_monitor_types = set(MonitorType.__constraints__) - {
663+
Monitor1,
664+
TransmissionMonitor,
665+
}
666+
for node in graph:
667+
assert_not_contains_type_arg(node, excluded_monitor_types)
668+
669+
670+
def assert_not_contains_type_arg(node: object, excluded: set[type]) -> None:
671+
assert not any(
672+
arg in excluded for arg in getattr(node, "__args__", ())
673+
), f"Node {node} contains one of {excluded!r}"

0 commit comments

Comments
 (0)