Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Grocery keyset for increased safety calling Recipe #528

Merged
merged 2 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class CalculateResidualRequest(BaseModel):
inputWorkspace: WorkspaceName
outputWorkspace: WorkspaceName
fitPeaksDiagnostic: WorkspaceName
fitPeaksDiagnosticWorkspace: WorkspaceName

model_config = ConfigDict(
# required in order to use 'WorkspaceName'
Expand Down
7 changes: 5 additions & 2 deletions src/snapred/backend/recipe/ApplyNormalizationRecipe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Set, Tuple

from snapred.backend.dao.ingredients import ApplyNormalizationIngredients as Ingredients
from snapred.backend.log.logger import snapredLogger
Expand All @@ -18,7 +18,10 @@ class ApplyNormalizationRecipe(Recipe[Ingredients]):
NUM_BINS = Config["constants.ResampleX.NumberBins"]
LOG_BINNING = True

def mandatoryInputWorkspaces(self):
def allGroceryKeys(self) -> Set[str]:
return {"inputWorkspace", "normalizationWorkspace", "backgroundWorkspace"}

def mandatoryInputWorkspaces(self) -> Set[str]:
return {"inputWorkspace"}

def chopIngredients(self, ingredients: Ingredients):
Expand Down
50 changes: 25 additions & 25 deletions src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Dict
from types import NoneType
from typing import Dict, Set

import numpy as np
from pydantic import BaseModel

from snapred.backend.dao.ingredients.CalculateDiffCalResidualIngredients import (
CalculateDiffCalResidualIngredients as Ingredients,
)
from snapred.backend.log.logger import snapredLogger
from snapred.backend.recipe.algorithm.Utensils import Utensils
from snapred.backend.recipe.Recipe import Recipe
Expand All @@ -18,7 +16,7 @@
outputWorkspace: str


class CalculateDiffCalResidualRecipe(Recipe[Ingredients]):
class CalculateDiffCalResidualRecipe(Recipe[None]):
def __init__(self, utensils: Utensils = None):
if utensils is None:
utensils = Utensils()
Expand All @@ -29,29 +27,31 @@
def logger(self):
return logger

def validateInputs(self, ingredients: Ingredients, groceries: Dict[str, WorkspaceName]):
super().validateInputs(ingredients, groceries)
def allGroceryKeys(self) -> Set[str]:
return {"inputWorkspace", "outputWorkspace", "fitPeaksDiagnosticWorkspace"}

def chopIngredients(self, ingredients: Ingredients) -> None:
"""Receive the ingredients from the recipe."""
self.inputWorkspaceName = ingredients.inputWorkspace
self.outputWorkspaceName = ingredients.outputWorkspace
inputGroupWorkspace = ingredients.fitPeaksDiagnosticWorkspace

fitPeaksGroupWorkspace = self.mantidSnapper.mtd[inputGroupWorkspace]
lastWorkspaceName = fitPeaksGroupWorkspace.getNames()[-1]
self.fitPeaksDiagnosticWorkSpaceName = lastWorkspaceName
def mandatoryInputWorkspaces(self) -> Set[str]:
return {"inputWorkspace", "fitPeaksDiagnosticWorkspace"}

def unbagGroceries(self):
def chopIngredients(self, ingredients: NoneType = None) -> None:
"""Receive the ingredients from the recipe."""
pass

def prep(self, ingredients: Ingredients):
def unbagGroceries(self, groceries: Dict[str, WorkspaceName]):
self.inputWorkspaceName = groceries["inputWorkspace"]
self.outputWorkspaceName = groceries["outputWorkspace"]
diagnosticWSname = groceries["fitPeaksDiagnosticWorkspace"]
diagnosticWorkspace = self.mantidSnapper.mtd[diagnosticWSname]
lastWorkspaceName = diagnosticWorkspace.getNames()[-1]
self.fitPeaksDiagnosticWorkspaceName = lastWorkspaceName

Check warning on line 46 in src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py

View check run for this annotation

Codecov / codecov/patch

src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py#L41-L46

Added lines #L41 - L46 were not covered by tests

def prep(self, ingredients: NoneType, groceries: Dict[str, WorkspaceName]):
"""
Convenience method to prepare the recipe for execution.
"""
self.validateInputs(ingredients, groceries=None)
self.validateInputs(ingredients, groceries)

Check warning on line 52 in src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py

View check run for this annotation

Codecov / codecov/patch

src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py#L52

Added line #L52 was not covered by tests
self.chopIngredients(ingredients)
self.unbagGroceries()
self.unbagGroceries(groceries)

Check warning on line 54 in src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py

View check run for this annotation

Codecov / codecov/patch

src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py#L54

Added line #L54 was not covered by tests
self.stirInputs()
self.queueAlgos()

Expand All @@ -64,14 +64,14 @@
)

