Skip to content

Commit

Permalink
Merge pull request #169 from scipp/tofworkflow-bug
Browse files Browse the repository at this point in the history
Change TofWorkflow API to not mirror the Pipeline api
  • Loading branch information
nvaytet authored Jan 28, 2025
2 parents 23cb90b + fa405ba commit 861eb73
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 36 deletions.
7 changes: 1 addition & 6 deletions src/ess/reduce/time_of_flight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@
neutron time-of-arrival at the detectors.
"""

from .toa_to_tof import (
default_parameters,
resample_tof_data,
providers,
TofWorkflow,
)
from .toa_to_tof import default_parameters, resample_tof_data, providers, TofWorkflow
from .simulation import simulate_beamline
from .types import (
DistanceResolution,
Expand Down
53 changes: 23 additions & 30 deletions src/ess/reduce/time_of_flight/toa_to_tof.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
event_time_offset coordinates to data with a time-of-flight coordinate.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import Any

import numpy as np
import scipp as sc
Expand Down Expand Up @@ -461,8 +458,8 @@ def providers() -> tuple[Callable]:

class TofWorkflow:
"""
Helper class to build a time-of-flight workflow and cache the expensive part of
the computation: running the simulation and building the lookup table.
Helper class to build a time-of-flight workflow and cache the expensive part of the
computation: running the simulation and building the lookup table.
Parameters
----------
Expand All @@ -472,16 +469,17 @@ class TofWorkflow:
wavelength, and weight.
ltotal_range:
Range of total flight path lengths from the source to the detector.
This is used to create the lookup table to compute the neutron time-of-flight.
This is used to create the lookup table to compute the neutron
time-of-flight.
Note that the resulting table will extend slightly beyond this range, as the
supplied range is not necessarily a multiple of the distance resolution.
pulse_stride:
Stride of used pulses. Usually 1, but may be a small integer when
pulse-skipping.
pulse_stride_offset:
Integer offset of the first pulse in the stride (typically zero unless we are
using pulse-skipping and the events do not begin with the first pulse in the
stride).
Integer offset of the first pulse in the stride (typically zero unless we
are using pulse-skipping and the events do not begin with the first pulse in
the stride).
distance_resolution:
Resolution of the distance axis in the lookup table.
Should be a single scalar value with a unit of length.
Expand All @@ -490,8 +488,8 @@ class TofWorkflow:
Resolution of the time of arrival axis in the lookup table.
Can be an integer (number of bins) or a sc.Variable (bin width).
error_threshold:
Threshold for the variance of the projected time-of-flight above which regions
are masked.
Threshold for the variance of the projected time-of-flight above which
regions are masked.
"""

def __init__(
Expand Down Expand Up @@ -526,23 +524,18 @@ def __init__(
error_threshold or params[LookupTableRelativeErrorThreshold]
)

def __getitem__(self, key):
return self.pipeline[key]

def __setitem__(self, key, value):
self.pipeline[key] = value

def persist(self) -> None:
for t in (SimulationResults, MaskedTimeOfFlightLookupTable, FastestNeutron):
def cache_results(
self,
results=(SimulationResults, MaskedTimeOfFlightLookupTable, FastestNeutron),
) -> None:
"""
Cache a list of (usually expensive to compute) intermediate results of the
time-of-flight workflow.
Parameters
----------
results:
List of results to cache.
"""
for t in results:
self.pipeline[t] = self.pipeline.compute(t)

def compute(self, *args, **kwargs) -> Any:
return self.pipeline.compute(*args, **kwargs)

def visualize(self, *args, **kwargs) -> Any:
return self.pipeline.visualize(*args, **kwargs)

def copy(self) -> TofWorkflow:
out = self.__class__(choppers=None, facility=None, ltotal_range=None)
out.pipeline = self.pipeline.copy()
return out

0 comments on commit 861eb73

Please sign in to comment.