diff --git a/conda/meta.yaml b/conda/meta.yaml index 4b70d028..96571c25 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -12,6 +12,8 @@ requirements: - setuptools_scm run: - python>=3.10 + - scipp >= 24.02.0 + - scippnexus >= 24.03.0 test: imports: diff --git a/docs/api-reference/index.md b/docs/api-reference/index.md index 70d4d966..04e8cf95 100644 --- a/docs/api-reference/index.md +++ b/docs/api-reference/index.md @@ -26,4 +26,6 @@ :toctree: ../generated/modules :template: module-template.rst :recursive: + + nexus ``` diff --git a/pyproject.toml b/pyproject.toml index 2379e945..92c51469 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ requires-python = ">=3.10" # Run 'tox -e deps' after making changes here. This will update requirement files. # Make sure to list one dependency per line. dependencies = [ + "scipp >= 24.02.0", + "scippnexus >= 24.03.0", ] dynamic = ["version"] diff --git a/requirements/base.in b/requirements/base.in index b801db0e..d058df89 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -2,4 +2,5 @@ # will not be touched by ``make_base.py`` # --- END OF CUSTOM SECTION --- # The following was generated by 'tox -e deps', DO NOT EDIT MANUALLY! - +scipp >= 24.02.0 +scippnexus >= 24.03.0 diff --git a/requirements/base.txt b/requirements/base.txt index c24fd27f..5e1a1dd3 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,8 +1,26 @@ -# SHA1:da39a3ee5e6b4b0d3255bfef95601890afd80709 +# SHA1:b5fdb6600edc83ab95fb0e848607edef52cdd293 # # This file is autogenerated by pip-compile-multi # To update, run: # # pip-compile-multi # - +h5py==3.10.0 + # via scippnexus +numpy==1.26.4 + # via + # h5py + # scipp + # scipy +python-dateutil==2.9.0.post0 + # via scippnexus +scipp==24.2.0 + # via + # -r base.in + # scippnexus +scippnexus==24.3.1 + # via -r base.in +scipy==1.12.0 + # via scippnexus +six==1.16.0 + # via python-dateutil diff --git a/requirements/basetest.txt b/requirements/basetest.txt index bc93620f..a1fc27b6 100644 --- a/requirements/basetest.txt +++ b/requirements/basetest.txt @@ -13,7 +13,7 @@ packaging==23.2 # via pytest pluggy==1.4.0 # via pytest -pytest==8.0.1 +pytest==8.0.2 # via -r basetest.in tomli==2.0.1 # via pytest diff --git a/requirements/ci.txt b/requirements/ci.txt index eeef86ea..87aaf8e4 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -5,7 +5,7 @@ # # pip-compile-multi # -cachetools==5.3.2 +cachetools==5.3.3 # via tox certifi==2024.2.2 # via requests diff --git a/requirements/dev.txt b/requirements/dev.txt index 61cfd1bb..42476c74 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -50,7 +50,7 @@ isoduration==20.11.0 # via jsonschema jinja2-ansible-filters==1.3.2 # via copier -json5==0.9.17 +json5==0.9.20 # via jupyterlab-server jsonpointer==2.4 # via jsonschema @@ -61,9 +61,9 @@ jsonschema[format-nongpl]==4.21.1 # nbformat jupyter-events==0.9.0 # via jupyter-server -jupyter-lsp==2.2.2 +jupyter-lsp==2.2.3 # via jupyterlab -jupyter-server==2.12.5 +jupyter-server==2.13.0 # via # jupyter-lsp # jupyterlab @@ -71,7 +71,7 @@ jupyter-server==2.12.5 # notebook-shim jupyter-server-terminals==0.5.2 # via jupyter-server -jupyterlab==4.1.2 +jupyterlab==4.1.3 # via -r dev.in jupyterlab-server==2.25.3 # via jupyterlab @@ -91,9 +91,9 @@ prometheus-client==0.20.0 # via jupyter-server pycparser==2.21 # via cffi -pydantic==2.6.1 +pydantic==2.6.3 # via copier -pydantic-core==2.16.2 +pydantic-core==2.16.3 # via pydantic python-json-logger==2.0.7 # via jupyter-events @@ -111,7 +111,7 @@ rfc3986-validator==0.1.1 # jupyter-events send2trash==1.8.2 # via jupyter-server -sniffio==1.3.0 +sniffio==1.3.1 # via # anyio # httpx diff --git a/requirements/docs.txt b/requirements/docs.txt index caed3661..4381e0a2 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -54,9 +54,9 @@ idna==3.6 # via requests imagesize==1.4.1 # via sphinx -ipykernel==6.29.2 +ipykernel==6.29.3 # via -r docs.in -ipython==8.22.0 +ipython==8.22.2 # via # -r docs.in # ipykernel @@ -107,7 +107,7 @@ myst-parser==2.0.0 # via -r docs.in nbclient==0.9.0 # via nbconvert -nbconvert==7.16.1 +nbconvert==7.16.2 # via nbsphinx nbformat==5.9.2 # via @@ -149,8 +149,6 @@ pygments==2.17.2 # nbconvert # pydata-sphinx-theme # sphinx -python-dateutil==2.8.2 - # via jupyter-client pyyaml==6.0.1 # via myst-parser pyzmq==25.1.2 @@ -167,11 +165,6 @@ rpds-py==0.18.0 # via # jsonschema # referencing -six==1.16.0 - # via - # asttokens - # bleach - # python-dateutil snowballstemmer==2.2.0 # via sphinx soupsieve==2.5 @@ -223,7 +216,7 @@ traitlets==5.14.1 # nbconvert # nbformat # nbsphinx -typing-extensions==4.9.0 +typing-extensions==4.10.0 # via pydata-sphinx-theme urllib3==2.2.1 # via requests diff --git a/requirements/mypy.txt b/requirements/mypy.txt index ac285686..49722576 100644 --- a/requirements/mypy.txt +++ b/requirements/mypy.txt @@ -10,5 +10,5 @@ mypy==1.8.0 # via -r mypy.in mypy-extensions==1.0.0 # via mypy -typing-extensions==4.9.0 +typing-extensions==4.10.0 # via mypy diff --git a/requirements/nightly.in b/requirements/nightly.in index 6b1ebcc2..0e1f5905 100644 --- a/requirements/nightly.in +++ b/requirements/nightly.in @@ -1,4 +1,5 @@ -r basetest.in # --- END OF CUSTOM SECTION --- # The following was generated by 'tox -e deps', DO NOT EDIT MANUALLY! - +scipp >= 24.02.0 +scippnexus >= 24.03.0 diff --git a/requirements/nightly.txt b/requirements/nightly.txt index a98564b2..2126652e 100644 --- a/requirements/nightly.txt +++ b/requirements/nightly.txt @@ -1,4 +1,4 @@ -# SHA1:e8b11c1210855f07eaedfbcfb3ecd1aec3595dee +# SHA1:9bb7ade09fe2af7ab62c586f5f5dc6f3e9b8344b # # This file is autogenerated by pip-compile-multi # To update, run: @@ -6,3 +6,22 @@ # pip-compile-multi # -r basetest.txt +h5py==3.10.0 + # via scippnexus +numpy==1.26.4 + # via + # h5py + # scipp + # scipy +python-dateutil==2.9.0.post0 + # via scippnexus +scipp==24.2.0 + # via + # -r nightly.in + # scippnexus +scippnexus==24.3.1 + # via -r nightly.in +scipy==1.12.0 + # via scippnexus +six==1.16.0 + # via python-dateutil diff --git a/requirements/wheels.txt b/requirements/wheels.txt index 2e33cfa3..12ac3232 100644 --- a/requirements/wheels.txt +++ b/requirements/wheels.txt @@ -5,7 +5,7 @@ # # pip-compile-multi # -build==1.0.3 +build==1.1.1 # via -r wheels.in packaging==23.2 # via build diff --git a/src/ess/reduce/__init__.py b/src/ess/reduce/__init__.py index 9da9479a..3fa47da5 100644 --- a/src/ess/reduce/__init__.py +++ b/src/ess/reduce/__init__.py @@ -4,9 +4,13 @@ # flake8: noqa import importlib.metadata +from . import nexus + try: __version__ = importlib.metadata.version(__package__ or __name__) except importlib.metadata.PackageNotFoundError: __version__ = "0.0.0" del importlib + +__all__ = ['nexus'] diff --git a/src/ess/reduce/logging.py b/src/ess/reduce/logging.py new file mode 100644 index 00000000..05adcfaf --- /dev/null +++ b/src/ess/reduce/logging.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) + +"""Logging tools for ess.reduce.""" + +import logging + + +def get_logger() -> logging.Logger: + """Return the logger for ess.reduce. + + Returns + ------- + : + The requested logger. + """ + return logging.getLogger('scipp.ess.reduce') diff --git a/src/ess/reduce/nexus.py b/src/ess/reduce/nexus.py new file mode 100644 index 00000000..54e3eb1e --- /dev/null +++ b/src/ess/reduce/nexus.py @@ -0,0 +1,391 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) + +"""NeXus utilities. + +This module defines functions and domain types that can be used +to build Sciline pipelines for simple workflows. +If multiple different kinds of files (e.g., sample and background runs) +are needed, custom types and providers need to be defined to wrap +the basic ones here. +""" + +from contextlib import nullcontext +from pathlib import Path +from typing import BinaryIO, ContextManager, NewType, Optional, Type, Union, cast + +import scipp as sc +import scippnexus as snx + +from .logging import get_logger + +FilePath = NewType('FilePath', Path) +"""Full path to a NeXus file on disk.""" +NeXusFile = NewType('NeXusFile', BinaryIO) +"""An open NeXus file. + +Can be any file handle for reading binary data. + +Note that this cannot be used as a parameter in Sciline as there are no +concrete implementations of ``BinaryIO``. +The type alias is provided for callers of load functions outside of pipelines. +""" +NeXusGroup = NewType('NeXusGroup', snx.Group) +"""A ScippNexus group in an open file.""" + +NeXusDetectorName = NewType('NeXusDetectorName', str) +"""Name of a detector (bank) in a NeXus file.""" +NeXusEntryName = NewType('NeXusEntryName', str) +"""Name of an entry in a NeXus file.""" +NeXusMonitorName = NewType('NeXusMonitorName', str) +"""Name of a monitor in a NeXus file.""" +NeXusSourceName = NewType('NeXusSourceName', str) +"""Name of a source in a NeXus file.""" + +RawDetector = NewType('RawDetector', sc.DataGroup) +"""Full raw data from a NeXus detector.""" +RawDetectorData = NewType('RawDetectorData', sc.DataArray) +"""Data extracted from a RawDetector.""" +RawMonitor = NewType('RawMonitor', sc.DataGroup) +"""Full raw data from a NeXus monitor.""" +RawMonitorData = NewType('RawMonitorData', sc.DataArray) +"""Data extracted from a RawMonitor.""" +RawSample = NewType('RawSample', sc.DataGroup) +"""Raw data from a NeXus sample.""" +RawSource = NewType('RawSource', sc.DataGroup) +"""Raw data from a NeXus source.""" + + +def load_detector( + file_path: Union[FilePath, NeXusFile, NeXusGroup], + *, + detector_name: NeXusDetectorName, + entry_name: Optional[NeXusEntryName] = None, +) -> RawDetector: + """Load a single detector (bank) from a NeXus file. + + The detector positions are computed automatically from NeXus transformations, + and the combined transformation is stored under the name 'transform'. + + Parameters + ---------- + file_path: + Indicates where to load data from. + One of: + + - Path to a NeXus file on disk. + - File handle or buffer for reading binary data. + - A ScippNexus group of the root of a NeXus file. + detector_name: + Name of the detector (bank) to load. + Must be a group in an instrument group in the entry (see below). + entry_name: + Name of the entry that contains the detector. + If ``None``, the entry will be located based + on its NeXus class, but there cannot be more than 1. + + Returns + ------- + : + A data group containing the detector events or histogram + and any auxiliary data stored in the same NeXus group. + """ + return RawDetector( + _load_group_with_positions( + file_path, + group_name=detector_name, + nx_class=snx.NXdetector, + entry_name=entry_name, + ) + ) + + +def load_monitor( + file_path: Union[FilePath, NeXusFile, NeXusGroup], + *, + monitor_name: NeXusMonitorName, + entry_name: Optional[NeXusEntryName] = None, +) -> RawMonitor: + """Load a single monitor from a NeXus file. + + The monitor position is computed automatically from NeXus transformations, + and the combined transformation is stored under the name 'transform'. + + Parameters + ---------- + file_path: + Indicates where to load data from. + One of: + + - Path to a NeXus file on disk. + - File handle or buffer for reading binary data. + - A ScippNexus group of the root of a NeXus file. + monitor_name: + Name of the monitor to load. + Must be a group in an instrument group in the entry (see below). + entry_name: + Name of the entry that contains the monitor. + If ``None``, the entry will be located based + on its NeXus class, but there cannot be more than 1. + + Returns + ------- + : + A data group containing the monitor events or histogram + and any auxiliary data stored in the same NeXus group. + """ + return RawMonitor( + _load_group_with_positions( + file_path, + group_name=monitor_name, + nx_class=snx.NXmonitor, + entry_name=entry_name, + ) + ) + + +def load_source( + file_path: Union[FilePath, NeXusFile, NeXusGroup], + *, + source_name: Optional[NeXusSourceName] = None, + entry_name: Optional[NeXusEntryName] = None, +) -> RawSource: + """Load a source from a NeXus file. + + The source position is computed automatically from NeXus transformations, + and the combined transformation is stored under the name 'transform'. + + Parameters + ---------- + file_path: + Indicates where to load data from. + One of: + + - Path to a NeXus file on disk. + - File handle or buffer for reading binary data. + - A ScippNexus group of the root of a NeXus file. + source_name: + Name of the source to load. + Must be a group in an instrument group in the entry (see below). + If ``None``, the source will be located based + on its NeXus class. + entry_name: + Name of the instrument that contains the source. + If ``None``, the entry will be located based + on its NeXus class, but there cannot be more than 1. + + Returns + ------- + : + A data group containing all data stored in + the source NeXus group. + """ + return RawSource( + _load_group_with_positions( + file_path, + group_name=source_name, + nx_class=snx.NXsource, + entry_name=entry_name, + ) + ) + + +def load_sample( + file_path: Union[FilePath, NeXusFile, NeXusGroup], + entry_name: Optional[NeXusEntryName] = None, +) -> RawSample: + """Load a sample from a NeXus file. + + The sample is located based on its NeXus class. + There can be only one sample in a NeXus file or + in the entry indicated by ``entry_name``. + + Parameters + ---------- + file_path: + Indicates where to load data from. + One of: + + - Path to a NeXus file on disk. + - File handle or buffer for reading binary data. + - A ScippNexus group of the root of a NeXus file. + entry_name: + Name of the instrument that contains the source. + If ``None``, the entry will be located based + on its NeXus class, but there cannot be more than 1. + + Returns + ------- + : + A data group containing all data stored in + the sample NeXus group. + """ + with _open_nexus_file(file_path) as f: + entry = _unique_child_group(f, snx.NXentry, entry_name) + loaded = cast(sc.DataGroup, _unique_child_group(entry, snx.NXsample, None)[()]) + return RawSample(loaded) + + +def _load_group_with_positions( + file_path: Union[FilePath, NeXusFile, NeXusGroup], + *, + group_name: Optional[str], + nx_class: Type[snx.NXobject], + entry_name: Optional[NeXusEntryName] = None, +) -> sc.DataGroup: + with _open_nexus_file(file_path) as f: + entry = _unique_child_group(f, snx.NXentry, entry_name) + instrument = _unique_child_group(entry, snx.NXinstrument, None) + loaded = cast( + sc.DataGroup, _unique_child_group(instrument, nx_class, group_name)[()] + ) + + transform_out_name = 'transform' + if transform_out_name in loaded: + raise RuntimeError( + f"Loaded data contains an item '{transform_out_name}' but we want to " + "store the combined NeXus transformations under that name." + ) + position_out_name = 'position' + if position_out_name in loaded: + raise RuntimeError( + f"Loaded data contains an item '{position_out_name}' but we want to " + "store the computed positions under that name." + ) + + loaded = snx.compute_positions( + loaded, store_position=position_out_name, store_transform=transform_out_name + ) + return loaded + + +def _open_nexus_file( + file_path: Union[FilePath, NeXusFile, NeXusGroup] +) -> ContextManager: + if isinstance(file_path, getattr(NeXusGroup, '__supertype__', type(None))): + return nullcontext(file_path) + return snx.File(file_path) + + +def _unique_child_group( + group: snx.Group, nx_class: Type[snx.NXobject], name: Optional[str] +) -> snx.Group: + if name is not None: + child = group[name] + if isinstance(child, snx.Field): + raise ValueError( + f"Expected a NeXus group as item '{name}' but got a field." + ) + if child.nx_class != nx_class: + raise ValueError( + f"The NeXus group '{name}' was expected to be a " + f'{nx_class} but is a {child.nx_class}.' + ) + return child + + children = group[nx_class] + if len(children) != 1: + raise ValueError(f'Expected exactly one {nx_class} group, got {len(children)}') + return next(iter(children.values())) # type: ignore[return-value] + + +def extract_detector_data(detector: RawDetector) -> RawDetectorData: + """Get and return the events or histogram from a detector loaded from NeXus. + + This function looks for a data array in the detector group and returns that. + + Parameters + ---------- + detector: + A detector loaded from NeXus. + + Returns + ------- + : + A data array containing the events or histogram. + + Raises + ------ + ValueError + If there is more than one data array. + + See also + -------- + load_detector: + Load a detector from a NeXus file in a format compatible with + ``extract_detector_data``. + """ + return RawDetectorData(_extract_events_or_histogram(detector)) + + +def extract_monitor_data(monitor: RawMonitor) -> RawMonitorData: + """Get and return the events or histogram from a monitor loaded from NeXus. + + This function looks for a data array in the monitor group and returns that. + + Parameters + ---------- + monitor: + A monitor loaded from NeXus. + + Returns + ------- + : + A data array containing the events or histogram. + + Raises + ------ + ValueError + If there is more than one data array. + + See also + -------- + load_monitor: + Load a monitor from a NeXus file in a format compatible with + ``extract_monitor_data``. + """ + return RawMonitorData(_extract_events_or_histogram(monitor)) + + +def _extract_events_or_histogram(dg: sc.DataGroup) -> sc.DataArray: + event_data_arrays = { + key: value + for key, value in dg.items() + if isinstance(value, sc.DataArray) and value.bins is not None + } + histogram_data_arrays = { + key: value + for key, value in dg.items() + if isinstance(value, sc.DataArray) and value.bins is None + } + if (array := _select_unique_array(event_data_arrays, 'event')) is not None: + if histogram_data_arrays: + get_logger().info( + "Selecting event data '%s' in favor of histogram data {%s}", + next(iter(event_data_arrays.keys())), + ', '.join(map(lambda k: f"'{k}'", histogram_data_arrays)), + ) + return array + + if (array := _select_unique_array(histogram_data_arrays, 'histogram')) is not None: + return array + + raise ValueError( + "Raw data loaded from NeXus does not contain events or a histogram. " + "Expected to find a data array, " + f"but the data only contains {set(dg.keys())}" + ) + + +def _select_unique_array( + arrays: dict[str, sc.DataArray], mapping_name: str +) -> Optional[sc.DataArray]: + if not arrays: + return None + if len(arrays) > 1: + raise ValueError( + f"Raw data loaded from NeXus contains more than one {mapping_name} " + "data array. Cannot uniquely identify the data to extract. " + f"Got {mapping_name} items {set(arrays.keys())}" + ) + return next(iter(arrays.values())) diff --git a/tests/nexus_test.py b/tests/nexus_test.py new file mode 100644 index 00000000..6287c106 --- /dev/null +++ b/tests/nexus_test.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) + +from contextlib import contextmanager +from io import BytesIO +from pathlib import Path +from typing import Union + +import numpy as np +import pytest +import scipp as sc +import scipp.testing +import scippnexus as snx + +from ess.reduce import nexus + + +def _event_data_components() -> sc.DataGroup: + return sc.DataGroup( + { + 'event_id': sc.array(dims=['event'], unit=None, values=[1, 2, 4, 1, 2, 2]), + 'event_time_offset': sc.array( + dims=['event'], unit='s', values=[456, 7, 3, 345, 632, 23] + ), + 'event_time_zero': sc.epoch(unit='s') + + sc.array(dims=['event_index'], unit='s', values=[1, 2, 3, 4]), + 'event_index': sc.array( + dims=['event_index'], unit=None, values=[0, 3, 3, 6] + ), + 'detector_number': sc.arange('detector_number', 5, unit=None), + 'pixel_offset': sc.vectors( + dims=['detector_number'], + values=np.arange(3 * 5).reshape((5, 3)), + unit='m', + ), + } + ) + + +def detector_transformation_components() -> sc.DataGroup: + return sc.DataGroup( + { + 'offset': sc.vector([0.4, 0.0, 11.5], unit='m'), + } + ) + + +def _monitor_histogram() -> sc.DataArray: + return sc.DataArray( + sc.array(dims=['time'], values=[2, 4, 8, 3], unit='counts'), + coords={ + 'time': sc.epoch(unit='ms') + + sc.array(dims=['time'], values=[2, 4, 6, 8, 10], unit='ms'), + }, + ) + + +def _source_data() -> sc.DataGroup: + return sc.DataGroup( + { + 'name': 'moderator', + 'probe': 'neutron', + 'type': 'Spallation Neutron Source', + 'position': sc.vector([0, 0, 0], unit='m'), + 'transform': sc.spatial.translation(value=[0, 0, 0], unit='m'), + } + ) + + +def _sample_data() -> sc.DataGroup: + return sc.DataGroup( + { + 'name': 'water', + 'chemical_formula': 'H2O', + 'type': 'sample+can', + } + ) + + +def _write_transformation(group: snx.Group, offset: sc.Variable) -> None: + group.create_field('depends_on', sc.scalar('transformations/t1')) + transformations = group.create_class('transformations', snx.NXtransformations) + t1 = transformations.create_field('t1', sc.scalar(0.0, unit=offset.unit)) + t1.attrs['depends_on'] = '.' + t1.attrs['transformation_type'] = 'translation' + t1.attrs['offset'] = offset.values + t1.attrs['offset_units'] = str(offset.unit) + t1.attrs['vector'] = sc.vector([0, 0, 1]).value + + +def _write_nexus_data(store: Union[Path, BytesIO]) -> None: + with snx.File(store, 'w') as root: + entry = root.create_class('entry-001', snx.NXentry) + instrument = entry.create_class('reducer', snx.NXinstrument) + + detector = instrument.create_class('bank12', snx.NXdetector) + events = detector.create_class('bank12_events', snx.NXevent_data) + detector_components = _event_data_components() + events['event_id'] = detector_components['event_id'] + events['event_time_offset'] = detector_components['event_time_offset'] + events['event_time_zero'] = detector_components['event_time_zero'] + events['event_index'] = detector_components['event_index'] + detector['x_pixel_offset'] = detector_components['pixel_offset'].fields.x + detector['y_pixel_offset'] = detector_components['pixel_offset'].fields.y + detector['z_pixel_offset'] = detector_components['pixel_offset'].fields.z + detector['detector_number'] = detector_components['detector_number'] + _write_transformation(detector, detector_transformation_components()['offset']) + + monitor_data = _monitor_histogram() + monitor = instrument.create_class('monitor', snx.NXmonitor) + data = monitor.create_class('data', snx.NXdata) + signal = data.create_field('signal', monitor_data.data) + signal.attrs['signal'] = 1 + signal.attrs['axes'] = monitor_data.dim + data.create_field('time', monitor_data.coords['time']) + + source_data = _source_data() + source = instrument.create_class('source', snx.NXsource) + source.create_field('name', source_data['name']) + source.create_field('probe', source_data['probe']) + source.create_field('type', source_data['type']) + _write_transformation(source, source_data['position']) + + sample_data = _sample_data() + sample = entry.create_class('sample', snx.NXsample) + sample.create_field('name', sample_data['name']) + sample.create_field('chemical_formula', sample_data['chemical_formula']) + sample.create_field('type', sample_data['type']) + + +@contextmanager +def _file_store(request: pytest.FixtureRequest): + if request.param == BytesIO: + yield BytesIO() + else: + # It would be good to use pyfakefs here, but h5py + # uses C to open files and that bypasses the fake. + base = request.getfixturevalue('tmp_path') + yield base / 'testfile.nxs' + + +@pytest.fixture(params=[Path, BytesIO, snx.Group]) +def nexus_file(request): + with _file_store(request) as store: + _write_nexus_data(store) + if isinstance(store, BytesIO): + store.seek(0) + + if request.param in (Path, BytesIO): + yield store + else: + with snx.File(store, 'r') as f: + yield f + + +@pytest.fixture() +def expected_bank12(): + components = _event_data_components() + buffer = sc.DataArray( + sc.ones(sizes={'event': 6}, unit='counts', dtype='float32'), + coords={ + 'detector_number': components['event_id'], + 'event_time_offset': components['event_time_offset'], + }, + ) + + # Bin by event_index tp broadcast event_time_zero to events + binned_in_time = sc.DataArray( + sc.bins( + data=buffer, + begin=components['event_index'], + end=sc.concat( + [components['event_index'][1:], components['event_index'][-1]], + dim='event_index', + ), + dim='event', + ) + ) + binned_in_time.bins.coords['event_time_zero'] = sc.bins_like( + binned_in_time, components['event_time_zero'] + ) + + # Bin by detector number like ScippNexus would + binned = binned_in_time.bins.concat().group(components['detector_number']) + binned.coords['x_pixel_offset'] = components['pixel_offset'].fields.x + binned.coords['y_pixel_offset'] = components['pixel_offset'].fields.y + binned.coords['z_pixel_offset'] = components['pixel_offset'].fields.z + # Computed position + offset = detector_transformation_components()['offset'] + binned.coords['position'] = offset + components['pixel_offset'] + return binned + + +@pytest.fixture() +def expected_monitor() -> sc.DataArray: + return _monitor_histogram() + + +@pytest.fixture() +def expected_source() -> sc.DataGroup: + return _source_data() + + +@pytest.fixture() +def expected_sample() -> sc.DataGroup: + return _sample_data() + + +@pytest.mark.parametrize('entry_name', (None, nexus.NeXusEntryName('entry-001'))) +def test_load_detector(nexus_file, expected_bank12, entry_name): + detector = nexus.load_detector( + nexus_file, + detector_name=nexus.NeXusDetectorName('bank12'), + entry_name=entry_name, + ) + sc.testing.assert_identical(detector['bank12_events'], expected_bank12) + offset = detector_transformation_components()['offset'] + sc.testing.assert_identical( + detector['transform'], + sc.spatial.translation(unit=offset.unit, value=offset.value), + ) + + +def test_load_detector_requires_entry_name_if_not_unique(nexus_file): + if not isinstance(nexus_file, Path): + # For simplicity, only create a second entry in an actual file + return + + with snx.File(nexus_file, 'r+') as f: + f.create_class('entry', snx.NXentry) + + with pytest.raises(ValueError): + nexus.load_detector( + nexus.FilePath(nexus_file), + detector_name=nexus.NeXusDetectorName('bank12'), + entry_name=None, + ) + + +def test_load_detector_select_entry_if_not_unique(nexus_file, expected_bank12): + if not isinstance(nexus_file, Path): + # For simplicity, only create a second entry in an actual file + return + + with snx.File(nexus_file, 'r+') as f: + f.create_class('entry', snx.NXentry) + + detector = nexus.load_detector( + nexus.FilePath(nexus_file), + detector_name=nexus.NeXusDetectorName('bank12'), + entry_name=nexus.NeXusEntryName('entry-001'), + ) + sc.testing.assert_identical(detector['bank12_events'], expected_bank12) + + +@pytest.mark.parametrize('entry_name', (None, nexus.NeXusEntryName('entry-001'))) +def test_load_monitor(nexus_file, expected_monitor, entry_name): + monitor = nexus.load_monitor( + nexus_file, + monitor_name=nexus.NeXusMonitorName('monitor'), + entry_name=entry_name, + ) + sc.testing.assert_identical(monitor['data'], expected_monitor) + + +@pytest.mark.parametrize('entry_name', (None, nexus.NeXusEntryName('entry-001'))) +@pytest.mark.parametrize('source_name', (None, nexus.NeXusSourceName('source'))) +def test_load_source(nexus_file, expected_source, entry_name, source_name): + source = nexus.load_source( + nexus_file, + entry_name=entry_name, + source_name=source_name, + ) + # NeXus details that we don't need to test as long as the positions are ok: + del source['depends_on'] + del source['transformations'] + sc.testing.assert_identical(source, nexus.RawSource(expected_source)) + + +@pytest.mark.parametrize('entry_name', (None, nexus.NeXusEntryName('entry-001'))) +def test_load_sample(nexus_file, expected_sample, entry_name): + sample = nexus.load_sample(nexus_file, entry_name=entry_name) + sc.testing.assert_identical(sample, nexus.RawSample(expected_sample)) + + +def test_extract_detector_data(): + detector = sc.DataGroup( + { + 'jdl2ab': sc.data.binned_x(10, 3), + 'llk': 23, + ' _': sc.linspace('xx', 2, 3, 10), + } + ) + data = nexus.extract_detector_data(nexus.RawDetector(detector)) + sc.testing.assert_identical(data, nexus.RawDetectorData(detector['jdl2ab'])) + + +def test_extract_monitor_data(): + monitor = sc.DataGroup( + { + '(eed)': sc.data.data_xy(), + 'llk': 23, + ' _': sc.linspace('xx', 2, 3, 10), + } + ) + data = nexus.extract_monitor_data(nexus.RawMonitor(monitor)) + sc.testing.assert_identical(data, nexus.RawMonitorData(monitor['(eed)'])) + + +def test_extract_detector_data_requires_unique_dense_data(): + detector = sc.DataGroup( + { + 'jdl2ab': sc.data.data_xy(), + 'llk': 23, + 'lob': sc.data.data_xy(), + ' _': sc.linspace('xx', 2, 3, 10), + } + ) + with pytest.raises(ValueError): + nexus.extract_detector_data(nexus.RawDetector(detector)) + + +def test_extract_detector_data_requires_unique_event_data(): + detector = sc.DataGroup( + { + 'jdl2ab': sc.data.binned_x(10, 3), + 'llk': 23, + 'lob': sc.data.binned_x(14, 5), + ' _': sc.linspace('xx', 2, 3, 10), + } + ) + with pytest.raises(ValueError): + nexus.extract_detector_data(nexus.RawDetector(detector)) + + +def test_extract_detector_data_favors_event_data_over_histogram_data(): + detector = sc.DataGroup( + { + 'jdl2ab': sc.data.data_xy(), + 'llk': 23, + 'lob': sc.data.binned_x(14, 5), + ' _': sc.linspace('xx', 2, 3, 10), + } + ) + data = nexus.extract_detector_data(nexus.RawDetector(detector)) + sc.testing.assert_identical(data, nexus.RawDetectorData(detector['lob']))