# Step 2: Check for overlapping spectra and manage them
fitPeaksWorkspace = self.mantidSnapper.mtd[self.fitPeaksDiagnosticWorkSpaceName]
fitPeaksWorkspace = self.mantidSnapper.mtd[self.fitPeaksDiagnosticWorkspaceName]

Check warning on line 67 in src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py

View check run for this annotation

Codecov / codecov/patch

src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py#L67

Added line #L67 was not covered by tests
numHistograms = fitPeaksWorkspace.getNumberHistograms()
processedSpectra = []
spectrumDict = {}

for i in range(numHistograms):
spectrumId = fitPeaksWorkspace.getSpectrum(i).getSpectrumNo()
singleSpectrumName = f"{self.fitPeaksDiagnosticWorkSpaceName}_spectrum_{spectrumId}"
singleSpectrumName = f"{self.fitPeaksDiagnosticWorkspaceName}_spectrum_{spectrumId}"

Check warning on line 74 in src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py

View check run for this annotation

Codecov / codecov/patch

src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py#L74

Added line #L74 was not covered by tests

# If this spectrum number is already processed, average with existing
if spectrumId in spectrumDict:
Expand All @@ -86,7 +86,7 @@
# Extract spectrum by position
self.mantidSnapper.ExtractSingleSpectrum(
f"Extracting spectrum with SpectrumNumber {spectrumId}...",
InputWorkspace=self.fitPeaksDiagnosticWorkSpaceName,
InputWorkspace=self.fitPeaksDiagnosticWorkspaceName,
OutputWorkspace=singleSpectrumName,
WorkspaceIndex=i,
)
Expand Down Expand Up @@ -128,8 +128,8 @@
# Set the output property to the final residual workspace
self.outputWorkspace = self.mantidSnapper.mtd[self.outputWorkspaceName]

def cook(self, ingredients: Ingredients):
self.prep(ingredients)
def cook(self, ingredients: NoneType, groceries: Dict[str, WorkspaceName]): # noqa ARG002
self.prep(None, groceries)

Check warning on line 132 in src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py

View check run for this annotation

Codecov / codecov/patch

src/snapred/backend/recipe/CalculateDiffCalResidualRecipe.py#L132

