Skip to content

Commit b380d96

Browse files
authored
Merge pull request #651 from scipp/bifrost-sqw
Changes to SQW writer for BIFROST
2 parents 1c17bd8 + 5797c25 commit b380d96

File tree

8 files changed

+186
-279
lines changed

8 files changed

+186
-279
lines changed

src/scippneutron/io/sqw/_build.py

Lines changed: 78 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import Any, BinaryIO
1313

1414
import numpy as np
15+
import numpy.typing as npt
1516
import scipp as sc
1617

1718
from .._files import open_or_pass
@@ -89,7 +90,7 @@ def __init__(
8990
("", "main_header"): main_header,
9091
}
9192

92-
self._dnd_placeholder: _DndPlaceholder | None = None
93+
self._dnd_data: _DndPlaceholder | _DndData | None = None
9394
self._pix_wrap: _PixWrap | None = None
9495
self._instrument: SqwIXNullInstrument | None = None
9596
self._sample: SqwIXSample | None = None
@@ -121,7 +122,7 @@ def create(self, *, chunk_size: int = 8192) -> Path | None:
121122
self._pix_wrap.write(sqw_io, chunk_size=chunk_size) # type: ignore[union-attr]
122123
case SqwDataBlockType.dnd:
123124
# Type guaranteed by _serialize_data_blocks
124-
self._dnd_placeholder.write(sqw_io) # type: ignore[union-attr]
125+
self._dnd_data.write(sqw_io) # type: ignore[union-attr]
125126
case _:
126127
raise NotImplementedError(
127128
f"Unsupported data block type: {descriptor.block_type}"
@@ -131,19 +132,18 @@ def create(self, *, chunk_size: int = 8192) -> Path | None:
131132

132133
def add_pixel_data(
133134
self,
134-
data: sc.DataArray,
135+
data: npt.NDArray[np.float32],
135136
*,
136137
experiments: list[SqwIXExperiment],
137138
n_dims: int = 4,
138-
rows: tuple[str, ...] = _DEFAULT_PIX_ROWS,
139139
row_units: tuple[str | None, ...] = _DEFAULT_PIX_ROW_UNITS,
140140
) -> SqwBuilder:
141141
self._n_dims = n_dims
142142
self._data_blocks[("experiment_info", "expdata")] = SqwMultiIXExperiment(
143143
experiments
144144
)
145145

146-
self._pix_wrap = _split_pix_rows(data, rows, row_units)
146+
self._pix_wrap = _PixWrap(row_data=data, row_units=row_units)
147147
metadata = self._make_pix_metadata(self._pix_wrap)
148148
self._data_blocks[("pix", "metadata")] = metadata
149149
self._data_blocks[("", "main_header")].nfiles = len(experiments)
@@ -164,11 +164,22 @@ def _add_dnd_metadata(self, block: SqwDndMetadata) -> SqwBuilder:
164164
self._data_blocks[("data", "metadata")] = block
165165
return self
166166

167-
def add_empty_dnd_data(self, block: SqwDndMetadata) -> SqwBuilder:
167+
def add_empty_dnd_data(self, metadata: SqwDndMetadata) -> SqwBuilder:
168168
# The file must always contain a DND block
169-
builder = self._add_dnd_metadata(block)
170-
builder._dnd_placeholder = _DndPlaceholder(
171-
shape=tuple(map(int, block.axes.n_bins_all_dims)) # type: ignore[call-overload]
169+
builder = self._add_dnd_metadata(metadata)
170+
builder._dnd_data = _DndPlaceholder(
171+
shape=tuple(map(int, metadata.axes.n_bins_all_dims)) # type: ignore[call-overload]
172+
)
173+
return builder
174+
175+
def add_dnd_data(
176+
self, metadata: SqwDndMetadata, *, data: sc.Variable, counts: sc.Variable
177+
) -> SqwBuilder:
178+
builder = self._add_dnd_metadata(metadata)
179+
builder._dnd_data = _DndData(
180+
shape=tuple(map(int, metadata.axes.n_bins_all_dims)), # type: ignore[call-overload]
181+
data=data,
182+
counts=counts,
172183
)
173184
return builder
174185

@@ -195,12 +206,10 @@ def _make_pix_metadata(self, pix_wrap: _PixWrap) -> SqwPixelMetadata:
195206
data_range=np.vstack(
196207
[
197208
(
198-
sc.to_unit(row.min(), unit).value,
199-
sc.to_unit(row.max(), unit).value,
200-
)
201-
for row, unit in zip(
202-
pix_wrap.row_data, pix_wrap.row_units, strict=True
209+
pix_wrap.row_data[:, i].min().astype(np.float64),
210+
pix_wrap.row_data[:, i].max().astype(np.float64),
203211
)
212+
for i, unit in enumerate(pix_wrap.row_units)
204213
]
205214
),
206215
)
@@ -232,15 +241,7 @@ def _serialize_data_blocks(
232241
locked=False,
233242
)
234243

235-
if self._dnd_placeholder is not None:
236-
buffers[("data", "nd_data")] = None
237-
descriptors[("data", "nd_data")] = SqwDataBlockDescriptor(
238-
block_type=SqwDataBlockType.dnd,
239-
name=("data", "nd_data"),
240-
position=0,
241-
size=self._dnd_placeholder.size(),
242-
locked=False,
243-
)
244+
self._serialize_dnd_data(buffers, descriptors)
244245

245246
if self._pix_wrap is not None:
246247
buffers[("pix", "data_wrap")] = None
@@ -254,6 +255,31 @@ def _serialize_data_blocks(
254255

255256
return buffers, descriptors
256257

258+
def _serialize_dnd_data(
259+
self,
260+
buffers: dict[DataBlockName, memoryview | None],
261+
descriptors: dict[tuple[str, str], SqwDataBlockDescriptor],
262+
) -> None:
263+
match self._dnd_data:
264+
case _DndPlaceholder() as placeholder:
265+
buffers[("data", "nd_data")] = None
266+
descriptors[("data", "nd_data")] = SqwDataBlockDescriptor(
267+
block_type=SqwDataBlockType.dnd,
268+
name=("data", "nd_data"),
269+
position=0,
270+
size=placeholder.size(),
271+
locked=False,
272+
)
273+
case _DndData() as data:
274+
buffers[("data", "nd_data")] = None
275+
descriptors[("data", "nd_data")] = SqwDataBlockDescriptor(
276+
block_type=SqwDataBlockType.dnd,
277+
name=("data", "nd_data"),
278+
position=0,
279+
size=data.size(),
280+
locked=False,
281+
)
282+
257283
def _prepare_data_blocks(self) -> dict[DataBlockName, Any]:
258284
filepath, filename = self._filepath_and_name
259285
blocks = {
@@ -380,33 +406,13 @@ def _to_canonical_block_order(
380406
("pix", "data_wrap"),
381407
)
382408
blocks = dict(blocks)
383-
out = {name: block for name in order if (block := blocks.get(name)) is not None}
409+
out = {
410+
name: block for name in order if (block := blocks.pop(name, None)) is not None
411+
}
384412
out.update(blocks) # append remaining blocks if any
385413
return out
386414

387415

388-
def _split_pix_rows(
389-
data: sc.DataArray, rows: tuple[str, ...], row_units: tuple[str | None, ...]
390-
) -> _PixWrap:
391-
"""Prepare the selected pixel rows for writing."""
392-
selected = []
393-
for name in rows:
394-
if name == 'signal':
395-
selected.append(sc.values(data.data))
396-
elif name == 'error':
397-
selected.append(sc.variances(data.data))
398-
else:
399-
if data.coords.is_edges(name, data.dim):
400-
raise sc.BinEdgeError(
401-
f"Pixel data must not contain bin-edges, got edges for '{name}'."
402-
)
403-
selected.append(data.coords[name])
404-
return _PixWrap(
405-
row_data=selected,
406-
row_units=row_units,
407-
)
408-
409-
410416
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
411417
class _DndPlaceholder:
412418
shape: tuple[int, ...]
@@ -425,34 +431,44 @@ def write(self, sqw_io: LowLevelSqw) -> None:
425431
sqw_io.write_array(np.zeros(self.shape, dtype="uint64"))
426432

427433

434+
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
435+
class _DndData:
436+
data: sc.Variable
437+
counts: sc.Variable
438+
shape: tuple[int, ...]
439+
440+
def size(self) -> int:
441+
n_elem = int(np.prod(self.shape))
442+
return 4 + 4 * len(self.shape) + 3 * 8 * n_elem
443+
444+
def write(self, sqw_io: LowLevelSqw) -> None:
445+
sqw_io.write_u32(len(self.shape))
446+
for s in self.shape:
447+
sqw_io.write_u32(s)
448+
sqw_io.write_array(self.data.values.astype("float64", copy=False))
449+
sqw_io.write_array(np.sqrt(self.data.variances.astype("float64", copy=False)))
450+
# Strictly speaking uint64, but we don't support that in Scipp.
451+
# So we have to assume that all numbers are positive anyway.
452+
# Then we can avoid unnecessary type casts.
453+
sqw_io.write_array(self.counts.values.astype("int64", copy=False))
454+
455+
428456
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
429457
class _PixWrap:
430-
row_data: list[sc.Variable]
458+
row_data: npt.NDArray[np.float32] # shape: (n_pixels, n_rows)
431459
row_units: tuple[str | None, ...]
432460

433461
def size(self) -> int:
434462
# *4 for f32
435463
return 4 + 8 + self.n_rows() * 4 * self.n_pixels()
436464

437465
def n_rows(self) -> int:
438-
return len(self.row_data)
466+
return self.row_data.shape[1]
439467

440468
def n_pixels(self) -> int:
441-
return len(self.row_data[0])
469+
return self.row_data.shape[0]
442470

443471
def write(self, sqw_io: LowLevelSqw, chunk_size: int) -> None:
444472
sqw_io.write_u32(self.n_rows())
445473
sqw_io.write_u64(self.n_pixels())
446-
447-
buffer = np.empty((self.n_pixels(), self.n_rows()), dtype=np.float32)
448-
remaining = self.n_pixels()
449-
for offset in range(0, self.n_rows(), chunk_size):
450-
n = min(chunk_size, remaining)
451-
remaining -= n
452-
for i_row, (row, unit) in enumerate(
453-
zip(self.row_data, self.row_units, strict=True)
454-
):
455-
buffer[:n, i_row] = sc.to_unit(
456-
row[offset : offset + chunk_size], unit, copy=False
457-
).values
458-
sqw_io.write_array(buffer[:n])
474+
sqw_io.write_array_fortran_layout(self.row_data)

src/scippneutron/io/sqw/_ir.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,5 +150,7 @@ def _serialize_field(
150150
ty=field.ty, shape=(len(field.value),) if field.value else (), data=[field]
151151
)
152152
if isinstance(field, Array):
153-
return ObjectArray(ty=field.ty, shape=field.value.shape[::-1], data=field.value)
153+
return ObjectArray(
154+
ty=field.ty, shape=field.value.shape[::-1], data=field.value.T
155+
)
154156
return ObjectArray(ty=field.ty, shape=(1,), data=[field])

src/scippneutron/io/sqw/_low_level_io.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def read_array(
155155
flat = flat.copy()
156156
self._file.seek(self.position + count * dtype.itemsize)
157157
else:
158-
flat = np.fromfile(self._file, dtype=dtype, count=int(np.prod(shape)))
158+
flat = np.fromfile(self._file, dtype=dtype, count=count)
159159
# Invert the shape because files use column-major layout.
160160
return flat.reshape(shape[::-1])
161161

@@ -198,6 +198,13 @@ def write_chars(self, value: str) -> None:
198198
@_annotate_write_exception("array")
199199
def write_array(
200200
self, array: npt.NDArray[np.float64] | npt.NDArray[np.float32]
201+
) -> None:
202+
# Transpose to match the column-major layout of the file.
203+
self.write_array_fortran_layout(array.T)
204+
205+
@_annotate_write_exception("array_fortran_layout")
206+
def write_array_fortran_layout(
207+
self, array: npt.NDArray[np.float64] | npt.NDArray[np.float32]
201208
) -> None:
202209
out = array.astype(array.dtype.newbyteorder(self.byteorder.value), copy=False)
203210
if isinstance(self._file, BytesIO):

src/scippneutron/io/sqw/_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def _serialize_to_dict(
252252
"version": ir.F64(self.version),
253253
"name": ir.String(self.name),
254254
"target_name": ir.String(self.target_name),
255-
"frequency": ir.F64(self.frequency.value), # TODO unit
255+
"frequency": ir.F64(self.frequency.to(unit='Hz').value),
256256
}
257257

258258

@@ -343,8 +343,8 @@ def _serialize_to_dict(
343343
"emode": ir.F64(float(self.emode.value)),
344344
"en": ir.Array(en.values, ty=ir.TypeTag.f64),
345345
"psi": ir.F64(_angle_value(self.psi)),
346-
"u": ir.Array(self.u.values, ty=ir.TypeTag.f64),
347-
"v": ir.Array(self.v.values, ty=ir.TypeTag.f64),
346+
"u": _variable_to_float_array(self.u, unit="1/angstrom"),
347+
"v": _variable_to_float_array(self.v, unit="1/angstrom"),
348348
"omega": ir.F64(_angle_value(self.omega)),
349349
"dpsi": ir.F64(_angle_value(self.dpsi)),
350350
"gl": ir.F64(_angle_value(self.gl)),

src/scippneutron/io/sqw/_sqw.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def get_vec(name: str, unit: str) -> sc.Variable | None:
342342

343343
return SqwLineProj(
344344
lattice_spacing=sc.vector(
345-
_get_struct_field(struct, "alatt").data, unit="1/angstrom"
345+
_get_struct_field(struct, "alatt").data, unit="angstrom"
346346
),
347347
lattice_angle=sc.vector(_get_struct_field(struct, "angdeg").data, unit="deg"),
348348
offset=_parse_1d_multi_unit_array(
@@ -394,8 +394,7 @@ def _parse_ix_source_2_0(struct: ir.Struct) -> SqwIXSource:
394394
return SqwIXSource(
395395
name=name,
396396
target_name=target_name,
397-
# TODO unit currently unknown
398-
frequency=sc.scalar(frequency, unit=None),
397+
frequency=sc.scalar(frequency, unit="Hz"),
399398
)
400399

401400

@@ -408,7 +407,7 @@ def _parse_ix_null_instrument_1_0(struct: ir.Struct) -> SqwIXNullInstrument:
408407
def _parse_ix_sample_0_0(struct: ir.Struct) -> SqwIXSample:
409408
name = _get_scalar_struct_field(struct, "name")
410409
lattice_spacing = sc.vector(
411-
_get_struct_field(struct, "alatt").data, unit="1/angstrom"
410+
_get_struct_field(struct, "alatt").data, unit="angstrom"
412411
)
413412
lattice_angle = sc.vector(_get_struct_field(struct, "angdeg").data, unit="deg")
414413
return SqwIXSample(
@@ -436,11 +435,18 @@ def g(n: str) -> Any:
436435
(e,) = candidate_efix
437436
efix = sc.scalar(e.value, unit="meV")
438437

438+
emode = EnergyMode(g("emode"))
439+
439440
raw_en = _get_struct_field(struct, "en").data
440-
if isinstance(raw_en, np.ndarray):
441-
en = raw_en.squeeze()
441+
raw_en = (
442+
raw_en
443+
if isinstance(raw_en, np.ndarray)
444+
else np.array([e.value for e in raw_en])
445+
)
446+
if emode == EnergyMode.indirect:
447+
en = sc.array(dims=["detector", "energy_transfer"], values=raw_en, unit="meV")
442448
else:
443-
en = [e.value for e in raw_en]
449+
en = sc.array(dims=["energy_transfer"], values=raw_en.squeeze(), unit="meV")
444450

445451
angle_unit = sc.Unit("deg" if g("angular_is_degree") else "rad")
446452

@@ -449,11 +455,11 @@ def g(n: str) -> Any:
449455
filepath=g("filepath"),
450456
run_id=int(g("run_id")) - 1,
451457
efix=efix,
452-
emode=EnergyMode(g("emode")),
453-
en=sc.array(dims=["energy_transfer"], values=en, unit="meV"),
458+
emode=emode,
459+
en=en,
454460
psi=sc.scalar(g("psi"), unit=angle_unit),
455-
u=sc.vector(_get_struct_field(struct, "u").data),
456-
v=sc.vector(_get_struct_field(struct, "v").data),
461+
u=sc.vector(_get_struct_field(struct, "u").data, unit="1/angstrom"),
462+
v=sc.vector(_get_struct_field(struct, "v").data, unit="1/angstrom"),
457463
omega=sc.scalar(g("omega"), unit=angle_unit),
458464
dpsi=sc.scalar(g("dpsi"), unit=angle_unit),
459465
gl=sc.scalar(g("gl"), unit=angle_unit),

0 commit comments

Comments
 (0)