Skip to content

Commit 8d0ef5a

Browse files
authored
Merge pull request #561 from scipp/multi-distance-frames
Fix propagation of chopper cascade frames to an array of distances
2 parents bff8428 + 148d5f3 commit 8d0ef5a

File tree

4 files changed

+82
-76
lines changed

4 files changed

+82
-76
lines changed

docs/user-guide/wfm/wfm-time-of-flight.ipynb

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -312,21 +312,16 @@
312312
"metadata": {},
313313
"outputs": [],
314314
"source": [
315-
"threshold = 0.01 # Select part of the pulse where signal is above 1% of max counts\n",
316-
"\n",
317-
"t = one_pulse.hist(time=300)\n",
318-
"tsel = t.data > threshold * t.data.max()\n",
319-
"pulse_times = sc.midpoints(t.coords[\"time\"])[tsel]\n",
320-
"\n",
321-
"w = one_pulse.hist(wavelength=300)\n",
322-
"wsel = w.data > threshold * w.data.max()\n",
323-
"pulse_wavs = sc.midpoints(w.coords[\"wavelength\"])[wsel]\n",
315+
"time_min = sc.scalar(0.0, unit='ms')\n",
316+
"time_max = sc.scalar(3.4, unit='ms')\n",
317+
"wavs_min = sc.scalar(0.01, unit='angstrom')\n",
318+
"wavs_max = sc.scalar(10.0, unit='angstrom')\n",
324319
"\n",
325320
"frames = chopper_cascade.FrameSequence.from_source_pulse(\n",
326-
" time_min=pulse_times.min(),\n",
327-
" time_max=pulse_times.max(),\n",
328-
" wavelength_min=pulse_wavs.min(),\n",
329-
" wavelength_max=pulse_wavs.max(),\n",
321+
" time_min=time_min,\n",
322+
" time_max=time_max,\n",
323+
" wavelength_min=wavs_min,\n",
324+
" wavelength_max=wavs_max,\n",
330325
")"
331326
]
332327
},
@@ -398,8 +393,8 @@
398393
"\n",
399394
"workflow[unwrap.PulsePeriod] = sc.reciprocal(ess_beamline.source.frequency)\n",
400395
"workflow[unwrap.PulseStride | None] = None\n",
401-
"workflow[unwrap.SourceTimeRange] = pulse_times.min(), pulse_times.max()\n",
402-
"workflow[unwrap.SourceWavelengthRange] = pulse_wavs.min(), pulse_wavs.max()\n",
396+
"workflow[unwrap.SourceTimeRange] = time_min, time_max\n",
397+
"workflow[unwrap.SourceWavelengthRange] = wavs_min, wavs_max\n",
403398
"workflow[unwrap.Choppers] = choppers\n",
404399
"\n",
405400
"workflow[unwrap.Ltotal] = Ltotal\n",

src/scippneutron/tof/chopper_cascade.py