Added line #L132 was not covered by tests
self.execute()
return CalculateDiffCalServing(
outputWorkspace=self.outputWorkspaceName,
Expand Down
8 changes: 7 additions & 1 deletion src/snapred/backend/recipe/EffectiveInstrumentRecipe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Set, Tuple

import numpy as np

Expand All @@ -16,6 +16,12 @@

@Singleton
class EffectiveInstrumentRecipe(Recipe[Ingredients]):
def allGroceryKeys(self) -> Set[str]:
return {"inputWorkspace", "outputWorkspace"}

def mandatoryInputWorkspaces(self) -> Set[str]:
return {"inputWorkspace"}

def unbagGroceries(self, groceries: Dict[str, Any]):
self.inputWS = groceries["inputWorkspace"]
self.outputWS = groceries.get("outputWorkspace", groceries["inputWorkspace"])
Expand Down
8 changes: 7 additions & 1 deletion src/snapred/backend/recipe/GenerateFocussedVanadiumRecipe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Set, Tuple

from snapred.backend.dao.ingredients import GenerateFocussedVanadiumIngredients as Ingredients
from snapred.backend.log.logger import snapredLogger
Expand All @@ -23,6 +23,12 @@ class GenerateFocussedVanadiumRecipe(Recipe[Ingredients]):

"""

def allGroceryKeys(self) -> Set[str]:
return {"inputWorkspace", "outputWorkspace"}

def mandatoryInputWorkspaces(self) -> Set[str]:
return {"inputWorkspace"}

def chopIngredients(self, ingredients: Ingredients):
self.smoothingParameter = ingredients.smoothingParameter
self.detectorPeaks = ingredients.detectorPeaks
Expand Down
30 changes: 12 additions & 18 deletions src/snapred/backend/recipe/GroupDiffCalRecipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,6 @@ class GroupDiffCalRecipe(Recipe[Ingredients]):
NOYZE_2_MIN = Config["calibration.fitting.minSignal2Noise"]
MAX_CHI_SQ = Config["constants.GroupDiffractionCalibration.MaxChiSq"]

GROCERIES = {
# NOTE this would be better as a StrEnum, which requires python 3.11
"inputWorkspace",
"groupingWorkspace",
"maskWorkspace",
"outputWorkspace",
"diagnosticWorkspace",
"previousCalibration",
"calibrationTable",
}

def __init__(self, utensils: Utensils = None):
if utensils is None:
utensils = Utensils()
Expand All @@ -51,18 +40,23 @@ def __init__(self, utensils: Utensils = None):
def logger(self):
return logger

def mandatoryInputWorkspaces(self) -> Set[WorkspaceName]:
def allGroceryKeys(self) -> Set[str]:
return {
"inputWorkspace",
"groupingWorkspace",
"maskWorkspace",
"outputWorkspace",
"diagnosticWorkspace",
"previousCalibration",
"calibrationTable",
}

def mandatoryInputWorkspaces(self) -> Set[str]:
return {"inputWorkspace", "groupingWorkspace"}

def validateInputs(self, ingredients: Ingredients, groceries: Dict[str, WorkspaceName]):
super().validateInputs(ingredients, groceries)

# make sure no invalid keys were passed
# NOTE this is for safer refactor, but not necessary for proper functioning
diff = set(groceries.keys()).difference(self.GROCERIES)
if bool(diff):
raise RuntimeError(f"The following invalid keys were found in the input groceries: {diff}")

pixelGroupIDs = ingredients.pixelGroup.groupIDs
groupIDs = [peakList.groupID for peakList in ingredients.groupedPeakLists]
if groupIDs != pixelGroupIDs:
Expand Down
12 changes: 12 additions & 0 deletions src/snapred/backend/recipe/PixelDiffCalRecipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ def __init__(self, utensils: Utensils = None):
def logger(self):
return logger

def allGroceryKeys(self) -> Set[str]:
return {
"inputWorkspace",
"groupingWorkspace",
"calibrationTable",
"maskWorkspace",
"previousCalibration",
# NOTE these are used only in the entire diff cal workflow
"diagnosticWorkspace",
"outputWorkspace",
}

def mandatoryInputWorkspaces(self) -> Set[str]:
return {"inputWorkspace", "groupingWorkspace"}

Expand Down
21 changes: 16 additions & 5 deletions src/snapred/backend/recipe/PreprocessReductionRecipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from snapred.backend.log.logger import snapredLogger
from snapred.backend.recipe.Recipe import Recipe
from snapred.meta.decorators.Singleton import Singleton
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName

logger = snapredLogger.getLogger(__name__)

Expand All @@ -19,7 +18,17 @@ def chopIngredients(self, ingredients: Ingredients):
"""
pass

def mandatoryInputWorkspaces(self) -> Set[WorkspaceName]:
def allGroceryKeys(self) -> Set[str]:
return {
"inputWorkspace",
"backgroundWorkspace",
"groupingWorkspace",
"diffcalWorkspace",
"maskWorkspace",
"outputWorkspace",
}

def mandatoryInputWorkspaces(self) -> Set[str]:
return {"inputWorkspace"}

def unbagGroceries(self, groceries: Dict[str, Any]):
Expand All @@ -46,16 +55,18 @@ def queueAlgos(self):
OutputWorkspace=self.outputWs,
)

