Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose NeXus transformation chains in workflow steps #114

Merged
merged 20 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/ess/reduce/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
group_event_data,
load_component,
compute_component_position,
extract_events_or_histogram,
extract_signal_data_array,
)

__all__ = [
Expand All @@ -25,5 +25,5 @@
'load_data',
'load_component',
'compute_component_position',
'extract_events_or_histogram',
'extract_signal_data_array',
]
9 changes: 7 additions & 2 deletions src/ess/reduce/nexus/_nexus_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,15 @@ def load_component(
component = _unique_child_group(instrument, nx_class, group_name)
loaded = cast(sc.DataGroup, component[selection])
loaded['nexus_component_name'] = component.name.split('/')[-1]
return compute_component_position(loaded)
return loaded
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the heart of the change in this PR: We are no longer computing the position directly here. This is moved into an extra step.



def compute_component_position(dg: sc.DataGroup) -> sc.DataGroup:
# In some downstream packages we use some of the Nexus components which attempt
# to compute positions without having actual Nexus data defining depends_on chains.
# We assume positions have been set in the non-Nexus input somehow and return early.
if 'depends_on' not in dg:
return dg
transform_out_name = 'transform'
if transform_out_name in dg:
raise RuntimeError(
Expand Down Expand Up @@ -115,7 +120,7 @@ def _contains_nx_class(group: snx.Group, nx_class: type[snx.NXobject]) -> bool:
return False


def extract_events_or_histogram(dg: sc.DataGroup) -> sc.DataArray:
def extract_signal_data_array(dg: sc.DataGroup) -> sc.DataArray:
event_data_arrays = {
key: value
for key, value in dg.items()
Expand Down
17 changes: 17 additions & 0 deletions src/ess/reduce/nexus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,20 @@ class NeXusMonitorDataLocationSpec(
NeXusComponentLocationSpec[snx.NXevent_data, RunType], Generic[RunType, MonitorType]
):
"""NeXus filename and parameters to identify (parts of) monitor data to load."""


T = TypeVar('T', bound='NeXusTransformationChain')


@dataclass
class NeXusTransformationChain(snx.TransformationChain, Generic[Component, RunType]):
@classmethod
def from_base(cls: type[T], base: snx.TransformationChain) -> T:
return cls(
parent=base.parent,
value=base.value,
transformations=base.transformations,
)

def compute_position(self) -> sc.Variable | sc.DataArray:
return self.compute() * sc.vector([0, 0, 0], unit='m')
43 changes: 34 additions & 9 deletions src/ess/reduce/nexus/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
NeXusMonitorName,
NeXusSample,
NeXusSource,
NeXusTransformationChain,
PreopenNeXusFile,
PulseSelection,
RunType,
Expand Down Expand Up @@ -339,16 +340,33 @@ def load_nexus_monitor_data(
)


def get_source_position(source: NeXusSource[RunType]) -> SourcePosition[RunType]:
def get_source_transformation_chain(
source: NeXusSource[RunType],
) -> NeXusTransformationChain[snx.NXsource, RunType]:
"""
Extract the source position from a NeXus source group.
Extract the transformation chain from a NeXus source group.

Parameters
----------
source:
NeXus source group.
"""
return SourcePosition[RunType](source["position"])
chain = source['depends_on']
return NeXusTransformationChain[snx.NXsource, RunType].from_base(chain)


def get_source_position(
transformations: NeXusTransformationChain[snx.NXsource, RunType],
) -> SourcePosition[RunType]:
"""
Extract the source position of a NeXus source group.

Parameters
----------
transformations:
NeXus transformation chain of the source group.
"""
return SourcePosition[RunType](transformations.compute_position())


def get_sample_position(sample: NeXusSample[RunType]) -> SamplePosition[RunType]:
Expand All @@ -362,7 +380,8 @@ def get_sample_position(sample: NeXusSample[RunType]) -> SamplePosition[RunType]
sample:
NeXus sample group.
"""
return SamplePosition[RunType](sample.get("position", origin))
dg = nexus.compute_component_position(sample)
return SamplePosition[RunType](dg.get("position", origin))


def get_calibrated_detector(
Expand Down Expand Up @@ -397,7 +416,8 @@ def get_calibrated_detector(
bank_sizes:
Dictionary of detector bank sizes.
"""
da = nexus.extract_events_or_histogram(detector)
detector = nexus.compute_component_position(detector)
da = nexus.extract_signal_data_array(detector)
if (
sizes := (bank_sizes or {}).get(detector.get('nexus_component_name'))
) is not None:
Expand Down Expand Up @@ -459,8 +479,9 @@ def get_calibrated_monitor(
source_position:
Position of the neutron source.
"""
monitor = nexus.compute_component_position(monitor)
return CalibratedMonitor[RunType, MonitorType](
nexus.extract_events_or_histogram(monitor).assign_coords(
nexus.extract_signal_data_array(monitor).assign_coords(
position=monitor['position'] + offset.to(unit=monitor['position'].unit),
source_position=source_position,
)
Expand Down Expand Up @@ -552,7 +573,13 @@ def _add_variances(da: sc.DataArray) -> sc.DataArray:
definitions["NXmonitor"] = _StrippedMonitor


_common_providers = (gravity_vector_neg_y, file_path_to_file_spec, all_pulses)
_common_providers = (
gravity_vector_neg_y,
file_path_to_file_spec,
all_pulses,
get_source_transformation_chain,
get_source_position,
)

_monitor_providers = (
no_monitor_position_offset,
Expand All @@ -562,7 +589,6 @@ def _add_variances(da: sc.DataArray) -> sc.DataArray:
load_nexus_monitor,
load_nexus_monitor_data,
load_nexus_source,
get_source_position,
get_calibrated_monitor,
assemble_monitor_data,
)
Expand All @@ -577,7 +603,6 @@ def _add_variances(da: sc.DataArray) -> sc.DataArray:
load_nexus_detector_data,
load_nexus_source,
load_nexus_sample,
get_source_position,
get_sample_position,
get_calibrated_detector,
assemble_detector_data,
Expand Down
32 changes: 22 additions & 10 deletions tests/nexus/nexus_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def test_load_detector(nexus_file, expected_bank12, entry_name, selection):
if selection is not None:
loc.selection = selection
detector = nexus.load_component(loc, nx_class=snx.NXdetector)
detector = nexus.compute_component_position(detector)
if selection:
expected = expected_bank12.bins[selection]
expected.coords.pop(selection[0])
Expand Down Expand Up @@ -282,7 +283,8 @@ def test_load_and_group_event_data_consistent_with_load_via_detector(
)
if selection:
loc.selection = selection
detector = nexus.load_component(loc, nx_class=snx.NXdetector)['bank12_events']
detector = nexus.load_component(loc, nx_class=snx.NXdetector)
detector = nexus.compute_component_position(detector)['bank12_events']
events = nexus.load_data(
nexus_file,
selection=selection,
Expand All @@ -300,7 +302,8 @@ def test_group_event_data_does_not_modify_input(nexus_file):
filename=nexus_file,
component_name=nexus.types.NeXusDetectorName('bank12'),
)
detector = nexus.load_component(loc, nx_class=snx.NXdetector)['bank12_events']
detector = nexus.load_component(loc, nx_class=snx.NXdetector)
detector = nexus.compute_component_position(detector)['bank12_events']
events = nexus.load_data(
nexus_file,
component_name=nexus.types.NeXusDetectorName('bank12'),
Expand Down Expand Up @@ -328,7 +331,7 @@ def test_load_detector_open_file_with_new_definitions_raises(nexus_file):
nexus.load_component(loc, nx_class=snx.NXdetector, definitions={})


def test_load_detector_new_definitions_applied(nexus_file, expected_bank12):
def test_load_detector_new_definitions_applied(nexus_file):
if not isinstance(nexus_file, snx.Group):
new_definition_used = False

Expand Down Expand Up @@ -381,6 +384,7 @@ def test_load_detector_select_entry_if_not_unique(nexus_file, expected_bank12):
entry_name=nexus.types.NeXusEntryName('entry-001'),
)
detector = nexus.load_component(loc, nx_class=snx.NXdetector)
detector = nexus.compute_component_position(detector)
sc.testing.assert_identical(detector['bank12_events'], expected_bank12)


Expand All @@ -402,6 +406,7 @@ def test_load_monitor(nexus_file, expected_monitor, entry_name, selection):
if selection is not None:
loc.selection = selection
monitor = nexus.load_component(loc, nx_class=snx.NXmonitor)
monitor = nexus.compute_component_position(monitor)
expected = expected_monitor[selection] if selection else expected_monitor
sc.testing.assert_identical(monitor['data'], expected)

Expand All @@ -415,6 +420,7 @@ def test_load_source(nexus_file, expected_source, entry_name, source_name):
component_name=source_name,
)
source = nexus.load_component(loc, nx_class=snx.NXsource)
source = nexus.compute_component_position(source)
# NeXus details that we don't need to test as long as the positions are ok:
del source['depends_on']
del source['transformations']
Expand Down Expand Up @@ -460,7 +466,7 @@ def test_extract_detector_data():
' _': sc.linspace('xx', 2, 3, 10),
}
)
data = nexus.extract_events_or_histogram(detector)
data = nexus.extract_signal_data_array(detector)
sc.testing.assert_identical(data, detector['jdl2ab'])


Expand All @@ -472,7 +478,7 @@ def test_extract_monitor_data():
' _': sc.linspace('xx', 2, 3, 10),
}
)
data = nexus.extract_events_or_histogram(monitor)
data = nexus.extract_signal_data_array(monitor)
sc.testing.assert_identical(data, monitor['(eed)'])


Expand All @@ -488,17 +494,17 @@ def test_extract_detector_data_requires_unique_dense_data():
with pytest.raises(
ValueError, match="Cannot uniquely identify the data to extract"
):
nexus.extract_events_or_histogram(detector)
nexus.extract_signal_data_array(detector)


def test_extract_detector_data_ignores_position_data_array():
detector = sc.DataGroup(jdl2ab=sc.data.data_xy(), position=sc.data.data_xy())
nexus.extract_events_or_histogram(detector)
nexus.extract_signal_data_array(detector)


def test_extract_detector_data_ignores_transform_data_array():
detector = sc.DataGroup(jdl2ab=sc.data.data_xy(), transform=sc.data.data_xy())
nexus.extract_events_or_histogram(detector)
nexus.extract_signal_data_array(detector)


def test_extract_detector_data_requires_unique_event_data():
Expand All @@ -513,7 +519,7 @@ def test_extract_detector_data_requires_unique_event_data():
with pytest.raises(
ValueError, match="Cannot uniquely identify the data to extract"
):
nexus.extract_events_or_histogram(detector)
nexus.extract_signal_data_array(detector)


def test_extract_detector_data_favors_event_data_over_histogram_data():
Expand All @@ -525,5 +531,11 @@ def test_extract_detector_data_favors_event_data_over_histogram_data():
' _': sc.linspace('xx', 2, 3, 10),
}
)
data = nexus.extract_events_or_histogram(detector)
data = nexus.extract_signal_data_array(detector)
sc.testing.assert_identical(data, detector['lob'])


def compute_component_position_returns_input_if_no_depends_on() -> None:
dg = sc.DataGroup(position=sc.vector([1, 2, 3], unit='m'))
result = nexus.compute_component_position(dg)
assert result is dg
Loading
Loading