Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose NeXus transformation chains in workflow steps #114

Merged
merged 20 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/ess/reduce/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
group_event_data,
load_component,
compute_component_position,
extract_events_or_histogram,
extract_signal_data_array,
)

__all__ = [
Expand All @@ -25,5 +25,5 @@
'load_data',
'load_component',
'compute_component_position',
'extract_events_or_histogram',
'extract_signal_data_array',
]
9 changes: 7 additions & 2 deletions src/ess/reduce/nexus/_nexus_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,15 @@ def load_component(
component = _unique_child_group(instrument, nx_class, group_name)
loaded = cast(sc.DataGroup, component[selection])
loaded['nexus_component_name'] = component.name.split('/')[-1]
return compute_component_position(loaded)
return loaded
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the heart of the change in this PR: We are no longer computing the position directly here. This is moved into an extra step.



def compute_component_position(dg: sc.DataGroup) -> sc.DataGroup:
# In some downstream packages we use some of the Nexus components which attempt
# to compute positions without having actual Nexus data defining depends_on chains.
# We assume positions have been set in the non-Nexus input somehow and return early.
if 'depends_on' not in dg:
return dg
transform_out_name = 'transform'
if transform_out_name in dg:
raise RuntimeError(
Expand Down Expand Up @@ -115,7 +120,7 @@ def _contains_nx_class(group: snx.Group, nx_class: type[snx.NXobject]) -> bool:
return False


def extract_events_or_histogram(dg: sc.DataGroup) -> sc.DataArray:
def extract_signal_data_array(dg: sc.DataGroup) -> sc.DataArray:
event_data_arrays = {
key: value
for key, value in dg.items()
Expand Down
124 changes: 76 additions & 48 deletions src/ess/reduce/nexus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
NeXusGroup = NewType('NeXusGroup', snx.Group)
"""A ScippNexus group in an open file."""

NeXusDetectorName = NewType('NeXusDetectorName', str)
"""Name of a detector (bank) in a NeXus file."""
NeXusEntryName = NewType('NeXusEntryName', str)
"""Name of an entry in a NeXus file."""
NeXusSourceName = NewType('NeXusSourceName', str)
Expand All @@ -33,8 +31,6 @@

GravityVector = NewType('GravityVector', sc.Variable)

Component = TypeVar('Component', bound=snx.NXobject)

PreopenNeXusFile = NewType('PreopenNeXusFile', bool)
"""Whether to preopen NeXus files before passing them to the rest of the workflow."""

Expand Down Expand Up @@ -95,9 +91,9 @@ class TransmissionRun(Generic[ScatteringRunType]):
"""Identifier for an arbitrary monitor"""
Monitor5 = NewType('Monitor5', int)
"""Identifier for an arbitrary monitor"""
Incident = NewType('Incident', int)
IncidentMonitor = NewType('IncidentMonitor', int)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the monitor suffix here since these now show up as MyComponent[IndicidentMonitor] and it may be confusing without.

"""Incident monitor"""
Transmission = NewType('Transmission', int)
TransmissionMonitor = NewType('TransmissionMonitor', int)
"""Transmission monitor"""
MonitorType = TypeVar(
'MonitorType',
Expand All @@ -106,50 +102,56 @@ class TransmissionRun(Generic[ScatteringRunType]):
Monitor3,
Monitor4,
Monitor5,
Incident,
Transmission,
IncidentMonitor,
TransmissionMonitor,
)
"""TypeVar used for specifying the monitor type such as Incident or Transmission"""


class NeXusMonitorName(sciline.Scope[MonitorType, str], str):
"""Name of a monitor in a NeXus file."""


class NeXusDetector(sciline.Scope[RunType, sc.DataGroup], sc.DataGroup):
"""Full raw data from a NeXus detector."""


class NeXusMonitor(
sciline.ScopeTwoParams[RunType, MonitorType, sc.DataGroup], sc.DataGroup
):
"""Full raw data from a NeXus monitor."""
Component = TypeVar(
'Component',
snx.NXdetector,
snx.NXsample,
snx.NXsource,
Monitor1,
Monitor2,
Monitor3,
Monitor4,
Monitor5,
IncidentMonitor,
TransmissionMonitor,
)
UniqueComponent = TypeVar('UniqueComponent', snx.NXsample, snx.NXsource)
"""Components that can be identified by their type as there will only be one."""


class NeXusSample(sciline.Scope[RunType, sc.DataGroup], sc.DataGroup):
"""Raw data from a NeXus sample."""
class NeXusName(sciline.Scope[Component, str], str):
"""Name of a component in a NeXus file."""


class NeXusSource(sciline.Scope[RunType, sc.DataGroup], sc.DataGroup):
"""Raw data from a NeXus source."""
class NeXusClass(sciline.Scope[Component, str], str):
"""NX_class of a component in a NeXus file."""


class NeXusDetectorData(sciline.Scope[RunType, sc.DataArray], sc.DataArray):
"""Data array loaded from an NXevent_data or NXdata group within an NXdetector."""
NeXusDetectorName = NeXusName[snx.NXdetector]
"""Name of a detector (bank) in a NeXus file."""


class NeXusMonitorData(
sciline.ScopeTwoParams[RunType, MonitorType, sc.DataArray], sc.DataArray
class NeXusComponent(
sciline.ScopeTwoParams[Component, RunType, sc.DataGroup], sc.DataGroup
):
"""Data array loaded from an NXevent_data or NXdata group within an NXmonitor."""
"""Raw data from a NeXus component."""


class SourcePosition(sciline.Scope[RunType, sc.Variable], sc.Variable):
"""Position of the neutron source."""
class NeXusData(sciline.ScopeTwoParams[Component, RunType, sc.DataArray], sc.DataArray):
"""
Data array loaded from an NXevent_data or NXdata group.

This must be contained in an NXmonitor or NXdetector group.
"""


class SamplePosition(sciline.Scope[RunType, sc.Variable], sc.Variable):
"""Position of the sample."""
class Position(sciline.ScopeTwoParams[Component, RunType, sc.Variable], sc.Variable):
"""Position of a component such as source, sample, monitor, or detector."""


class DetectorPositionOffset(sciline.Scope[RunType, sc.Variable], sc.Variable):
Expand All @@ -166,6 +168,10 @@ class CalibratedDetector(sciline.Scope[RunType, sc.DataArray], sc.DataArray):
"""Calibrated data from a detector."""


class CalibratedBeamline(sciline.Scope[RunType, sc.DataArray], sc.DataArray):
"""Calibrated beamline with detector and other components."""


class CalibratedMonitor(
sciline.ScopeTwoParams[RunType, MonitorType, sc.DataArray], sc.DataArray
):
Expand Down Expand Up @@ -217,23 +223,45 @@ class NeXusComponentLocationSpec(NeXusLocationSpec, Generic[Component, RunType])


@dataclass
class NeXusMonitorLocationSpec(
NeXusComponentLocationSpec[snx.NXmonitor, RunType], Generic[RunType, MonitorType]
):
"""
NeXus filename and optional parameters to identify (parts of) a monitor to load.
"""
class NeXusDataLocationSpec(NeXusLocationSpec, Generic[Component, RunType]):
"""NeXus filename and parameters to identify (parts of) detector data to load."""


T = TypeVar('T', bound='NeXusTransformationChain')


@dataclass
class NeXusDetectorDataLocationSpec(
NeXusComponentLocationSpec[snx.NXevent_data, RunType], Generic[RunType]
):
"""NeXus filename and parameters to identify (parts of) detector data to load."""
class NeXusTransformationChain(snx.TransformationChain, Generic[Component, RunType]):
@classmethod
def from_base(cls: type[T], base: snx.TransformationChain) -> T:
return cls(
parent=base.parent,
value=base.value,
transformations=base.transformations,
)

def compute_position(self) -> sc.Variable | sc.DataArray:
return self.compute() * sc.vector([0, 0, 0], unit='m')


@dataclass
class NeXusMonitorDataLocationSpec(
NeXusComponentLocationSpec[snx.NXevent_data, RunType], Generic[RunType, MonitorType]
):
"""NeXus filename and parameters to identify (parts of) monitor data to load."""
class NeXusTransformation(Generic[Component, RunType]):
value: sc.Variable

@staticmethod
def from_chain(
chain: NeXusTransformationChain[Component, RunType],
) -> 'NeXusTransformation[Component, RunType]':
"""
Convert a transformation chain to a single transformation.

As transformation chains may be time-dependent, this method will need to select
a specific time point to convert to a single transformation. This may include
averaging as well as threshold checks. This is not implemented yet and we
therefore currently raise an error if the transformation chain does not compute
to a scalar.
"""
transform = chain.compute()
if transform.ndim == 0:
return NeXusTransformation(value=transform)
raise ValueError(f"Expected scalar transformation, got {transform}")
Loading
Loading