if self.maskWs:
if self.maskWs != "":
self.mantidSnapper.MaskDetectorFlags(
"Applying pixel mask...",
MaskWorkspace=self.maskWs,
OutputWorkspace=self.outputWs,
)

if self.diffcalWs:
if self.diffcalWs != "":
self.mantidSnapper.ApplyDiffCal(
"Applying diffcal..", InstrumentWorkspace=self.outputWs, CalibrationWorkspace=self.diffcalWs
"Applying diffcal..",
InstrumentWorkspace=self.outputWs,
CalibrationWorkspace=self.diffcalWs,
)

# convert to tof if needed
Expand Down
5 changes: 4 additions & 1 deletion src/snapred/backend/recipe/ReadWorkspaceMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
class ReadWorkspaceMetadata(Recipe[WorkspaceMetadata]):
TAG_PREFIX = Config["metadata.tagPrefix"]

def mandatoryInputWorkspaces(self) -> Set[WorkspaceName]:
def allGroceryKeys(self) -> Set[str]:
return {"workspace"}

def mandatoryInputWorkspaces(self) -> Set[str]:
return {"workspace"}

def chopIngredients(self, ingredients): # noqa ARG002
Expand Down
7 changes: 5 additions & 2 deletions src/snapred/backend/recipe/RebinFocussedGroupDataRecipe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Tuple
from typing import Any, Dict, Set, Tuple

from snapred.backend.dao.ingredients import RebinFocussedGroupDataIngredients as Ingredients
from snapred.backend.log.logger import snapredLogger
Expand All @@ -17,7 +17,10 @@ class RebinFocussedGroupDataRecipe(Recipe[Ingredients]):
NUM_BINS = Config["constants.ResampleX.NumberBins"]
LOG_BINNING = True

def mandatoryInputWorkspaces(self):
def allGroceryKeys(self) -> Set[str]:
return {"inputWorkspace"}

def mandatoryInputWorkspaces(self) -> Set[str]:
return {"inputWorkspace"}

def chopIngredients(self, ingredients: Ingredients):
Expand Down
22 changes: 18 additions & 4 deletions src/snapred/backend/recipe/Recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,20 @@
Requires: unbagged groceries and chopped ingredients.
"""

@abstractmethod
def allGroceryKeys(self) -> Set[str]:
"""
A set of all possible keys which may be in the grocery dictionary
"""
return set()

# methods which MAY be kept as is

def mandatoryInputWorkspaces(self) -> Set[WorkspaceName]:
def mandatoryInputWorkspaces(self) -> Set[str]:
"""
A list of workspace names corresponding to mandatory inputs
A set of workspace keys corresponding to mandatory inputs
"""
return {}
return set()

Check warning on line 68 in src/snapred/backend/recipe/Recipe.py

View check run for this annotation

Codecov / codecov/patch

src/snapred/backend/recipe/Recipe.py#L68

Added line #L68 was not covered by tests

@classmethod
def Ingredients(cls, **kwargs):
Expand Down Expand Up @@ -104,7 +111,14 @@
else:
logger.info("No ingredients given, skipping ingredient validation")
pass
# ensure all of the given workspaces exist

# make sure no invalid keys were passed
if groceries is not None:
diff = set(groceries.keys()).difference(self.allGroceryKeys())
if bool(diff):
raise ValueError(f"The following invalid keys were found in the input groceries: {diff}")

# ensure all of the mandatory workspaces exist
# NOTE may need to be tweaked to ignore output workspaces...
if groceries is not None:
logger.info(f"Validating the given workspaces: {groceries.values()}")
Expand Down
Loading
Loading