Skip to content

Commit

Permalink
Test ROIFilter setup
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonHeybrock committed Jan 31, 2025
1 parent a83cfcd commit 0e038c8
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/ess/reduce/live/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def from_coords(
def __call__(self, da: sc.DataArray) -> sc.DataArray:
self._current += 1
coords = self._coords[self._replica_dim, self._current % self._replicas]
return sc.DataArray(da.data, coords=coords).hist(self._edges)
# If input is multi-dim we need to flatten since those dims cannot be preserved.
return sc.DataArray(da.data, coords=coords).flatten(to='_').hist(self._edges)

def input_indices(self) -> sc.DataArray:
"""Return an array with input indices corresponding to each histogram bin."""
Expand Down Expand Up @@ -247,6 +248,7 @@ def __init__(
self._cache = self._history.sum('window')

def make_roi_filter(self) -> roi.ROIFilter:
"""Return a ROI filter operating via the projection plane of the view."""
norm = 1.0
if isinstance(self._projection, Histogrammer):
indices = self._projection.input_indices()
Expand Down
93 changes: 93 additions & 0 deletions tests/live/raw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,96 @@ def test_histogrammer_input_indices() -> None:
assert set(indices.coords) == {'x', 'y'}
assert indices.bins.size().sum().value == nx * ny * nz
assert indices.sizes == resolution


def test_ROIFilter_from_trivial_RollingDetectorView() -> None:
detector_number = sc.array(
dims=['x', 'y'], values=[[1, 2, 3], [4, 5, 6]], unit=None
)
view = raw.RollingDetectorView(detector_number=detector_number, window=2)
roi_filter = view.make_roi_filter()
data = detector_number.copy()
data.unit = 'counts'
flat = data.flatten(to='detector_number')

result, scale = roi_filter.apply(data)
# ROIFilter defaults to include nothing
assert sc.identical(result, flat[0:0])
assert sc.identical(scale, sc.zeros(dims=['detector_number'], shape=[0]))

roi_filter.set_roi_from_intervals(sc.DataGroup(x=(1, 2)))
result, scale = roi_filter.apply(data)
assert sc.identical(result, flat[3:6])
assert sc.identical(scale, sc.ones(dims=['detector_number'], shape=[3]))

roi_filter.set_roi_from_intervals(sc.DataGroup(x=(1, 2), y=(1, 3)))
result, scale = roi_filter.apply(data)
assert sc.identical(result, flat[4:6])
assert sc.identical(scale, sc.ones(dims=['detector_number'], shape=[2]))


def test_ROIFilter_from_RollingDetectorView_with_LogicalView() -> None:
logical_view = raw.LogicalView(select={'z': 0})
detector_number = sc.array(
dims=['x', 'y', 'z'], values=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], unit=None
)
view = raw.RollingDetectorView(
detector_number=detector_number, window=2, projection=logical_view
)
roi_filter = view.make_roi_filter()
data = detector_number.copy()
data.unit = 'counts'
flat = data['z', 0].flatten(to='detector_number')

result, scale = roi_filter.apply(data)
# ROIFilter defaults to include nothing
assert sc.identical(result, flat[0:0])
assert sc.identical(scale, sc.zeros(dims=['detector_number'], shape=[0]))

roi_filter.set_roi_from_intervals(sc.DataGroup(x=(1, 2)))
result, scale = roi_filter.apply(data)
assert sc.identical(result, flat[2:4])
assert sc.identical(scale, sc.ones(dims=['detector_number'], shape=[2]))

roi_filter.set_roi_from_intervals(sc.DataGroup(x=(1, 2), y=(1, 3)))
result, scale = roi_filter.apply(data)
assert sc.identical(result, flat[3:4])
assert sc.identical(scale, sc.ones(dims=['detector_number'], shape=[1]))


def test_ROIFilter_from_RollingDetectorView_with_xy_projection() -> None:
detector_number = sc.array(
dims=['x', 'y', 'z'], values=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], unit=None
)
nx, ny, nz = 2, 2, 2
coords = raw.project_xy(
make_grid_cube(nx=nx, ny=ny, nz=nz, center=(0.0, 0.0, 10.0))
)
coords = sc.concat(
[coords.fold(dim='point', sizes=detector_number.sizes)], 'replica'
)

resolution = {'x': 4, 'y': 4}
histogrammer = raw.Histogrammer.from_coords(coords=coords, resolution=resolution)
view = raw.RollingDetectorView(
detector_number=detector_number, window=1, projection=histogrammer
)
roi_filter = view.make_roi_filter()
data = detector_number.copy()
data.unit = 'counts'
flat = data.flatten(to='detector_number')

result, scale = roi_filter.apply(data)
# ROIFilter defaults to include nothing
assert sc.identical(result, flat[0:0])
assert sc.identical(scale, sc.zeros(dims=['detector_number'], shape=[0]))

roi_filter.set_roi_from_intervals(sc.DataGroup(x=(0, 2)))
result, scale = roi_filter.apply(data)
assert sc.identical(scale, sc.ones(dims=['detector_number'], shape=[4]))
assert sc.identical(result, flat[:4])

roi_filter.set_roi_from_intervals(sc.DataGroup(x=(0, 2), y=(0, 2)))
result, scale = roi_filter.apply(data)
assert sc.identical(scale, sc.ones(dims=['detector_number'], shape=[2]))
assert sc.identical(result, flat[0:2])

0 comments on commit 0e038c8

Please sign in to comment.