Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple fields names to parse dtype or units and add more tests. #235

Merged
merged 4 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 50 additions & 10 deletions src/beamlime/applications/_nexus_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,48 @@ def _validate_module_keys(
)


def _validate_f144_module_spec(
module_spec: StreamModuleValue,
) -> None:
"""Validate the f144 module."""
if len(module_spec.parent["children"]) != 1:
raise InvalidNexusStructureError(
"Group containing f144 module should have exactly one child"
)
if module_spec.dtype is None or module_spec.value_units is None:
raise InvalidNexusStructureError(
"f144 module spec should have dtype and value_units(or units)"
)


def _validate_f144_module_values(
key_value_dict: dict[StreamModuleKey, StreamModuleValue],
) -> None:
"""Validate the module values for the f144 module."""
for key, value in key_value_dict.items():
if key.module_type == "f144":
_validate_f144_module_spec(value)


def _validate_ev44_module_spec(
module_spec: StreamModuleValue,
) -> None:
"""Validate the ev44 module."""
if len(module_spec.parent["children"]) != 1:
raise InvalidNexusStructureError(
"Group containing ev44 module should have exactly one child"
)


def _validate_ev44_module_values(
key_value_dict: dict[StreamModuleKey, StreamModuleValue],
) -> None:
"""Validate the module values for the ev44 module."""
for key, value in key_value_dict.items():
if key.module_type == "ev44":
_validate_ev44_module_spec(value)


def collect_streaming_modules(
structure: Mapping,
) -> dict[StreamModuleKey, StreamModuleValue]:
Expand All @@ -208,8 +250,8 @@ def collect_streaming_modules(
# Modules do not have name so we remove the last element(None)
path=(parent_path := path[:-1]),
parent=cast(dict, find_nexus_structure(structure, parent_path)),
dtype=config.get("dtype"),
value_units=config.get("value_units"),
dtype=config.get("dtype", config.get("type")),
value_units=config.get("value_units", config.get("units")),
),
)
for path, node in iter_nexus_structure(structure)
Expand All @@ -222,7 +264,10 @@ def collect_streaming_modules(
)
_validate_module_configs(key_value_pairs)
_validate_module_keys(key_value_pairs)
return dict(key_value_pairs)
key_value_dict = dict(key_value_pairs)
_validate_f144_module_values(key_value_dict)
_validate_ev44_module_values(key_value_dict)
return key_value_dict


