Skip to content

Commit

Permalink
Update workflow to use frame unwrapping from scippneutron package
Browse files Browse the repository at this point in the history
  • Loading branch information
YooSunYoung committed Oct 31, 2023
1 parent 50715a8 commit f9593b8
Showing 1 changed file with 54 additions and 43 deletions.
97 changes: 54 additions & 43 deletions tests/prototypes/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,63 +19,41 @@
# Constants
FirstPulseTime = NewType("FirstPulseTime", sc.Variable)
FrameUnwrappingGraph = NewType("FrameUnwrappingGraph", dict)
CoordTransformGraph = NewType("CoordTransformGraph", dict)
LtotalGraph = NewType("LtotalGraph", dict)
WavelengthGraph = NewType("WavelengthGraph", dict)

# Generated/Calculated
Events = NewType("Events", List[sc.DataArray])
MergedData = NewType("MergedData", sc.DataArray)
PixelIDEdges = NewType("PixelIDEdges", sc.Variable)
PixelGrouped = NewType("PixelGrouped", sc.DataArray)
LTotalCalculated = NewType("Transformed", sc.DataArray)
FrameUnwrapped = NewType("FrameUnwrapped", sc.DataArray)
ReducedData = NewType("ReducedData", sc.DataArray)
Histogrammed = NewType("Histogrammed", sc.DataArray)


def provide_coord_transform_graph(
frame_rate: FrameRate, first_pulse_time: FirstPulseTime
) -> CoordTransformGraph:
from scipp.constants import h, m_n

lambda_min = sc.scalar(0, unit='angstrom')
frame_period = sc.scalar(1 / frame_rate, unit='ns') # No pulse skipping
scale_factor = (m_n / h).to(
unit=sc.units.us / sc.units.angstrom**2
) # All wavelength is in angstrom unit.
def provide_Ltotal_graph() -> LtotalGraph:
c_a = sc.scalar(0.00001, unit='m')
c_b = sc.scalar(0.1, unit='m')
c_c = sc.scalar(1, unit='1e-3m^2/s')

def time_offset_pivot(tof_min: sc.Variable, frame_offset: sc.Variable):
return (frame_offset + tof_min) % frame_period

def tof_from_time_offset(
event_time_offset: sc.Variable,
time_offset_pivot: sc.Variable,
tof_min: sc.Variable,
):
shift = tof_min - time_offset_pivot
tof = sc.where(
event_time_offset >= time_offset_pivot, shift, shift + frame_period
)
tof += event_time_offset
return tof

def wavelength_from_tof(tof, L):
return (c_c * tof / L).to(unit='angstrom')

return CoordTransformGraph(
return LtotalGraph(
{
'L1': lambda pixel_id: (pixel_id * c_a) + c_b,
'L2': lambda pixel_id: (pixel_id * c_a) + c_b,
'L': lambda L1, L2: L1 + L2,
'tof_min': lambda L: (L * scale_factor * lambda_min).to(unit=sc.units.ns),
'frame_offset': lambda event_time_zero: event_time_zero - first_pulse_time,
'time_offset_pivot': time_offset_pivot,
'tof': tof_from_time_offset,
'wavelength': wavelength_from_tof,
'Ltotal': lambda L1, L2: L1 + L2,
}
)


def provide_wavelength_graph() -> WavelengthGraph:
c_c = sc.scalar(1, unit='1e-3m^2/s')

return WavelengthGraph(
{'wavelength': lambda tof, Ltotal: (c_c * tof / Ltotal).to(unit='angstrom')}
)


def merge_data_list(da_list: Events) -> MergedData:
return MergedData(sc.concat(da_list, dim='event'))

Expand All @@ -88,17 +66,47 @@ def bin_pixel_id(da: MergedData, pixel_bin_coord: PixelIDEdges) -> PixelGrouped:
return PixelGrouped(da.group(pixel_bin_coord))


def transform_coords(
def calculate_ltotal(
binned: PixelGrouped,
graph: CoordTransformGraph,
graph: LtotalGraph,
) -> LTotalCalculated:
da = binned.transform_coords(['Ltotal'], graph=graph)
if not isinstance(da, sc.DataArray):
raise TypeError

return LTotalCalculated(da)


def unwrap_frames(
da: LTotalCalculated, frame_rate: FrameRate, first_pulse_time: FirstPulseTime
) -> FrameUnwrapped:
from scippneutron.tof import unwrap_frames

return FrameUnwrapped(
unwrap_frames(
da,
pulse_period=sc.scalar(1 / frame_rate, unit='ns'), # No pulse skipping
lambda_min=sc.scalar(5.0, unit='angstrom'),
frame_offset=first_pulse_time.to(unit='ms'),
first_pulse_time=first_pulse_time,
)
)


def calculate_wavelength(
unwrapped: FrameUnwrapped, graph: WavelengthGraph
) -> ReducedData:
return ReducedData(binned.transform_coords(['L', 'wavelength'], graph=graph))
da = unwrapped.transform_coords(['wavelength'], graph=graph)
if not isinstance(da, sc.DataArray):
raise TypeError

return ReducedData(da)


def histogram_result(
bin_size: HistogramBinSize, reduced_data: ReducedData
) -> Histogrammed:
return reduced_data.hist(wavelength=bin_size).sum('L')
return reduced_data.hist(wavelength=bin_size)


Workflow = NewType("Workflow", Factory)
Expand All @@ -108,10 +116,13 @@ def provide_workflow(
num_pixels: NumPixels, histogram_binsize: HistogramBinSize, frame_rate: FrameRate
) -> Workflow:
providers = ProviderGroup(
SingletonProvider(provide_wavelength_graph),
SingletonProvider(provide_Ltotal_graph),
merge_data_list,
bin_pixel_id,
provide_coord_transform_graph,
transform_coords,
calculate_ltotal,
calculate_wavelength,
unwrap_frames,
histogram_result,
provide_pixel_id_bin_edges,
)
Expand Down

0 comments on commit f9593b8

Please sign in to comment.