Lines changed: 39 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class Subframe:
5656
"""
5757

5858
def __init__(self, time: sc.Variable, wavelength: sc.Variable):
59-
if time.sizes != wavelength.sizes:
59+
if {dim: time.sizes.get(dim) for dim in wavelength.sizes} != wavelength.sizes:
6060
raise sc.DimensionError(
6161
f'Inconsistent dims or shape: {time.sizes} vs {wavelength.sizes}'
6262
)
@@ -106,22 +106,33 @@ def propagate_by(self, distance: sc.Variable) -> Subframe:
106106

107107
@property
108108
def start_time(self) -> sc.Variable:
109-
"""The start time of the subframe."""
110-
return self.time.min()
109+
"""The start time of the subframe for each of the distances in self.time."""
110+
# The `self.time` may have an additional dimension for distance, compared to
111+
# `self.wavelength`, and we need to keep that dimension in the output.
112+
out = self.time
113+
for dim in self.wavelength.dims:
114+
out = out.min(dim)
115+
return out
111116

112117
@property
113118
def end_time(self) -> sc.Variable:
114-
"""The end time of the subframe."""
115-
return self.time.max()
119+
"""The end time of the subframe for each of the distances in self.time."""
120+
# The `self.time` may have an additional dimension for distance, compared to
121+
# `self.wavelength`, and we need to keep that dimension in the output.
122+
out = self.time
123+
for dim in self.wavelength.dims:
124+
out = out.max(dim)
125+
return out
116126

117127
@property
118128
def start_wavelength(self) -> sc.Variable:
119-
"""The start wavelength of the subframe."""
129+
"""The start wavelength of the subframe for each of the distances in
130+
self.time"""
120131
return self.wavelength.min()
121132

122133
@property
123134
def end_wavelength(self) -> sc.Variable:
124-
"""The end wavelength of the subframe."""
135+
"""The end wavelength of the subframe for each of the distances in self.time."""
125136
return self.wavelength.max()
126137

127138

@@ -206,19 +217,14 @@ def bounds(self) -> sc.DataGroup:
206217

207218
def subbounds(self) -> sc.DataGroup:
208219
"""
209-
The bounds of the subframes, defined as the union over subframes.
210-
211-
This is not the same as the bounds of the individual subframes, but defined as
212-
the union of all subframes. Subframes that overlap in time are "merged" into a
213-
single subframe.
214-
215-
This function is to some extent experimental: It is not clear if taking the
216-
union of overlapping subframes has any utility in practice, since this may
217-
simply indicate a problem with the chopper cascade. Attempts to handle this
218-
automatically may be misguided.
220+
The bounds of the individual subframes, stored as a DataGroup.
219221
"""
220-
starts = [subframe.start_time for subframe in self.subframes]
221-
ends = [subframe.end_time for subframe in self.subframes]
222+
starts = sc.concat(
223+
[subframe.start_time for subframe in self.subframes], dim='subframe'
224+
)
225+
ends = sc.concat(
226+
[subframe.end_time for subframe in self.subframes], dim='subframe'
227+
)
222228
# Given how time-propagation and chopping works, the min wavelength is always
223229
# given by the same vertex as the min time, and the max wavelength by the same
224230
# vertex as the max time. Thus, this check should generally always pass.
@@ -228,48 +234,21 @@ def subbounds(self) -> sc.DataGroup:
228234
'Subframes must be regular, i.e., min/max time and wavelength must '
229235
'coincide.'
230236
)
231-
wav_starts = [subframe.start_wavelength for subframe in self.subframes]
232-
wav_ends = [subframe.end_wavelength for subframe in self.subframes]
233-
234-
@dataclass
235-
class Bound:
236-
start: sc.Variable
237-
end: sc.Variable
238-
wav_start: sc.Variable
239-
wav_end: sc.Variable
240-
241-
bounds = [
242-
Bound(start, end, wav_start, wav_end)
243-
for start, end, wav_start, wav_end in zip(
244-
starts, ends, wav_starts, wav_ends, strict=True
245-
)
246-
]
247-
bounds = sorted(bounds, key=lambda x: x.start)
248-
current = bounds[0]
249-
merged_bounds = []
250-
for bound in bounds[1:]:
251-
# If start is before current end, merge
252-
if bound.start <= current.end:
253-
current = Bound(
254-
current.start,
255-
max(current.end, bound.end),
256-
current.wav_start,
257-
max(current.wav_end, bound.wav_end),
258-
)
259-
else:
260-
merged_bounds.append(current)
261-
current = bound
262-
merged_bounds.append(current)
263-
time_bounds = [
264-
sc.concat([bound.start, bound.end], dim='bound') for bound in merged_bounds
265-
]
266-
wav_bounds = [
267-
sc.concat([bound.wav_start, bound.wav_end], dim='bound')
268-
for bound in merged_bounds
269-
]
237+
wav_starts = sc.concat(
238+
[subframe.start_wavelength for subframe in self.subframes], dim='subframe'
239+
)
240+
wav_ends = sc.concat(
241+
[subframe.end_wavelength for subframe in self.subframes], dim='subframe'
242+
)
243+
244+
time = sc.concat([starts, ends], dim='bound')
245+
wavelength = sc.concat([wav_starts, wav_ends], dim='bound')
246+
time_dims = list(set(time.dims) - {'subframe', 'bound'})
247+
wavelength_dims = list(set(wavelength.dims) - {'subframe', 'bound'})
248+
270249
return sc.DataGroup(
271-
time=sc.concat(time_bounds, dim='subframe'),
272-
wavelength=sc.concat(wav_bounds, dim='subframe'),
250+
time=time.transpose([*time_dims, 'subframe', 'bound']),
251+
wavelength=wavelength.transpose([*wavelength_dims, 'subframe', 'bound']),
273252
)
274253

275254

tests/tof/chopper_cascade_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,38 @@ def test_frame_sequence_propagate_to_returns_new_sequence_with_added_propagated_
391391
assert result2[2] == result[1].propagate_to(distance * 2)
392392

393393

394+
def test_frame_sequence_propagate_to_array_of_distances(
395+
source_frame_sequence: chopper_cascade.FrameSequence,
396+
) -> None:
397+
frames = source_frame_sequence
398+
distances = sc.array(dims=['distance'], values=[1.5, 1.7], unit='m')
399+
result = frames.propagate_to(distances)
400+
assert len(frames) == 1
401+
assert len(result) == 2
402+
assert result[1] == frames[0].propagate_to(distances)
403+
subframe = result[1].subframes[0]
404+
assert subframe.start_time.sizes["distance"] == 2
405+
assert subframe.end_time.sizes["distance"] == 2
406+
407+
408+
def test_frame_sequence_propagate_to_array_of_distances_subbounds(
409+
source_frame_sequence: chopper_cascade.FrameSequence,
410+
) -> None:
411+
frames = source_frame_sequence
412+
distances = sc.array(dims=['distance'], values=[1.5, 1.7], unit='m')
413+
result = frames.propagate_to(distances)
414+
subbounds = result[-1].subbounds()
415+
subframe = result[1].subframes[0]
416+
assert_identical(subbounds['time']['subframe', 0]['bound', 0], subframe.start_time)
417+
assert_identical(subbounds['time']['subframe', 0]['bound', 1], subframe.end_time)
418+
assert_identical(
419+
subbounds['wavelength']['subframe', 0]['bound', 0], subframe.start_wavelength
420+
)
421+
assert_identical(
422+
subbounds['wavelength']['subframe', 0]['bound', 1], subframe.end_wavelength
423+
)
424+
425+
394426
def test_frame_sequence_chop_returns_new_sequence_with_added_chopped_frames(
395427
source_frame_sequence: chopper_cascade.FrameSequence,
396428
) -> None:

tests/tof/wfm_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def test_compute_wavelengths_from_wfm(disk_choppers, npulses):
149149
# computed correctly
150150
workflow[unwrap.SourceTimeRange] = (
151151
sc.scalar(0.0, unit='ms'),
152-
sc.scalar(4.0, unit='ms'),
152+
sc.scalar(3.4, unit='ms'),
153153
)
154154
workflow[unwrap.SourceWavelengthRange] = (
155155
sc.scalar(0.01, unit='angstrom'),

0 commit comments

Comments
 (0)