def create_dataset(
Expand Down Expand Up @@ -258,8 +303,7 @@ def _is_monitor(group: NexusGroup) -> bool:

def _initialize_ev44(module_spec: StreamModuleValue) -> NexusGroup:
parent = module_spec.parent
if len(parent['children']) != 1:
raise ValueError('Group containing ev44 module should have exactly one child')
_validate_ev44_module_spec(module_spec)
group: NexusGroup = cast(NexusGroup, parent.copy())
group['children'] = [
create_dataset(
Expand Down Expand Up @@ -346,11 +390,7 @@ def _merge_ev44(group: NexusGroup, data: DeserializedMessage) -> None:

def _initialize_f144(module_spec: StreamModuleValue) -> NexusGroup:
parent = module_spec.parent
if len(parent['children']) != 1:
raise ValueError('Group containing f144 module should have exactly one child')
if module_spec.dtype is None:
raise ValueError('f144 module spec should have dtype')

_validate_f144_module_spec(module_spec)
group: NexusGroup = cast(NexusGroup, parent.copy())
group["children"] = [
create_dataset(
Expand Down
197 changes: 157 additions & 40 deletions tests/applications/nexus_helpers_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
import hashlib
import json
import pathlib
from collections.abc import Generator, Mapping

import numpy as np
import pytest

from beamlime.applications._nexus_helpers import (
InvalidNexusStructureError,
StreamModuleKey,
StreamModuleValue,
collect_streaming_modules,
Expand All @@ -33,6 +33,14 @@ def ymir_streaming_modules(ymir: dict) -> dict[StreamModuleKey, StreamModuleValu
return collect_streaming_modules(ymir)


def _find_attributes(group: dict[str, list[dict]], attr_name: str) -> dict:
attributes = group.get("attributes", [])
try:
return next(attr for attr in attributes if attr["name"] == attr_name)
except StopIteration as e:
raise KeyError(f"Attribute {attr_name} not found in {group}") from e


def test_iter_nexus_structure() -> None:
expected_keys = [(), ('a',), ('a', 'c'), ('b',)]
test_structure = {
Expand Down Expand Up @@ -112,21 +120,6 @@ def test_find_nexus_structure_not_found_raises() -> None:
find_nexus_structure({}, ("b0",))


def test_invalid_nexus_template_multiple_module_placeholders() -> None:
with open(pathlib.Path(__file__).parent / "multiple_modules_datagroup.json") as f:
modules = collect_streaming_modules(json.load(f))

key = StreamModuleKey("ev44", "hypothetical_detector", "ymir_00")
spec = modules[key]
with pytest.raises(ValueError, match="should have exactly one child"):
merge_message_into_nexus_store(
module_key=key,
module_spec=spec,
nexus_store={},
data={}, # data does not matter since it reaches the error first.
)


def test_ymir_detector_template_checksum() -> None:
"""Test that the ymir template with detectors is updated.

Expand Down Expand Up @@ -170,15 +163,15 @@ def _is_class(partial_structure: Mapping, cls_name: str) -> bool:
)


def _is_detector(c: Mapping) -> bool:
return _is_class(c, "NXdetector")


def _is_event_data(c: Mapping) -> bool:
return _is_class(c, "NXevent_data")


def _find_event_time_zerov_values(c: Mapping) -> np.ndarray:
def _get_values(c: Mapping) -> np.ndarray:
return c["config"]["values"]


def _find_event_time_zero_values(c: Mapping) -> np.ndarray:
return find_nexus_structure(c, ("event_time_zero",))["config"]["values"]


Expand Down Expand Up @@ -222,26 +215,40 @@ def test_ev44_module_merging(
assert stored_value['name'] == 'ymir_detector_events'
assert len(stored_value['children']) == 4 # 4 datasets
# Test event time zero
event_time_zero_values = _find_event_time_zerov_values(stored_value)
event_time_zero = find_nexus_structure(stored_value, ("event_time_zero",))
event_time_zero_values = _get_values(event_time_zero)
event_time_zero_unit = _find_attributes(event_time_zero, "units")
assert event_time_zero_unit["values"] == "ns"
inserted_event_time_zeros = np.concatenate(
[d["reference_time"] for d in stored_data[key]]
)
assert np.all(event_time_zero_values == inserted_event_time_zeros)
# Test event time offset
event_time_offset_values = _find_event_time_offset_values(stored_value)
event_time_offset = find_nexus_structure(stored_value, ("event_time_offset",))
event_time_offset_values = _get_values(event_time_offset)
event_time_offset_unit = _find_attributes(event_time_offset, "units")
assert event_time_offset_unit["values"] == "ns"
inserted_event_time_offsets = np.concatenate(
[d["time_of_flight"] for d in stored_data[key]]
)
assert np.all(event_time_offset_values == inserted_event_time_offsets)
# Test event id
event_id_values = _find_event_id_values(stored_value)
event_id = find_nexus_structure(stored_value, ("event_id",))
event_id_values = _get_values(event_id)
with pytest.raises(KeyError, match="units"):
# Event id should not have units
_find_attributes(event_id, "units")
inserted_event_ids = np.concatenate([d["pixel_id"] for d in stored_data[key]])
assert np.all(event_id_values == inserted_event_ids)
# Test event index
# event index values are calculated based on the length of the previous events
first_event_length = len(stored_data[key][0]["time_of_flight"])
expected_event_indices = np.array([0, first_event_length])
event_index_values = _find_event_index_values(stored_value)
event_index = find_nexus_structure(stored_value, ("event_index",))
event_index_values = _get_values(event_index)
with pytest.raises(KeyError, match="units"):
# Event index should not have units
_find_attributes(event_index, "units")
assert np.all(event_index_values == expected_event_indices)


Expand Down Expand Up @@ -308,6 +315,46 @@ def test_nxevent_data_ev44_generator_yields_frame_by_frame() -> None:
next(ev44)


def test_ev44_merge_no_children_raises() -> None:
key = StreamModuleKey("ev44", "", "")
wrong_value = StreamModuleValue(
path=("",),
parent={"children": []},
dtype="int32",
value_units="km",
)
with pytest.raises(
InvalidNexusStructureError,
match="Group containing ev44 module should have exactly one child",
):
merge_message_into_nexus_store(
module_key=key,
module_spec=wrong_value,
nexus_store={},
data={},
)


def test_ev44_merge_too_many_children_raises() -> None:
key = StreamModuleKey("ev44", "", "")
wrong_value = StreamModuleValue(
path=("",),
parent={"children": []},
dtype="int32",
value_units="km",
)
with pytest.raises(
InvalidNexusStructureError,
match="Group containing ev44 module should have exactly one child",
):
merge_message_into_nexus_store(
module_key=key,
module_spec=wrong_value,
nexus_store={},
data={},
)


@pytest.fixture()
def nexus_template_with_streamed_log(dtype):
return {
Expand Down Expand Up @@ -343,7 +390,7 @@ def f144_event_generator(shape, dtype):

@pytest.mark.parametrize('shape', [(1,), (2,), (2, 2)])
@pytest.mark.parametrize('dtype', ['int32', 'uint32', 'float32', 'float64', 'bool'])
def test_f144(nexus_template_with_streamed_log, shape, dtype):
def test_f144_merge(nexus_template_with_streamed_log, shape, dtype):
modules = collect_streaming_modules(nexus_template_with_streamed_log)
f144_modules = {
key: value for key, value in modules.items() if key.module_type == 'f144'
Expand All @@ -367,7 +414,89 @@ def test_f144(nexus_template_with_streamed_log, shape, dtype):
times = find_nexus_structure(stored_value, ('time',))
assert times['module'] == 'dataset'
assert values['config']['values'].shape[1:] == shape
assert values['attributes'][0]['values'] == 'km'
unit_attr = _find_attributes(values, 'units')
assert unit_attr['values'] == 'km'
assert unit_attr['dtype'] == 'string'


def test_f144_merge_no_children_raises():
key = StreamModuleKey(module_type='f144', topic='', source='')
wrong_value = StreamModuleValue(
path=('',),
parent={'children': []},
dtype='int32',
value_units='km',
)
with pytest.raises(
InvalidNexusStructureError,
match="Group containing f144 module should have exactly one child",
):
merge_message_into_nexus_store(
module_key=key,
module_spec=wrong_value,
nexus_store={},
data={},
)


def test_f144_merge_too_many_children_raises():
key = StreamModuleKey(module_type='f144', topic='', source='')
wrong_value = StreamModuleValue(
path=('',),
parent={'children': [{}, {}]},
dtype='int32',
value_units='km',
)
with pytest.raises(
InvalidNexusStructureError,
match="Group containing f144 module should have exactly one child",
):
merge_message_into_nexus_store(
module_key=key,
module_spec=wrong_value,
nexus_store={},
data={},
)


def test_f144_merge_missing_dtype_raises():
key = StreamModuleKey(module_type='f144', topic='', source='')
wrong_value = StreamModuleValue(
path=('',),
parent={'children': [{}]},
dtype=None,
value_units='km',
)
with pytest.raises(
InvalidNexusStructureError,
match="f144 module spec should have dtype and value_units",
):
merge_message_into_nexus_store(
module_key=key,
module_spec=wrong_value,
nexus_store={},
data={},
)


def test_f144_merge_missing_value_units_raises():
key = StreamModuleKey(module_type='f144', topic='', source='')
wrong_value = StreamModuleValue(
path=('',),
parent={'children': [{}]},
dtype='int32',
value_units=None,
)
with pytest.raises(
InvalidNexusStructureError,
match="f144 module spec should have dtype and value_units",
):
merge_message_into_nexus_store(
module_key=key,
module_spec=wrong_value,
nexus_store={},
data={},
)


@pytest.fixture()
Expand Down Expand Up @@ -401,7 +530,7 @@ def tdct_event_generator():
max_last_timestamp = timestamps.max()


def test_tdct(nexus_template_with_streamed_tdct: dict):
def test_tdct_merge(nexus_template_with_streamed_tdct: dict):
modules = collect_streaming_modules(nexus_template_with_streamed_tdct)
tdct_modules = {
key: value for key, value in modules.items() if key.module_type == 'tdct'
Expand All @@ -423,15 +552,3 @@ def test_tdct(nexus_template_with_streamed_tdct: dict):
assert np.issubdtype(
tdct['config']['values'].dtype, np.dtype(tdct['config']['dtype'])
)


@pytest.fixture()
def nexus_template_with_mixed_streams(
nexus_template_with_streamed_log, nexus_template_with_streamed_tdct
):
return {
"children": [
nexus_template_with_streamed_log,
nexus_template_with_streamed_tdct,
],
}
Loading