diff --git a/src/ess/reduce/live/raw.py b/src/ess/reduce/live/raw.py index 7dd119e5..358ae7d4 100644 --- a/src/ess/reduce/live/raw.py +++ b/src/ess/reduce/live/raw.py @@ -110,7 +110,8 @@ def from_coords( def __call__(self, da: sc.DataArray) -> sc.DataArray: self._current += 1 coords = self._coords[self._replica_dim, self._current % self._replicas] - return sc.DataArray(da.data, coords=coords).hist(self._edges) + # If input is multi-dim we need to flatten since those dims cannot be preserved. + return sc.DataArray(da.data, coords=coords).flatten(to='_').hist(self._edges) def input_indices(self) -> sc.DataArray: """Return an array with input indices corresponding to each histogram bin.""" @@ -247,6 +248,7 @@ def __init__( self._cache = self._history.sum('window') def make_roi_filter(self) -> roi.ROIFilter: + """Return a ROI filter operating via the projection plane of the view.""" norm = 1.0 if isinstance(self._projection, Histogrammer): indices = self._projection.input_indices() diff --git a/tests/live/raw_test.py b/tests/live/raw_test.py index b34f8a05..6aa997de 100644 --- a/tests/live/raw_test.py +++ b/tests/live/raw_test.py @@ -222,3 +222,96 @@ def test_histogrammer_input_indices() -> None: assert set(indices.coords) == {'x', 'y'} assert indices.bins.size().sum().value == nx * ny * nz assert indices.sizes == resolution + + +def test_ROIFilter_from_trivial_RollingDetectorView() -> None: + detector_number = sc.array( + dims=['x', 'y'], values=[[1, 2, 3], [4, 5, 6]], unit=None + ) + view = raw.RollingDetectorView(detector_number=detector_number, window=2) + roi_filter = view.make_roi_filter() + data = detector_number.copy() + data.unit = 'counts' + flat = data.flatten(to='detector_number') + + result, scale = roi_filter.apply(data) + # ROIFilter defaults to include nothing + assert sc.identical(result, flat[0:0]) + assert sc.identical(scale, sc.zeros(dims=['detector_number'], shape=[0])) + + roi_filter.set_roi_from_intervals(sc.DataGroup(x=(1, 2))) + result, scale = roi_filter.apply(data) + assert sc.identical(result, flat[3:6]) + assert sc.identical(scale, sc.ones(dims=['detector_number'], shape=[3])) + + roi_filter.set_roi_from_intervals(sc.DataGroup(x=(1, 2), y=(1, 3))) + result, scale = roi_filter.apply(data) + assert sc.identical(result, flat[4:6]) + assert sc.identical(scale, sc.ones(dims=['detector_number'], shape=[2])) + + +def test_ROIFilter_from_RollingDetectorView_with_LogicalView() -> None: + logical_view = raw.LogicalView(select={'z': 0}) + detector_number = sc.array( + dims=['x', 'y', 'z'], values=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], unit=None + ) + view = raw.RollingDetectorView( + detector_number=detector_number, window=2, projection=logical_view + ) + roi_filter = view.make_roi_filter() + data = detector_number.copy() + data.unit = 'counts' + flat = data['z', 0].flatten(to='detector_number') + + result, scale = roi_filter.apply(data) + # ROIFilter defaults to include nothing + assert sc.identical(result, flat[0:0]) + assert sc.identical(scale, sc.zeros(dims=['detector_number'], shape=[0])) + + roi_filter.set_roi_from_intervals(sc.DataGroup(x=(1, 2))) + result, scale = roi_filter.apply(data) + assert sc.identical(result, flat[2:4]) + assert sc.identical(scale, sc.ones(dims=['detector_number'], shape=[2])) + + roi_filter.set_roi_from_intervals(sc.DataGroup(x=(1, 2), y=(1, 3))) + result, scale = roi_filter.apply(data) + assert sc.identical(result, flat[3:4]) + assert sc.identical(scale, sc.ones(dims=['detector_number'], shape=[1])) + + +def test_ROIFilter_from_RollingDetectorView_with_xy_projection() -> None: + detector_number = sc.array( + dims=['x', 'y', 'z'], values=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], unit=None + ) + nx, ny, nz = 2, 2, 2 + coords = raw.project_xy( + make_grid_cube(nx=nx, ny=ny, nz=nz, center=(0.0, 0.0, 10.0)) + ) + coords = sc.concat( + [coords.fold(dim='point', sizes=detector_number.sizes)], 'replica' + ) + + resolution = {'x': 4, 'y': 4} + histogrammer = raw.Histogrammer.from_coords(coords=coords, resolution=resolution) + view = raw.RollingDetectorView( + detector_number=detector_number, window=1, projection=histogrammer + ) + roi_filter = view.make_roi_filter() + data = detector_number.copy() + data.unit = 'counts' + flat = data.flatten(to='detector_number') + + result, scale = roi_filter.apply(data) + # ROIFilter defaults to include nothing + assert sc.identical(result, flat[0:0]) + assert sc.identical(scale, sc.zeros(dims=['detector_number'], shape=[0])) + + roi_filter.set_roi_from_intervals(sc.DataGroup(x=(0, 2))) + result, scale = roi_filter.apply(data) + assert sc.identical(scale, sc.ones(dims=['detector_number'], shape=[4])) + assert sc.identical(result, flat[:4]) + + roi_filter.set_roi_from_intervals(sc.DataGroup(x=(0, 2), y=(0, 2))) + result, scale = roi_filter.apply(data) + assert sc.identical(scale, sc.ones(dims=['detector_number'], shape=[2])) + assert sc.identical(result, flat[0:2])