From deb4102a44572d1be66bd4059743a5f03eb32586 Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Wed, 22 Jan 2025 15:24:42 +0100 Subject: [PATCH] Select run norm via override_input --- src/ess/dream/workflow.py | 4 ++-- src/ess/powder/correction.py | 37 ++++++++++++++++++++++++------------ src/ess/powder/types.py | 16 ++++++++++++++++ 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/src/ess/dream/workflow.py b/src/ess/dream/workflow.py index 97cd8256..92ef401e 100644 --- a/src/ess/dream/workflow.py +++ b/src/ess/dream/workflow.py @@ -13,7 +13,7 @@ from ess.powder import with_pixel_mask_filenames from ess.powder.correction import ( RunNormalization, - insert_run_normalization, + select_run_normalization, ) from ess.powder.types import ( AccumulatedProtonCharge, @@ -65,7 +65,7 @@ def DreamGeant4Workflow(*, run_norm: RunNormalization) -> sciline.Pipeline: wf = LoadGeant4Workflow() for provider in itertools.chain(powder_providers, _dream_providers): wf.insert(provider) - insert_run_normalization(wf, run_norm) + wf = select_run_normalization(wf, run_norm) for key, value in default_parameters().items(): wf[key] = value wf.typical_outputs = typical_outputs diff --git a/src/ess/powder/correction.py b/src/ess/powder/correction.py index 7f46e8c1..61648ab2 100644 --- a/src/ess/powder/correction.py +++ b/src/ess/powder/correction.py @@ -16,9 +16,12 @@ DataWithScatteringCoordinates, FocussedDataDspacing, FocussedDataDspacingTwoTheta, + HistogramMonitorNormalizedRunData, + IntegratedMonitorNormalizedRunData, IofDspacing, IofDspacingTwoTheta, NormalizedRunData, + ProtonChargeNormalizedRunData, RunType, SampleRun, UncertaintyBroadcastMode, @@ -32,7 +35,7 @@ def normalize_by_monitor_histogram( *, monitor: WavelengthMonitor[RunType, CaveMonitor], uncertainty_broadcast_mode: UncertaintyBroadcastMode, -) -> NormalizedRunData[RunType]: +) -> HistogramMonitorNormalizedRunData[RunType]: """Normalize detector data by a histogrammed monitor. Parameters @@ -55,7 +58,9 @@ def normalize_by_monitor_histogram( norm = broadcast_uncertainties( monitor, prototype=detector, mode=uncertainty_broadcast_mode ) - return detector.bins / sc.lookup(norm, dim="wavelength") + return HistogramMonitorNormalizedRunData[RunType]( + detector.bins / sc.lookup(norm, dim="wavelength") + ) def normalize_by_monitor_integrated( @@ -63,7 +68,7 @@ def normalize_by_monitor_integrated( *, monitor: WavelengthMonitor[RunType, CaveMonitor], uncertainty_broadcast_mode: UncertaintyBroadcastMode, -) -> NormalizedRunData[RunType]: +) -> IntegratedMonitorNormalizedRunData[RunType]: """Normalize detector data by an integrated monitor. The monitor is integrated according to @@ -116,7 +121,7 @@ def normalize_by_monitor_integrated( norm = broadcast_uncertainties( norm, prototype=detector, mode=uncertainty_broadcast_mode ) - return NormalizedRunData[RunType](detector / norm) + return IntegratedMonitorNormalizedRunData[RunType](detector / norm) def _expect_monitor_covers_range_of_detector( @@ -200,7 +205,7 @@ def normalize_by_vanadium_dspacing_and_two_theta( def normalize_by_proton_charge( data: DataWithScatteringCoordinates[RunType], proton_charge: AccumulatedProtonCharge[RunType], -) -> NormalizedRunData[RunType]: +) -> ProtonChargeNormalizedRunData[RunType]: """Normalize data by an accumulated proton charge. Parameters @@ -215,7 +220,7 @@ def normalize_by_proton_charge( : ``data / proton_charge`` """ - return NormalizedRunData[RunType](data / proton_charge) + return ProtonChargeNormalizedRunData[RunType](data / proton_charge) def merge_calibration(*, into: sc.DataArray, calibration: sc.Dataset) -> sc.DataArray: @@ -325,21 +330,29 @@ class RunNormalization(enum.Enum): proton_charge = enum.auto() -def insert_run_normalization( +def select_run_normalization( workflow: sciline.Pipeline, run_norm: RunNormalization -) -> None: - """Insert providers for a specific normalization into a workflow.""" +) -> sciline.Pipeline: + """Connect a specific normalization to the rest of a workflow.""" match run_norm: case RunNormalization.monitor_histogram: - workflow.insert(normalize_by_monitor_histogram) + return workflow.override_input( + NormalizedRunData[RunType], HistogramMonitorNormalizedRunData[RunType] + ) case RunNormalization.monitor_integrated: - workflow.insert(normalize_by_monitor_integrated) + return workflow.override_input( + NormalizedRunData[RunType], IntegratedMonitorNormalizedRunData[RunType] + ) case RunNormalization.proton_charge: - workflow.insert(normalize_by_proton_charge) + return workflow.override_input( + NormalizedRunData[RunType], ProtonChargeNormalizedRunData[RunType] + ) providers = ( normalize_by_proton_charge, + normalize_by_monitor_histogram, + normalize_by_monitor_integrated, normalize_by_vanadium_dspacing, normalize_by_vanadium_dspacing_and_two_theta, ) diff --git a/src/ess/powder/types.py b/src/ess/powder/types.py index 0132b818..c8fd989b 100644 --- a/src/ess/powder/types.py +++ b/src/ess/powder/types.py @@ -155,9 +155,25 @@ class WavelengthMonitor( class NormalizedRunData(sciline.Scope[RunType, sc.DataArray], sc.DataArray): + """Data that has been normalized by proton charge or monitor.""" + + +class ProtonChargeNormalizedRunData(sciline.Scope[RunType, sc.DataArray], sc.DataArray): """Data that has been normalized by proton charge.""" +class HistogramMonitorNormalizedRunData( + sciline.Scope[RunType, sc.DataArray], sc.DataArray +): + """Data that has been normalized by a histogrammed monitor.""" + + +class IntegratedMonitorNormalizedRunData( + sciline.Scope[RunType, sc.DataArray], sc.DataArray +): + """Data that has been normalized by an integrated monitor.""" + + PixelMaskFilename = NewType("PixelMaskFilename", str) """Filename of a pixel mask."""