From 6ee82a57ec8f1818fd656de85ae91f5fdbc91a54 Mon Sep 17 00:00:00 2001 From: Michael Walsh <68125095+walshmm@users.noreply.github.com> Date: Mon, 16 Dec 2024 09:49:55 -0500 Subject: [PATCH 1/7] Ewm8253 normcal index being misread (#509) * refactor indexer to not allow possibility of misinterpreting None failing 40 test wip commit all unit tests passing changes found during manual testing of workflows fix remaining unit tests why does git think this file still uses VERSION_DEFAULT add back VERSION_DEFAULT because of arbitrary ci failure? fix versioning integration checks migrate missed spots for datafactoryservice correct version for integration test workspace name fix wng? disallow overwriting default calibration entry disallow overwritting default calibration add debug info for ci, because I cannot easily recreate this locally fix integration tests, added neat summary in case they fail move reduction completion summary note to correct spot update tests and gitmodules extend test coverage catch a couple more lines up that test coverage change refspec? respond to reece's comments * fix rebase * fix failing unit tests * hubris * address the last comments I missed * fix reduction with no initial state + default state * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure the continueAnywayFlags get reset after workflow completiong. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .gitmodules | 1 + pyproject.toml | 3 +- .../backend/dao/indexing/Versioning.py | 108 +++---- .../dao/normalization/NormalizationRecord.py | 12 +- .../request/CreateCalibrationRecordRequest.py | 3 +- .../dao/request/CreateIndexEntryRequest.py | 6 +- .../request/DiffractionCalibrationRequest.py | 4 +- .../dao/request/FarmFreshIngredients.py | 17 +- .../request/LoadCalibrationRecordRequest.py | 9 + .../backend/dao/request/ReductionRequest.py | 9 +- src/snapred/backend/data/DataExportService.py | 10 +- .../backend/data/DataFactoryService.py | 43 +-- src/snapred/backend/data/GroceryService.py | 13 +- src/snapred/backend/data/Indexer.py | 197 ++++++------ src/snapred/backend/data/LocalDataService.py | 79 +++-- .../backend/service/CalibrationService.py | 24 +- .../backend/service/NormalizationService.py | 18 +- .../backend/service/ReductionService.py | 16 +- src/snapred/backend/service/SousChef.py | 2 + .../meta/mantid/WorkspaceNameGenerator.py | 13 +- src/snapred/resources/application.yml | 2 +- src/snapred/ui/workflow/DiffCalWorkflow.py | 6 +- .../ui/workflow/NormalizationWorkflow.py | 6 +- src/snapred/ui/workflow/ReductionWorkflow.py | 7 +- .../ui/workflow/WorkflowImplementer.py | 1 + tests/data/snapred-data | 2 +- tests/integration/test_versions_in_order.py | 44 +-- .../test_workflow_panels_happy_path.py | 139 +++++++- .../ReductionRecord_20240614T130420.json | 1 + .../unit/backend/dao/test_VersionedObject.py | 145 ++++----- .../backend/data/test_DataFactoryService.py | 40 +-- .../unit/backend/data/test_GroceryService.py | 31 ++ tests/unit/backend/data/test_Indexer.py | 300 ++++++++---------- .../backend/data/test_LocalDataService.py | 132 +++++--- .../service/test_CalibrationService.py | 24 +- .../service/test_NormalizationService.py | 15 +- .../backend/service/test_ReductionService.py | 18 +- tests/unit/backend/service/test_SousChef.py | 16 + .../mantid/test_WorkspaceNameGenerator.py | 13 +- tests/unit/meta/test_Decorators.py | 1 + .../test_InitializeStatePresenter.py | 9 +- .../ui/view/test_CalibrationAssessmentView.py | 4 + .../ui/view/test_InitializeStateCheckView.py | 3 + tests/unit/ui/widget/test_Workflow.py | 1 + .../unit/ui/workflow/test_DiffCalWorkflow.py | 3 + .../ui/workflow/test_WorkflowImplementer.py | 3 + tests/util_tests/test_state_helpers.py | 4 +- 47 files changed, 864 insertions(+), 693 deletions(-) create mode 100644 src/snapred/backend/dao/request/LoadCalibrationRecordRequest.py diff --git a/.gitmodules b/.gitmodules index f5c64879f..dcb998528 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "tests/data/snapred-data"] path = tests/data/snapred-data url = https://code.ornl.gov/sns-hfir-scse/infrastructure/test-data/snapred-data.git + branch = main diff --git a/pyproject.toml b/pyproject.toml index 8d77718f3..b04de4928 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,8 @@ markers = [ "integration: mark a test as an integration test", "mount_snap: mark a test as using /SNS/SNAP/ data mount", "golden_data(*, path=None, short_name=None, date=None): mark golden data to use with a test", - "datarepo: mark a test as using snapred-data repo" + "datarepo: mark a test as using snapred-data repo", + "ui: mark a test as a UI test", ] # The following will be overridden by the commandline option "-m integration" addopts = "-m 'not (integration or datarepo)'" diff --git a/src/snapred/backend/dao/indexing/Versioning.py b/src/snapred/backend/dao/indexing/Versioning.py index b29b10463..7657baa00 100644 --- a/src/snapred/backend/dao/indexing/Versioning.py +++ b/src/snapred/backend/dao/indexing/Versioning.py @@ -1,79 +1,51 @@ -from typing import Any, Optional - -from numpy import integer -from pydantic import BaseModel, computed_field, field_serializer +from pydantic import BaseModel, ConfigDict, field_validator from snapred.meta.Config import Config +from snapred.meta.Enum import StrEnum VERSION_START = Config["version.start"] VERSION_NONE_NAME = Config["version.friendlyName.error"] -VERSION_DEFAULT_NAME = Config["version.friendlyName.default"] -# VERSION_DEFAULT is a SNAPRed-internal "magic" integer: -# * it is implicitely set during `Config` initialization. -VERSION_DEFAULT = Config["version.default"] + +class VersionState(StrEnum): + DEFAULT = Config["version.friendlyName.default"] + LATEST = "latest" + NEXT = "next" + + +# I'm not sure why ci is failing without this, it doesn't seem to be used anywhere +VERSION_DEFAULT = VersionState.DEFAULT + +Version = int | VersionState class VersionedObject(BaseModel): # Base class for all versioned DAO - # In pydantic, a leading double underscore activates - # the `__pydantic_private__` feature, which limits the visibility - # of the attribute to the interior scope of its own class. - __version: Optional[int] = None - - @classmethod - def parseVersion(cls, version, *, exclude_none: bool = False, exclude_default: bool = False) -> int | None: - v: int | None - # handle two special cases - if (not exclude_none) and (version is None or version == VERSION_NONE_NAME): - v = None - elif (not exclude_default) and (version == VERSION_DEFAULT_NAME or version == VERSION_DEFAULT): - v = VERSION_DEFAULT - # parse integers - elif isinstance(version, int | integer): - if int(version) >= VERSION_START: - v = int(version) - else: - raise ValueError(f"Given version {version} is smaller than start version {VERSION_START}") - # otherwise this is an error - else: - raise ValueError(f"Cannot initialize version as {version}") - return v - - @classmethod - def writeVersion(cls, version) -> int | str: - v: int | str - if version is None: - v = VERSION_NONE_NAME - elif version == VERSION_DEFAULT: - v = VERSION_DEFAULT_NAME - elif isinstance(version, int | integer): - v = int(version) - else: - raise ValueError("Version is not valid") - return v - - def __init__(self, **kwargs): - version = kwargs.pop("version", None) - super().__init__(**kwargs) - self.__version = self.parseVersion(version) - - @field_serializer("version", check_fields=False, when_used="json") - def write_user_defaults(self, value: Any): # noqa ARG002 - return self.writeVersion(self.__version) - - # NOTE some serialization still using the dict() method - def dict(self, **kwargs): - res = super().dict(**kwargs) - res["version"] = self.writeVersion(res["version"]) - return res - - @computed_field - @property - def version(self) -> int: - return self.__version - - @version.setter - def version(self, v): - self.__version = self.parseVersion(v, exclude_none=True) + version: Version + + @field_validator("version", mode="before") + def validate_version(cls, value: Version) -> Version: + if value in VersionState.values(): + return value + + if isinstance(value, str): + raise ValueError(f"Version must be an int or {VersionState.values()}") + + if value is None: + raise ValueError("Version must be specified") + + if value < VERSION_START: + raise ValueError(f"Version must be greater than {VERSION_START}") + + return value + + # NOTE: This approach was taken because 'field_serializer' was checking against the + # INITIAL value of version for some reason. This is a workaround. + # + def model_dump_json(self, *args, **kwargs): # noqa ARG002 + if self.version in VersionState.values(): + raise ValueError(f"Version {self.version} must be flattened to an int before writing to JSON") + return super().model_dump_json(*args, **kwargs) + + model_config = ConfigDict(use_enum_values=True, validate_assignment=True) diff --git a/src/snapred/backend/dao/normalization/NormalizationRecord.py b/src/snapred/backend/dao/normalization/NormalizationRecord.py index b1aa040fd..a2e6db446 100644 --- a/src/snapred/backend/dao/normalization/NormalizationRecord.py +++ b/src/snapred/backend/dao/normalization/NormalizationRecord.py @@ -1,9 +1,9 @@ from typing import Any, List -from pydantic import field_serializer, field_validator +from pydantic import field_validator from snapred.backend.dao.indexing.Record import Record -from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT, VersionedObject +from snapred.backend.dao.indexing.Versioning import VERSION_START, Version, VersionedObject from snapred.backend.dao.Limit import Limit from snapred.backend.dao.normalization.Normalization import Normalization from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName @@ -31,7 +31,7 @@ class NormalizationRecord(Record, extra="ignore"): smoothingParameter: float # detectorPeaks: List[DetectorPeak] # TODO: need to save this for reference during reduction workspaceNames: List[WorkspaceName] = [] - calibrationVersionUsed: int = VERSION_DEFAULT + calibrationVersionUsed: Version = VERSION_START crystalDBounds: Limit[float] normalizationCalibrantSamplePath: str @@ -44,8 +44,4 @@ def validate_backgroundRunNumber(cls, v: Any) -> Any: @field_validator("calibrationVersionUsed", mode="before") @classmethod def version_is_integer(cls, v: Any) -> Any: - return VersionedObject.parseVersion(v) - - @field_serializer("calibrationVersionUsed", when_used="json") - def write_user_defaults(self, value: Any): # noqa ARG002 - return VersionedObject.writeVersion(self.calibrationVersionUsed) + return VersionedObject(version=v).version diff --git a/src/snapred/backend/dao/request/CreateCalibrationRecordRequest.py b/src/snapred/backend/dao/request/CreateCalibrationRecordRequest.py index c91e94b3f..8060c6538 100644 --- a/src/snapred/backend/dao/request/CreateCalibrationRecordRequest.py +++ b/src/snapred/backend/dao/request/CreateCalibrationRecordRequest.py @@ -5,6 +5,7 @@ from snapred.backend.dao.calibration.Calibration import Calibration from snapred.backend.dao.calibration.FocusGroupMetric import FocusGroupMetric from snapred.backend.dao.CrystallographicInfo import CrystallographicInfo +from snapred.backend.dao.indexing.Versioning import Version, VersionState from snapred.backend.dao.state.PixelGroup import PixelGroup from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName, WorkspaceType @@ -18,7 +19,7 @@ class CreateCalibrationRecordRequest(BaseModel, extra="forbid"): runNumber: str useLiteMode: bool - version: Optional[int] = None + version: Version = VersionState.NEXT calculationParameters: Calibration crystalInfo: CrystallographicInfo pixelGroups: Optional[List[PixelGroup]] = None diff --git a/src/snapred/backend/dao/request/CreateIndexEntryRequest.py b/src/snapred/backend/dao/request/CreateIndexEntryRequest.py index a2be99075..3292afed4 100644 --- a/src/snapred/backend/dao/request/CreateIndexEntryRequest.py +++ b/src/snapred/backend/dao/request/CreateIndexEntryRequest.py @@ -1,7 +1,7 @@ -from typing import Optional - from pydantic import BaseModel +from snapred.backend.dao.indexing.Versioning import Version, VersionState + class CreateIndexEntryRequest(BaseModel): """ @@ -10,7 +10,7 @@ class CreateIndexEntryRequest(BaseModel): runNumber: str useLiteMode: bool - version: Optional[int] = None + version: Version = VersionState.NEXT comments: str author: str appliesTo: str diff --git a/src/snapred/backend/dao/request/DiffractionCalibrationRequest.py b/src/snapred/backend/dao/request/DiffractionCalibrationRequest.py index d773f13da..17b0ce27a 100644 --- a/src/snapred/backend/dao/request/DiffractionCalibrationRequest.py +++ b/src/snapred/backend/dao/request/DiffractionCalibrationRequest.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, field_validator -from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT +from snapred.backend.dao.indexing.Versioning import Version, VersionState from snapred.backend.dao.Limit import Pair from snapred.backend.dao.state.FocusGroup import FocusGroup from snapred.backend.error.ContinueWarning import ContinueWarning @@ -40,7 +40,7 @@ class DiffractionCalibrationRequest(BaseModel, extra="forbid"): continueFlags: Optional[ContinueWarning.Type] = ContinueWarning.Type.UNSET - startingTableVersion: int = VERSION_DEFAULT + startingTableVersion: Version = VersionState.DEFAULT @field_validator("fwhmMultipliers", mode="before") @classmethod diff --git a/src/snapred/backend/dao/request/FarmFreshIngredients.py b/src/snapred/backend/dao/request/FarmFreshIngredients.py index d61f43add..e0fecbea2 100644 --- a/src/snapred/backend/dao/request/FarmFreshIngredients.py +++ b/src/snapred/backend/dao/request/FarmFreshIngredients.py @@ -2,13 +2,14 @@ from pydantic import BaseModel, ConfigDict, ValidationError, field_validator, model_validator +from snapred.backend.dao.indexing.Versioning import Version, VersionState from snapred.backend.dao.Limit import Limit, Pair from snapred.backend.dao.state import FocusGroup from snapred.meta.Config import Config from snapred.meta.mantid.AllowedPeakTypes import SymmetricPeakEnum # TODO: this declaration is duplicated in `ReductionRequest`. -Versions = NamedTuple("Versions", [("calibration", Optional[int]), ("normalization", Optional[int])]) +Versions = NamedTuple("Versions", [("calibration", Version), ("normalization", Version)]) class FarmFreshIngredients(BaseModel): @@ -21,19 +22,19 @@ class FarmFreshIngredients(BaseModel): runNumber: str - versions: Versions = Versions(None, None) + versions: Versions = Versions(VersionState.LATEST, VersionState.LATEST) # allow 'versions' to be accessed as a single version, # or, to be accessed ambiguously @property - def version(self) -> Optional[int]: + def version(self) -> Version: if self.versions.calibration is not None and self.versions.normalization is not None: raise RuntimeError("accessing 'version' property when 'versions' is non-singular") return self.versions[0] @version.setter - def version(self, v: Optional[int]): - self.versions = (v, None) + def version(self, v: Version): + self.versions = Versions(v, v) useLiteMode: bool @@ -83,6 +84,10 @@ def focusGroup(self, fg: FocusGroup): def validate_versions(cls, v) -> Versions: if not isinstance(v, Versions): v = Versions(v) + if v.calibration is None: + raise ValueError("Calibration version must be specified") + if v.normalization is None: + raise ValueError("Normalization version must be specified") return v @field_validator("crystalDBounds", mode="before") @@ -119,4 +124,4 @@ def validate_focusGroups(cls, v: Any): del v["focusGroup"] return v - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", validate_assignment=True) diff --git a/src/snapred/backend/dao/request/LoadCalibrationRecordRequest.py b/src/snapred/backend/dao/request/LoadCalibrationRecordRequest.py new file mode 100644 index 000000000..ab7f0c180 --- /dev/null +++ b/src/snapred/backend/dao/request/LoadCalibrationRecordRequest.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +from snapred.backend.dao import RunConfig +from snapred.backend.dao.indexing.Versioning import Version + + +class LoadCalibrationRecordRequest(BaseModel): + runConfig: RunConfig + version: Version diff --git a/src/snapred/backend/dao/request/ReductionRequest.py b/src/snapred/backend/dao/request/ReductionRequest.py index c8d44822c..4ca60f704 100644 --- a/src/snapred/backend/dao/request/ReductionRequest.py +++ b/src/snapred/backend/dao/request/ReductionRequest.py @@ -2,13 +2,14 @@ from pydantic import BaseModel, ConfigDict, field_validator +from snapred.backend.dao.indexing.Versioning import Version, VersionState from snapred.backend.dao.ingredients import ArtificialNormalizationIngredients from snapred.backend.dao.state.FocusGroup import FocusGroup from snapred.backend.error.ContinueWarning import ContinueWarning from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceNameGenerator as wng -Versions = NamedTuple("Versions", [("calibration", Optional[int]), ("normalization", Optional[int])]) +Versions = NamedTuple("Versions", [("calibration", Version), ("normalization", Version)]) class ReductionRequest(BaseModel): @@ -22,7 +23,7 @@ class ReductionRequest(BaseModel): # Calibration and normalization versions: # `None` => - versions: Versions = Versions(None, None) + versions: Versions = Versions(VersionState.LATEST, VersionState.LATEST) pixelMasks: List[WorkspaceName] = [] artificialNormalizationIngredients: Optional[ArtificialNormalizationIngredients] = None @@ -37,6 +38,10 @@ def validate_versions(cls, v) -> Versions: if not isinstance(v, Tuple): raise ValueError("'versions' must be a tuple: '(, )'") v = Versions(v) + if v.calibration is None: + raise ValueError("Calibration version must be specified") + if v.normalization is None: + raise ValueError("Normalization version must be specified") return v model_config = ConfigDict( diff --git a/src/snapred/backend/data/DataExportService.py b/src/snapred/backend/data/DataExportService.py index ae53f0eef..01bd654d3 100644 --- a/src/snapred/backend/data/DataExportService.py +++ b/src/snapred/backend/data/DataExportService.py @@ -1,6 +1,6 @@ import time from pathlib import Path -from typing import Tuple +from typing import Optional, Tuple from pydantic import validate_call @@ -64,11 +64,11 @@ def exportCalibrationIndexEntry(self, entry: IndexEntry): """ self.dataService.writeCalibrationIndexEntry(entry) - def exportCalibrationRecord(self, record: CalibrationRecord): + def exportCalibrationRecord(self, record: CalibrationRecord, entry: Optional[IndexEntry] = None): """ Record must have correct version set. """ - self.dataService.writeCalibrationRecord(record) + self.dataService.writeCalibrationRecord(record, entry) def exportCalibrationWorkspaces(self, record: CalibrationRecord): """ @@ -94,11 +94,11 @@ def exportNormalizationIndexEntry(self, entry: IndexEntry): """ self.dataService.writeNormalizationIndexEntry(entry) - def exportNormalizationRecord(self, record: NormalizationRecord): + def exportNormalizationRecord(self, record: NormalizationRecord, entry: Optional[IndexEntry] = None): """ Record must have correct version set. """ - self.dataService.writeNormalizationRecord(record) + self.dataService.writeNormalizationRecord(record, entry) def exportNormalizationWorkspaces(self, record: NormalizationRecord): """ diff --git a/src/snapred/backend/data/DataFactoryService.py b/src/snapred/backend/data/DataFactoryService.py index dcf28a8f5..776c046d0 100644 --- a/src/snapred/backend/data/DataFactoryService.py +++ b/src/snapred/backend/data/DataFactoryService.py @@ -1,11 +1,12 @@ import os from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List from pydantic import validate_call from snapred.backend.dao.calibration.CalibrationRecord import CalibrationRecord from snapred.backend.dao.indexing.IndexEntry import IndexEntry +from snapred.backend.dao.indexing.Versioning import Version, VersionState from snapred.backend.dao.InstrumentConfig import InstrumentConfig from snapred.backend.dao.normalization.NormalizationRecord import NormalizationRecord from snapred.backend.dao.reduction import ReductionRecord @@ -81,7 +82,7 @@ def calibrationExists(self, runId: str, useLiteMode: bool): return self.lookupService.calibrationExists(runId, useLiteMode) @validate_call - def getCalibrationDataPath(self, runId: str, useLiteMode: bool, version: int): + def getCalibrationDataPath(self, runId: str, useLiteMode: bool, version: Version): return self.lookupService.calibrationIndexer(runId, useLiteMode).versionPath(version) def checkCalibrationStateExists(self, runId: str): @@ -102,28 +103,22 @@ def getCalibrationIndex(self, runId: str, useLiteMode: bool) -> List[IndexEntry] return self.lookupService.calibrationIndexer(runId, useLiteMode).getIndex() @validate_call - def getCalibrationRecord(self, runId: str, useLiteMode: bool, version: Optional[int] = None) -> CalibrationRecord: + def getCalibrationRecord( + self, runId: str, useLiteMode: bool, version: Version = VersionState.LATEST + ) -> CalibrationRecord: """ If no version is passed, will use the latest version applicable to runId """ return self.lookupService.readCalibrationRecord(runId, useLiteMode, version) @validate_call - def getCalibrationDataWorkspace(self, runId: str, useLiteMode: bool, version: int, name: str): + def getCalibrationDataWorkspace(self, runId: str, useLiteMode: bool, version: Version, name: str): path = self.lookupService.calibrationIndexer(runId, useLiteMode).versionPath(version) return self.groceryService.fetchWorkspace(os.path.join(path, name) + ".nxs", name) @validate_call - def getThisOrCurrentCalibrationVersion(self, runId: str, useLiteMode: bool, version: Optional[int] = None): - return self.lookupService.calibrationIndexer(runId, useLiteMode).thisOrCurrentVersion(version) - - @validate_call - def getThisOrNextCalibrationVersion(self, runId: str, useLiteMode: bool, version: Optional[int] = None): - return self.lookupService.calibrationIndexer(runId, useLiteMode).thisOrNextVersion(version) - - @validate_call - def getThisOrLatestCalibrationVersion(self, runId: str, useLiteMode: bool, version: Optional[int] = None): - return self.lookupService.calibrationIndexer(runId, useLiteMode).thisOrLatestApplicableVersion(runId, version) + def getLatestApplicableCalibrationVersion(self, runId: str, useLiteMode: bool): + return self.lookupService.calibrationIndexer(runId, useLiteMode).latestApplicableVersion(runId) ##### NORMALIZATION METHODS ##### @@ -131,7 +126,7 @@ def normalizationExists(self, runId: str, useLiteMode: bool): return self.lookupService.normalizationExists(runId, useLiteMode) @validate_call - def getNormalizationDataPath(self, runId: str, useLiteMode: bool, version: int): + def getNormalizationDataPath(self, runId: str, useLiteMode: bool, version: Version): return self.lookupService.normalizationIndexer(runId, useLiteMode).versionPath(version) def createNormalizationIndexEntry(self, request: NormalizationExportRequest) -> IndexEntry: @@ -149,28 +144,22 @@ def getNormalizationIndex(self, runId: str, useLiteMode: bool) -> List[IndexEntr return self.lookupService.normalizationIndexer(runId, useLiteMode).getIndex() @validate_call - def getNormalizationRecord(self, runId: str, useLiteMode: bool, version: Optional[int] = None): + def getNormalizationRecord( + self, runId: str, useLiteMode: bool, version: Version = VersionState.LATEST + ) -> NormalizationRecord: """ If no version is passed, will use the latest version applicable to runId """ return self.lookupService.readNormalizationRecord(runId, useLiteMode, version) @validate_call - def getNormalizationDataWorkspace(self, runId: str, useLiteMode: bool, version: int, name: str): + def getNormalizationDataWorkspace(self, runId: str, useLiteMode: bool, version: Version, name: str): path = self.getNormalizationDataPath(runId, useLiteMode, version) return self.groceryService.fetchWorkspace(os.path.join(path, name) + ".nxs", name) @validate_call - def getThisOrCurrentNormalizationVersion(self, runId: str, useLiteMode: bool, version: Optional[int] = None): - return self.lookupService.normalizationIndexer(runId, useLiteMode).thisOrCurrentVersion(version) - - @validate_call - def getThisOrNextNormalizationVersion(self, runId: str, useLiteMode: bool, version: Optional[int] = None): - return self.lookupService.normalizationIndexer(runId, useLiteMode).thisOrNextVersion(version) - - @validate_call - def getThisOrLatestNormalizationVersion(self, runId: str, useLiteMode: bool, version: Optional[int] = None): - return self.lookupService.normalizationIndexer(runId, useLiteMode).thisOrLatestApplicableVersion(runId, version) + def getLatestApplicableNormalizationVersion(self, runId: str, useLiteMode: bool): + return self.lookupService.normalizationIndexer(runId, useLiteMode).latestApplicableVersion(runId) ##### REDUCTION METHODS ##### diff --git a/src/snapred/backend/data/GroceryService.py b/src/snapred/backend/data/GroceryService.py index 7279b7799..258f83f58 100644 --- a/src/snapred/backend/data/GroceryService.py +++ b/src/snapred/backend/data/GroceryService.py @@ -14,7 +14,7 @@ ) from pydantic import validate_call -from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT +from snapred.backend.dao.indexing.Versioning import VERSION_START, Version, VersionState from snapred.backend.dao.ingredients import GroceryListItem from snapred.backend.dao.state import DetectorState from snapred.backend.dao.WorkspaceMetadata import WorkspaceMetadata @@ -236,9 +236,10 @@ def _createDiffcalTableFilepathFromWsName( calibrationDataPath = self._getCalibrationDataPath(runNumber, useLiteMode, version) expectedWsName = self.createDiffcalTableWorkspaceName(runNumber, useLiteMode, version) if wsName != expectedWsName: + record = self.dataService.calibrationIndexer(runNumber, useLiteMode).readRecord(version) raise ValueError( f"Workspace name {wsName} does not match the expected diffcal table workspace name for run {runNumber}", - f"(i.e. {expectedWsName})", + f"(i.e. {expectedWsName}), debug info: {record.model_dump_json(indent=4)}, path: {calibrationDataPath}", ) return str(calibrationDataPath / (wsName + self.diffcalTableFileExtension)) @@ -333,14 +334,14 @@ def createDiffcalTableWorkspaceName( self, runNumber: str, useLiteMode: bool, # noqa: ARG002 - version: Optional[int], + version: Optional[Version], ) -> WorkspaceName: """ - NOTE: This method will IGNORE runNumber if the provided version is VERSION_DEFAULT + NOTE: This method will IGNORE runNumber if the provided version is VersionState.DEFAULT """ wsName = wng.diffCalTable().runNumber(runNumber).version(version).build() - if version == VERSION_DEFAULT: - wsName = wsName = wng.diffCalTable().runNumber("default").version(VERSION_DEFAULT).build() + if version in [VersionState.DEFAULT, VERSION_START]: + wsName = wng.diffCalTable().runNumber("default").version(VersionState.DEFAULT).build() return wsName @validate_call diff --git a/src/snapred/backend/data/Indexer.py b/src/snapred/backend/data/Indexer.py index e6db9478c..aa908c235 100644 --- a/src/snapred/backend/data/Indexer.py +++ b/src/snapred/backend/data/Indexer.py @@ -1,7 +1,7 @@ import os import sys from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List from pydantic import validate_call @@ -10,12 +10,7 @@ from snapred.backend.dao.indexing.CalculationParameters import CalculationParameters from snapred.backend.dao.indexing.IndexEntry import IndexEntry from snapred.backend.dao.indexing.Record import Record -from snapred.backend.dao.indexing.Versioning import ( - VERSION_DEFAULT, - VERSION_DEFAULT_NAME, - VERSION_START, - VersionedObject, -) +from snapred.backend.dao.indexing.Versioning import VERSION_START, Version, VersionedObject, VersionState from snapred.backend.dao.normalization.Normalization import Normalization from snapred.backend.dao.normalization.NormalizationRecord import NormalizationRecord from snapred.backend.dao.reduction.ReductionRecord import ReductionRecord @@ -101,18 +96,20 @@ def readDirectoryList(self): if os.path.isdir(fname): version = str(fname).split("_")[-1] # Warning: order matters here: - # check VERSION_DEFAULT_NAME _before_ the `isdigit` check. - if str(version) == str(VERSION_DEFAULT_NAME): - version = VERSION_DEFAULT + # check VersionState.DEFAULT _before_ the `isdigit` check. + if version in [VersionState.DEFAULT, self.defaultVersion()]: + version = self.defaultVersion() elif version.isdigit(): version = int(version) + else: + logger.warning(f"Invalid version in directory: {version}") + continue versions.add(version) return versions def reconcileIndexToFiles(self): self.dirVersions = self.readDirectoryList() indexVersions = set(self.index.keys()) - # if a directory has no entry in the index, warn missingEntries = self.dirVersions.difference(indexVersions) if len(missingEntries) > 0: @@ -129,7 +126,12 @@ def reconcileIndexToFiles(self): f"The following records were expected, but not available on disk: {missingRecords}" ) else: - logger.warn(f"The following records were expected, but not available on disk: {missingRecords}") + logger.warning( + ( + f"The following records were expected, but not available on disk: {missingRecords}", + "\n Please contact your IS or CIS about these missing records.", + ) + ) # take the set of versions common to both commonVersions = self.dirVersions & indexVersions @@ -145,7 +147,7 @@ def defaultVersion(self) -> int: """ The version number to use for default states. """ - return VERSION_DEFAULT + return VERSION_START def currentVersion(self) -> int: """ @@ -176,8 +178,8 @@ def latestApplicableVersion(self, runNumber: str) -> int: elif len(relevantEntries) == 1: version = relevantEntries[0].version else: - if VERSION_DEFAULT in self.index: - relevantEntries.remove(self.index[VERSION_DEFAULT]) + if self.defaultVersion() in self.index: + relevantEntries.remove(self.index[self.defaultVersion()]) version = relevantEntries[-1].version return version @@ -186,65 +188,29 @@ def nextVersion(self) -> int: A new version number to use for saving calibration records. """ - version = None - # if the index and directories are in sync, the next version is one past them - if set(self.index.keys()) == self.dirVersions: - # remove the default version - dirVersions = [x for x in self.dirVersions if x != VERSION_DEFAULT] - # if nothing is left, the next is the start - if len(dirVersions) == 0: - version = VERSION_START - # otherwise, the next is max version + 1 - else: - version = max(dirVersions) + 1 - # if the index and directory are out of sync, find the largest in both sets - else: - # get the elements particular to each set -- the max of these is the next version - indexSet = set(self.index.keys()) - diffAB = indexSet.difference(self.dirVersions) - diffBA = self.dirVersions.difference(indexSet) - # if diffAB is nullset, diffBA has one more member -- that is next - if diffAB == set(): - version = list(diffBA)[0] - # if diffBA is nullset, diffAB has one more member -- that is next - elif diffBA == set(): - version = list(diffAB)[0] - # otherwise find the max of both differences and return that - else: - indexVersion = max(diffAB) - directoryVersion = max(diffBA) - version = max(indexVersion, directoryVersion) - - return version - - @validate_call - def thisOrCurrentVersion(self, version: Optional[int]): - if self.isValidVersion(version): - return version - else: - return self.currentVersion() - - @validate_call - def thisOrNextVersion(self, version: Optional[int]): - if self.isValidVersion(version): - return version - else: - return self.nextVersion() + if set(self.index.keys()) != self.dirVersions: + self.reconcileIndexToFiles() - @validate_call - def thisOrLatestApplicableVersion(self, runNumber: str, version: Optional[int]): - if self.isValidVersion(version) and self._isApplicableEntry(self.index[version], runNumber): - return version + if self.currentVersion() is None: + return self.defaultVersion() else: - return self.latestApplicableVersion(runNumber) + return self.currentVersion() + 1 - def isValidVersion(self, version): + def validateVersion(self, version): try: - VersionedObject.parseVersion(version, exclude_none=True) + VersionedObject(version=version) return True except ValueError: - return False + # This error would only ever result from a software bug. + # Saving/Loading/Refering to erroneous "current" versions just serves to obfuscate the error. + raise ValueError( + ( + f"The indexer has encountered an invalid version: {version}.", + "This is a software error. Please report this to your IS or CIS", + "so it may be patched.", + ) + ) ## VERSION COMPARISON METHODS ## @@ -277,24 +243,20 @@ def indexPath(self): """ return self.rootDirectory / f"{self.indexerType}Index.json" - def recordPath(self, version: Optional[int] = None): + def recordPath(self, version: int): """ Path to a specific version of a calculation record """ return self.versionPath(version) / f"{self.indexerType}Record.json" - def parametersPath(self, version: Optional[int] = None): + def parametersPath(self, version: int): """ Path to a specific version of calculation parameters """ return self.versionPath(version) / f"{self.indexerType}Parameters.json" - @validate_call - def versionPath(self, version: Optional[int] = None) -> Path: - if version is None: - version = VERSION_START - else: - version = self.thisOrCurrentVersion(version) + def versionPath(self, version: int) -> Path: + self.validateVersion(version) return self.rootDirectory / wnvf.pathVersion(version) def currentPath(self) -> Path: @@ -304,14 +266,14 @@ def currentPath(self) -> Path: """ return self.versionPath(self.currentVersion()) - def latestApplicablePath(self, runNumber: str) -> Path: + def getLatestApplicablePath(self, runNumber: str) -> Path: return self.versionPath(self.latestApplicableVersion(runNumber)) ## INDEX MANIPULATION METHODS ## def createIndexEntry(self, *, version, **other_arguments): return IndexEntry( - version=self.thisOrNextVersion(version), + version=self._flattenVersion(version), **other_arguments, ) @@ -321,7 +283,7 @@ def getIndex(self) -> List[IndexEntry]: # remove the default version, if it exists res = self.index.copy() - res.pop(VERSION_DEFAULT, None) + res.pop(self.defaultVersion(), None) return list(res.values()) def readIndex(self) -> Dict[int, IndexEntry]: @@ -342,9 +304,7 @@ def addIndexEntry(self, entry: IndexEntry): Will save at the version on the index entry. If the version is invalid, will throw an error and refuse to save. """ - if not self.isValidVersion(entry.version): - raise RuntimeError(f"Invalid version {entry.version} on index entry. Save failed.") - + entry.version = self._flattenVersion(entry.version) self.index[entry.version] = entry self.writeIndex() @@ -352,58 +312,107 @@ def addIndexEntry(self, entry: IndexEntry): def createRecord(self, *, version, **other_arguments): record = RECORD_TYPE[self.indexerType]( - version=self.thisOrNextVersion(version), + version=self._flattenVersion(version), **other_arguments, ) record.calculationParameters.version = record.version return record - def _determineRecordType(self, version: Optional[int] = None): - version = self.thisOrCurrentVersion(version) + def _determineRecordType(self, version: int): recordType = None - if version == VERSION_DEFAULT: + if version == self.defaultVersion(): recordType = DEFAULT_RECORD_TYPE.get(self.indexerType, None) if recordType is None: recordType = RECORD_TYPE[self.indexerType] return recordType - def readRecord(self, version: Optional[int] = None) -> Record: + def readRecord(self, version: int) -> Record: """ If no version given, defaults to current version """ - version = self.thisOrCurrentVersion(version) filePath = self.recordPath(version) record = None if filePath.exists(): record = parse_file_as(self._determineRecordType(version), filePath) + else: + raise FileNotFoundError( + f"No record found at {filePath} for version {version}, latest version is {self.currentVersion()}" + ) return record + def _flattenVersion(self, version: Version): + """ + Converts a version to an int. + This should only ever be used on write, + converting VersionState to a version that doesnt exist. + i.e. next, or default before state initialization. + """ + flattenedVersion = None + if version == VersionState.DEFAULT: + flattenedVersion = self.defaultVersion() + elif version == VersionState.NEXT: + flattenedVersion = self.nextVersion() + elif isinstance(version, int): + flattenedVersion = version + else: + acceptableVersionShorthands = [VersionState.DEFAULT, VersionState.NEXT] + raise ValueError(f"Version must be an int or {[acceptableVersionShorthands]}, not {version}") + + if flattenedVersion is None: + raise ValueError( + f"No available versions found during lookup using: " + f"v={version}, index={self.index}, dir={self.dirVersions}" + ) + return flattenedVersion + + def versionExists(self, version: Version): + return self._flattenVersion(version) in self.index + + def writeNewVersion(self, record: Record, entry: IndexEntry): + """ + Coupled write of a record and an index entry. + As required for new records. + """ + if self.versionExists(record.version): + raise ValueError(f"Version {record.version} already exists in index, please write a new version.") + + if entry.appliesTo is None: + entry.appliesTo = ">=" + record.runNumber + + self.addIndexEntry(entry) + # make sure they flatten to the same value. + record.version = entry.version + self.writeRecord(record) + def writeRecord(self, record: Record): """ Will save at the version on the record. If the version is invalid, will throw an error and refuse to save. """ - if not self.isValidVersion(record.version): - raise RuntimeError(f"Invalid version {record.version} on record. Save failed.") + record.version = self._flattenVersion(record.version) + + if not self.versionExists(record.version): + raise ValueError(f"Version {record.version} not found in index, please write an index entry first.") filePath = self.recordPath(record.version) filePath.parent.mkdir(parents=True, exist_ok=True) + write_model_pretty(record, filePath) + self.dirVersions.add(record.version) ## STATE PARAMETER READ / WRITE METHODS ## def createParameters(self, *, version, **other_arguments) -> CalculationParameters: return PARAMS_TYPE[self.indexerType]( - version=self.thisOrNextVersion(version), + version=self._flattenVersion(version), **other_arguments, ) - def readParameters(self, version: Optional[int] = None) -> CalculationParameters: + def readParameters(self, version: Version) -> CalculationParameters: """ If no version given, defaults to current version """ - version = self.thisOrCurrentVersion(version) filePath = self.parametersPath(version) parameters = None if filePath.exists(): @@ -419,9 +428,7 @@ def writeParameters(self, parameters: CalculationParameters): Will save at the version on the calculation parameters. If the version is invalid, will throw an error and refuse to save. """ - if not self.isValidVersion(parameters.version): - raise RuntimeError(f"Invalid version {parameters.version} on calculation parameters. Save failed.") - + parameters.version = self._flattenVersion(parameters.version) parametersPath = self.parametersPath(parameters.version) if parametersPath.exists(): logger.warn(f"Overwriting {self.indexerType} parameters at {parametersPath}") diff --git a/src/snapred/backend/data/LocalDataService.py b/src/snapred/backend/data/LocalDataService.py index 2de734071..4ef85944a 100644 --- a/src/snapred/backend/data/LocalDataService.py +++ b/src/snapred/backend/data/LocalDataService.py @@ -27,7 +27,7 @@ ) from snapred.backend.dao.calibration import Calibration, CalibrationDefaultRecord, CalibrationRecord from snapred.backend.dao.indexing.IndexEntry import IndexEntry -from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT +from snapred.backend.dao.indexing.Versioning import Version, VersionState from snapred.backend.dao.Limit import Limit, Pair from snapred.backend.dao.normalization import Normalization, NormalizationRecord from snapred.backend.dao.reduction import ReductionRecord @@ -142,7 +142,9 @@ def _readInstrumentParameters(self) -> Dict[str, Any]: raise _createFileNotFoundError("Instrument configuration file", self.instrumentConfigPath) from e def readStateConfig(self, runId: str, useLiteMode: bool) -> StateConfig: - diffCalibration = self.calibrationIndexer(runId, useLiteMode).readParameters() + indexer = self.calibrationIndexer(runId, useLiteMode) + version = indexer.latestApplicableVersion(runId) + diffCalibration = indexer.readParameters(version) stateId = str(diffCalibration.instrumentState.id) # Read the grouping-schema map associated with this `StateConfig`. @@ -481,28 +483,37 @@ def writeNormalizationIndexEntry(self, entry: IndexEntry): def createNormalizationIndexEntry(self, request: CreateIndexEntryRequest) -> IndexEntry: indexer = self.normalizationIndexer(request.runNumber, request.useLiteMode) - return indexer.createIndexEntry(**request.model_dump()) + entryParams = request.model_dump() + entryParams["version"] = entryParams.get("version", VersionState.NEXT) + if entryParams["version"] is None: + entryParams["version"] = VersionState.NEXT + return indexer.createIndexEntry(**entryParams) def createNormalizationRecord(self, request: CreateNormalizationRecordRequest) -> NormalizationRecord: indexer = self.normalizationIndexer(request.runNumber, request.useLiteMode) return indexer.createRecord(**request.model_dump()) def normalizationExists(self, runId: str, useLiteMode: bool) -> bool: - version = self.normalizationIndexer(runId, useLiteMode).currentVersion() + version = self.normalizationIndexer(runId, useLiteMode).latestApplicableVersion(runId) return version is not None @validate_call - def readNormalizationRecord(self, runId: str, useLiteMode: bool, version: Optional[int] = None): + def readNormalizationRecord(self, runId: str, useLiteMode: bool, version: Version = VersionState.LATEST): """ Will return a normalization record for the given version. If no version given, will choose the latest applicable version from the index. """ indexer = self.normalizationIndexer(runId, useLiteMode) - if version is None: + if version is VersionState.LATEST: version = indexer.latestApplicableVersion(runId) - return indexer.readRecord(version) + record = None + if version is not None: + record = indexer.readRecord(version) + logger.info(indexer.index) + logger.info(f"latest applicable version: {version} for runId: {runId} ") + return record - def writeNormalizationRecord(self, record: NormalizationRecord): + def writeNormalizationRecord(self, record: NormalizationRecord, entry: Optional[IndexEntry] = None): """ Persists a `NormalizationRecord` to either a new version folder, or overwrites a specific version. Record must be set with correct version. @@ -511,7 +522,10 @@ def writeNormalizationRecord(self, record: NormalizationRecord): indexer = self.normalizationIndexer(record.runNumber, record.useLiteMode) # write the record to file - indexer.writeRecord(record) + if entry is None: + indexer.writeRecord(record) + else: + indexer.writeNewVersion(record, entry) # separately write the normalization state indexer.writeParameters(record.calculationParameters) @@ -542,22 +556,25 @@ def createCalibrationRecord(self, request: CreateCalibrationRecordRequest) -> Ca return indexer.createRecord(**request.model_dump()) def calibrationExists(self, runId: str, useLiteMode: bool) -> bool: - version = self.calibrationIndexer(runId, useLiteMode).currentVersion() + version = self.calibrationIndexer(runId, useLiteMode).latestApplicableVersion(runId) return version is not None @validate_call - def readCalibrationRecord(self, runId: str, useLiteMode: bool, version: Optional[int] = None): + def readCalibrationRecord(self, runId: str, useLiteMode: bool, version: Version = VersionState.LATEST): """ Will return a calibration record for the given version. If no version given, will choose the latest applicable version from the index. """ indexer = self.calibrationIndexer(runId, useLiteMode) - if version is None: - # NOTE Indexer.readRecord defaults to currentVersion + if version is VersionState.LATEST: version = indexer.latestApplicableVersion(runId) - return indexer.readRecord(version) + record = None + if version is not None: + record = indexer.readRecord(version) + + return record - def writeCalibrationRecord(self, record: CalibrationRecord): + def writeCalibrationRecord(self, record: CalibrationRecord, entry: Optional[IndexEntry] = None): """ Persists a `CalibrationRecord` to either a new version folder, or overwrite a specific version. Record must be set with correct version. @@ -565,8 +582,12 @@ def writeCalibrationRecord(self, record: CalibrationRecord): """ indexer = self.calibrationIndexer(record.runNumber, record.useLiteMode) - # write record to file - indexer.writeRecord(record) + if entry is None: + # write record to file + indexer.writeRecord(record) + else: + # write record to file + indexer.writeNewVersion(record, entry) # separately write the calibration state indexer.writeParameters(record.calculationParameters) @@ -808,7 +829,7 @@ def readCifFilePath(self, sampleId: str): ##### READ / WRITE STATE METHODS ##### @validate_call - def readCalibrationState(self, runId: str, useLiteMode: bool, version: Optional[int] = None): + def readCalibrationState(self, runId: str, useLiteMode: bool, version: Optional[Version] = VersionState.LATEST): if not self.calibrationExists(runId, useLiteMode): if self._hasWritePermissionsCalibrationStateRoot(): raise RecoverableException.stateUninitialized(runId, useLiteMode) @@ -820,16 +841,19 @@ def readCalibrationState(self, runId: str, useLiteMode: bool, version: Optional[ indexer = self.calibrationIndexer(runId, useLiteMode) # NOTE if we prefer latest version in index, uncomment below - # if version is None: - # version = indexer.latestApplicableVersion(runId) - return indexer.readParameters(version) + parameters = None + if version is VersionState.LATEST: + version = indexer.latestApplicableVersion(runId) + if version is not None: + parameters = indexer.readParameters(version) + + return parameters @validate_call - def readNormalizationState(self, runId: str, useLiteMode: bool, version: Optional[int] = None): + def readNormalizationState(self, runId: str, useLiteMode: bool, version: Version = VersionState.LATEST): indexer = self.normalizationIndexer(runId, useLiteMode) - # NOTE if we prefer latest version in index, uncomment below - # if version is None: - # version = indexer.latestApplicableVersion(runId) + if version is VersionState.LATEST: + version = indexer.latestApplicableVersion(runId) return indexer.readParameters(version) def writeCalibrationState(self, calibration: Calibration): @@ -977,9 +1001,9 @@ def initializeState(self, runId: str, useLiteMode: bool, name: str = None): self._prepareStateRoot(stateId) # now save default versions of files in both lite and native resolution directories - version = VERSION_DEFAULT for liteMode in [True, False]: indexer = self.calibrationIndexer(runId, liteMode) + version = indexer.defaultVersion() calibration = indexer.createParameters( instrumentState=instrumentState, name=name, @@ -1010,9 +1034,8 @@ def initializeState(self, runId: str, useLiteMode: bool, name: str = None): comments="The default configuration when loading StateConfig if none other is found", ) # write the calibration state - indexer.writeRecord(record) + indexer.writeNewVersion(record, entry) indexer.writeParameters(record.calculationParameters) - indexer.addIndexEntry(entry) # write the default diffcal table self._writeDefaultDiffCalTable(runId, liteMode) diff --git a/src/snapred/backend/service/CalibrationService.py b/src/snapred/backend/service/CalibrationService.py index dd23c4e05..f4c75489f 100644 --- a/src/snapred/backend/service/CalibrationService.py +++ b/src/snapred/backend/service/CalibrationService.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import pydantic @@ -9,6 +9,7 @@ FocusGroupMetric, ) from snapred.backend.dao.indexing import IndexEntry +from snapred.backend.dao.indexing.Versioning import VERSION_START, VersionState from snapred.backend.dao.ingredients import ( CalculateDiffCalResidualIngredients, CalibrationMetricsWorkspaceIngredients, @@ -29,6 +30,7 @@ FocusSpectraRequest, HasStateRequest, InitializeStateRequest, + LoadCalibrationRecordRequest, MatchRunsRequest, SimpleDiffCalRequest, ) @@ -134,6 +136,12 @@ def prepDiffractionCalibrationIngredients( @FromString def fetchDiffractionCalibrationGroceries(self, request: DiffractionCalibrationRequest) -> Dict[str, str]: # groceries + + # TODO: It would be nice for groceryclerk to be smart enough to flatten versions + # However I will save that scope for another time + if request.startingTableVersion == VersionState.DEFAULT: + request.startingTableVersion = VERSION_START + self.groceryClerk.name("inputWorkspace").neutron(request.runNumber).useLiteMode(request.useLiteMode).add() self.groceryClerk.name("groupingWorkspace").fromRun(request.runNumber).grouping( request.focusGroup.name @@ -304,6 +312,9 @@ def save(self, request: CalibrationExportRequest): entry = self.dataFactoryService.createCalibrationIndexEntry(request.createIndexEntryRequest) record = self.dataFactoryService.createCalibrationRecord(request.createRecordRequest) version = entry.version + if self.dataFactoryService.calibrationExists(entry.runNumber, entry.useLiteMode): + if version == VERSION_START: + raise RuntimeError("Overwriting the default calibration is not allowed.") # Rebuild the workspace names to strip any "iteration" number: savedWorkspaces = {} @@ -356,15 +367,16 @@ def save(self, request: CalibrationExportRequest): record.workspaces = savedWorkspaces # save the objects at the indicated version - self.dataExportService.exportCalibrationRecord(record) + self.dataExportService.exportCalibrationRecord(record, entry) self.dataExportService.exportCalibrationWorkspaces(record) - self.saveCalibrationToIndex(entry) @FromString - def load(self, run: RunConfig, version: Optional[int] = None): + def load(self, request: LoadCalibrationRecordRequest): """ If no version is given, will load the latest version applicable to the run number """ + run = request.runConfig + version = request.version return self.dataFactoryService.getCalibrationRecord(run.runNumber, run.useLiteMode, version) def matchRunsToCalibrationVersions(self, request: MatchRunsRequest) -> Dict[str, Any]: @@ -373,7 +385,7 @@ def matchRunsToCalibrationVersions(self, request: MatchRunsRequest) -> Dict[str, """ response = {} for runNumber in request.runNumbers: - response[runNumber] = self.dataFactoryService.getThisOrLatestCalibrationVersion( + response[runNumber] = self.dataFactoryService.getLatestApplicableCalibrationVersion( runNumber, request.useLiteMode ) return response @@ -395,8 +407,6 @@ def saveCalibrationToIndex(self, entry: IndexEntry): """ if entry.appliesTo is None: entry.appliesTo = ">=" + entry.runNumber - if entry.timestamp is None: - entry.timestamp = self.dataExportService.getUniqueTimestamp() logger.info("Saving calibration index entry for Run Number {}".format(entry.runNumber)) self.dataExportService.exportCalibrationIndexEntry(entry) diff --git a/src/snapred/backend/service/NormalizationService.py b/src/snapred/backend/service/NormalizationService.py index 4021e266e..9c20ff0a1 100644 --- a/src/snapred/backend/service/NormalizationService.py +++ b/src/snapred/backend/service/NormalizationService.py @@ -3,7 +3,7 @@ from snapred.backend.dao import Limit from snapred.backend.dao.indexing.IndexEntry import IndexEntry -from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT +from snapred.backend.dao.indexing.Versioning import VersionState from snapred.backend.dao.ingredients import ( GroceryListItem, ) @@ -119,7 +119,9 @@ def normalization(self, request: NormalizationRequest): request.useLiteMode ).add() - calVersion = self.dataFactoryService.getThisOrLatestCalibrationVersion(request.runNumber, request.useLiteMode) + calVersion = self.dataFactoryService.getLatestApplicableCalibrationVersion( + request.runNumber, request.useLiteMode + ) calRunNumber = self.dataFactoryService.getCalibrationRecord( request.runNumber, request.useLiteMode, calVersion ).runNumber @@ -206,8 +208,10 @@ def _validateDiffractionCalibrationExists(self, request: NormalizationRequest): self.sousChef.verifyCalibrationExists(request.runNumber, request.useLiteMode) - calVersion = self.dataFactoryService.getThisOrLatestCalibrationVersion(request.runNumber, request.useLiteMode) - if calVersion == VERSION_DEFAULT: + calVersion = self.dataFactoryService.getLatestApplicableCalibrationVersion( + request.runNumber, request.useLiteMode + ) + if calVersion is None: continueFlags = continueFlags | ContinueWarning.Type.DEFAULT_DIFFRACTION_CALIBRATION if request.continueFlags: @@ -278,6 +282,7 @@ def normalizationAssessment(self, request: NormalizationRequest): normalizationCalibrantSamplePath=request.calibrantSamplePath, calculationParameters=normalization, crystalDBounds=request.crystalDBounds, + version=VersionState.NEXT, ) return self.dataFactoryService.createNormalizationRecord(createRecordRequest) @@ -300,9 +305,8 @@ def saveNormalization(self, request: NormalizationExportRequest): record.workspaceNames = savedWorkspaces # save the objects at the indicated version - self.dataExportService.exportNormalizationRecord(record) + self.dataExportService.exportNormalizationRecord(record, entry) self.dataExportService.exportNormalizationWorkspaces(record) - self.saveNormalizationToIndex(entry) def saveNormalizationToIndex(self, entry: IndexEntry): """ @@ -387,7 +391,7 @@ def matchRunsToNormalizationVersions(self, request: MatchRunsRequest) -> Dict[st """ response = {} for runNumber in request.runNumbers: - response[runNumber] = self.dataFactoryService.getThisOrLatestNormalizationVersion( + response[runNumber] = self.dataFactoryService.getLatestApplicableNormalizationVersion( runNumber, request.useLiteMode ) return response diff --git a/src/snapred/backend/service/ReductionService.py b/src/snapred/backend/service/ReductionService.py index 1ad5a1d0c..c9d92e651 100644 --- a/src/snapred/backend/service/ReductionService.py +++ b/src/snapred/backend/service/ReductionService.py @@ -15,7 +15,6 @@ ReductionExportRequest, ReductionRequest, ) -from snapred.backend.dao.request.ReductionRequest import Versions from snapred.backend.dao.response.ReductionResponse import ReductionResponse from snapred.backend.dao.SNAPRequest import SNAPRequest from snapred.backend.dao.WorkspaceMetadata import DiffcalStateMetadata, NormalizationStateMetadata, WorkspaceMetadata @@ -323,6 +322,9 @@ def prepReductionIngredients( :return: The needed reduction ignredients. :rtype: ReductionIngredients """ + if request.versions is None or request.versions.calibration is None or request.versions.normalization is None: + raise ValueError("Reduction request must have versions set") + farmFresh = FarmFreshIngredients( runNumber=request.runNumber, useLiteMode=request.useLiteMode, @@ -363,11 +365,11 @@ def fetchReductionGroceries(self, request: ReductionRequest) -> Dict[str, Any]: calVersion = None normVersion = None if ContinueWarning.Type.MISSING_DIFFRACTION_CALIBRATION not in request.continueFlags: - calVersion = self.dataFactoryService.getThisOrLatestCalibrationVersion( + calVersion = self.dataFactoryService.getLatestApplicableCalibrationVersion( request.runNumber, request.useLiteMode ) if ContinueWarning.Type.MISSING_NORMALIZATION not in request.continueFlags: - normVersion = self.dataFactoryService.getThisOrLatestNormalizationVersion( + normVersion = self.dataFactoryService.getLatestApplicableNormalizationVersion( request.runNumber, request.useLiteMode ) @@ -416,10 +418,6 @@ def fetchReductionGroceries(self, request: ReductionRequest) -> Dict[str, Any]: request.useLiteMode ).add() - request.versions = Versions( - calVersion, - normVersion, - ) groceries = self.groceryService.fetchGroceryDict( groceryDict=self.groceryClerk.buildDict(), **({"combinedPixelMask": combinedPixelMask} if combinedPixelMask else {}), @@ -496,7 +494,9 @@ def _groupByVanadiumVersion(self, requests: List[SNAPRequest]): for request in requests: runNumber = json.loads(request.payload)["runNumber"] useLiteMode = bool(json.loads(request.payload)["useLiteMode"]) - normalizationVersion = self.dataFactoryService.getThisOrCurrentNormalizationVersion(runNumber, useLiteMode) + normalizationVersion = self.dataFactoryService.getLatestApplicableNormalizationVersion( + runNumber, useLiteMode + ) version = "normalization_" + str(normalizationVersion) if versions.get(version) is None: versions[version] = [] diff --git a/src/snapred/backend/service/SousChef.py b/src/snapred/backend/service/SousChef.py index 804b49aab..fbb6aabec 100644 --- a/src/snapred/backend/service/SousChef.py +++ b/src/snapred/backend/service/SousChef.py @@ -222,6 +222,8 @@ def _pullCalibrationRecordFFI( self, ingredients: FarmFreshIngredients, ) -> FarmFreshIngredients: + if ingredients.versions.calibration is None: + raise ValueError("Calibration version must be specified") calibrationRecord = self.dataFactoryService.getCalibrationRecord( ingredients.runNumber, ingredients.useLiteMode, ingredients.versions.calibration ) diff --git a/src/snapred/meta/mantid/WorkspaceNameGenerator.py b/src/snapred/meta/mantid/WorkspaceNameGenerator.py index 2cc384db3..603abbe5f 100644 --- a/src/snapred/meta/mantid/WorkspaceNameGenerator.py +++ b/src/snapred/meta/mantid/WorkspaceNameGenerator.py @@ -8,12 +8,9 @@ from pydantic.functional_validators import BeforeValidator from typing_extensions import Annotated, Self +from snapred.backend.dao.indexing.Versioning import VERSION_START, VersionState from snapred.meta.Config import Config -# Bypass circular import: -VERSION_DEFAULT = Config["version.default"] -VERSION_DEFAULT_NAME = Config["version.friendlyName.default"] - class WorkspaceName(str): def __new__(cls, value: Any) -> Self: @@ -179,8 +176,8 @@ def formatVersion(cls, version: Optional[int], fmt=versionFormat.WORKSPACE): # in those cases, format will be a user-specified string formattedVersion = "" - if version == VERSION_DEFAULT: - formattedVersion = f"v{VERSION_DEFAULT_NAME}" + if version == VersionState.DEFAULT: + formattedVersion = f"v{VERSION_START}" elif isinstance(version, int): formattedVersion = fmt.format(version=version) elif str(version).isdigit(): @@ -191,8 +188,8 @@ def formatVersion(cls, version: Optional[int], fmt=versionFormat.WORKSPACE): def pathVersion(cls, version: int): # only one special case: default version - if version == VERSION_DEFAULT: - return f"v_{VERSION_DEFAULT_NAME}" + if version == VersionState.DEFAULT: + return f"v_{VersionState.DEFAULT}" return cls.formatVersion(version, fmt=cls.versionFormat.PATH) @classmethod diff --git a/src/snapred/resources/application.yml b/src/snapred/resources/application.yml index f21dc6865..6968c4004 100644 --- a/src/snapred/resources/application.yml +++ b/src/snapred/resources/application.yml @@ -193,7 +193,7 @@ version: friendlyName: error: "uninitialized" # alphanumeric default: 0 # alphanumeric - start: 1 # MUST be nonnegative integer + start: 0 # MUST be nonnegative integer cis_mode: false diff --git a/src/snapred/ui/workflow/DiffCalWorkflow.py b/src/snapred/ui/workflow/DiffCalWorkflow.py index 2ab4e305d..648343393 100644 --- a/src/snapred/ui/workflow/DiffCalWorkflow.py +++ b/src/snapred/ui/workflow/DiffCalWorkflow.py @@ -3,7 +3,7 @@ from snapred.backend.dao import RunConfig from snapred.backend.dao.indexing.IndexEntry import IndexEntry -from snapred.backend.dao.indexing.Versioning import VersionedObject +from snapred.backend.dao.indexing.Versioning import VersionedObject, VersionState from snapred.backend.dao.Limit import Pair from snapred.backend.dao.request import ( CalculateResidualRequest, @@ -534,10 +534,10 @@ def _resetSaveView(self): def _saveCalibration(self, workflowPresenter): view = workflowPresenter.widget.tabView runNumber = view.fieldRunNumber.get() - version = view.fieldVersion.get(None) + version = view.fieldVersion.get(VersionState.NEXT) appliesTo = view.fieldAppliesTo.get(f">={runNumber}") # validate the version number - version = VersionedObject.parseVersion(version, exclude_default=True) + version = VersionedObject(version=version).version # validate appliesTo field appliesTo = IndexEntry.appliesToFormatChecker(appliesTo) diff --git a/src/snapred/ui/workflow/NormalizationWorkflow.py b/src/snapred/ui/workflow/NormalizationWorkflow.py index 95b5e294f..91cdc4bcf 100644 --- a/src/snapred/ui/workflow/NormalizationWorkflow.py +++ b/src/snapred/ui/workflow/NormalizationWorkflow.py @@ -1,7 +1,7 @@ from qtpy.QtCore import Slot from snapred.backend.dao.indexing.IndexEntry import IndexEntry -from snapred.backend.dao.indexing.Versioning import VersionedObject +from snapred.backend.dao.indexing.Versioning import VersionedObject, VersionState from snapred.backend.dao.request import ( CalibrationWritePermissionsRequest, CreateIndexEntryRequest, @@ -217,10 +217,10 @@ def _specifyNormalization(self, workflowPresenter): # noqa: ARG002 def _saveNormalization(self, workflowPresenter): view = workflowPresenter.widget.tabView runNumber = view.fieldRunNumber.get() - version = view.fieldVersion.get() + version = view.fieldVersion.get(VersionState.NEXT) appliesTo = view.fieldAppliesTo.get(f">={self.calibrationRunNumber}") # validate version number - version = VersionedObject.parseVersion(version, exclude_default=True) + version = VersionedObject(version=version).version # validate appliesTo field appliesTo = IndexEntry.appliesToFormatChecker(appliesTo) diff --git a/src/snapred/ui/workflow/ReductionWorkflow.py b/src/snapred/ui/workflow/ReductionWorkflow.py index 1d8e046a1..d79337850 100644 --- a/src/snapred/ui/workflow/ReductionWorkflow.py +++ b/src/snapred/ui/workflow/ReductionWorkflow.py @@ -187,6 +187,10 @@ def _triggerReduction(self, workflowPresenter): response = self.request(path="reduction/groupings", payload=request_) self._keeps = set(response.data["groupingWorkspaces"]) + # Validate reduction; if artificial normalization is needed, handle it + # NOTE: this logic ONLY works because we are forbidding mixed cases of artnorm or loaded norm + response = self.request(path="reduction/validate", payload=request_) + # get the calibration and normalization versions for all runs to be processed matchRequest = MatchRunsRequest(runNumbers=self.runNumbers, useLiteMode=self.useLiteMode) loadedCalibrations, calVersions = self.request(path="calibration/fetchMatches", payload=matchRequest).data @@ -203,9 +207,6 @@ def _triggerReduction(self, workflowPresenter): "and try again." ) - # Validate reduction; if artificial normalization is needed, handle it - # NOTE: this logic ONLY works because we are forbidding mixed cases of artnorm or loaded norm - response = self.request(path="reduction/validate", payload=request_) if ContinueWarning.Type.MISSING_NORMALIZATION in self.continueAnywayFlags: if len(self.runNumbers) > 1: raise RuntimeError( diff --git a/src/snapred/ui/workflow/WorkflowImplementer.py b/src/snapred/ui/workflow/WorkflowImplementer.py index 251502981..c3640d73f 100644 --- a/src/snapred/ui/workflow/WorkflowImplementer.py +++ b/src/snapred/ui/workflow/WorkflowImplementer.py @@ -101,6 +101,7 @@ def reset(self, retainOutputs=False): self.responses = [] self.outputs = [] self.collectedOutputs = [] + self.continueAnywayFlags = ContinueWarning.Type.UNSET for hook in self.resetHooks: logger.info(f"Calling reset hook: {hook}") diff --git a/tests/data/snapred-data b/tests/data/snapred-data index b6fbfbadf..bd6930ff5 160000 --- a/tests/data/snapred-data +++ b/tests/data/snapred-data @@ -1 +1 @@ -Subproject commit b6fbfbadfe8c080b96895e26c5153db344821a02 +Subproject commit bd6930ff57eef257a17adecdab9b7d9cea76850f diff --git a/tests/integration/test_versions_in_order.py b/tests/integration/test_versions_in_order.py index be234c4f3..bdf068dc6 100644 --- a/tests/integration/test_versions_in_order.py +++ b/tests/integration/test_versions_in_order.py @@ -36,9 +36,9 @@ from util.dao import DAOFactory from util.diffraction_calibration_synthetic_data import SyntheticData -from snapred.backend.dao.calibration.CalibrationRecord import CalibrationRecord +from snapred.backend.dao.calibration.CalibrationRecord import CalibrationDefaultRecord, CalibrationRecord from snapred.backend.dao.indexing.IndexEntry import IndexEntry -from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT, VERSION_START +from snapred.backend.dao.indexing.Versioning import VERSION_START, VersionState from snapred.backend.dao.SNAPRequest import SNAPRequest from snapred.backend.dao.SNAPResponse import ResponseCode from snapred.backend.dao.state.DetectorState import DetectorState @@ -112,7 +112,7 @@ def _writeDefaultDiffCalTable(self, runNumber: str, useLiteMode: bool): """ Note this replicates the original in every respect, except using the ImitationGroceryService """ - version = VERSION_DEFAULT + version = VERSION_START grocer = ImitationGroceryService() outWS = grocer.fetchDefaultDiffCalTable(runNumber, useLiteMode, version) filename = Path(outWS + ".h5") @@ -294,9 +294,9 @@ def test_calibration_versioning(self): assert self.state_exists() # ensure the new state has grouping map, calibration state, and default diffcal table - diffCalTableName = wng.diffCalTable().runNumber("default").version(VERSION_DEFAULT).build() + diffCalTableName = wng.diffCalTable().runNumber("default").version(VersionState.DEFAULT).build() assert self.localDataService._groupingMapPath(self.stateId).exists() - versionDir = wnvf.pathVersion(VERSION_DEFAULT) + versionDir = wnvf.pathVersion(VERSION_START) assert Path(self.stateRoot, "lite", "diffraction", versionDir, "CalibrationParameters.json").exists() assert Path(self.stateRoot, "native", "diffraction", versionDir, "CalibrationParameters.json").exists() assert Path(self.stateRoot, "lite", "diffraction", versionDir, diffCalTableName + ".h5").exists() @@ -306,22 +306,22 @@ def test_calibration_versioning(self): assert [] == self.get_index() # assert the current diffcal version is the default, and the next is the start - assert self.indexer.currentVersion() == VERSION_DEFAULT - assert self.indexer.latestApplicableVersion(self.runNumber) == VERSION_DEFAULT - assert self.indexer.nextVersion() == VERSION_START + assert self.indexer.currentVersion() == VERSION_START + assert self.indexer.latestApplicableVersion(self.runNumber) == VERSION_START + assert self.indexer.nextVersion() == VERSION_START + 1 # run diffraction calibration for the first time, and save res = self.run_diffcal() - self.save_diffcal(res, version=None) + self.save_diffcal(res, version=VersionState.NEXT) # ensure things saved correctly - self.assert_diffcal_saved(VERSION_START) + self.assert_diffcal_saved(VERSION_START + 1) assert len(self.get_index()) == 1 # run diffraction calibration for a second time, and save res = self.run_diffcal() - self.save_diffcal(res, version=None) - self.assert_diffcal_saved(VERSION_START + 1) + self.save_diffcal(res, version=VersionState.NEXT) + self.assert_diffcal_saved(VERSION_START + 2) assert len(self.get_index()) == 2 # now save at version 7 @@ -334,7 +334,7 @@ def test_calibration_versioning(self): # now save at next version -- will be 8 version = 8 res = self.run_diffcal() - self.save_diffcal(res, version=None) # NOTE using None points it to next version + self.save_diffcal(res, version=VersionState.NEXT) self.assert_diffcal_saved(version) assert len(self.get_index()) == 4 @@ -376,7 +376,7 @@ def run_diffcal(self): assert response.code <= ResponseCode.MAX_OK return response.data - def save_diffcal(self, res, version=None): + def save_diffcal(self, res, version=VersionState.NEXT): # send a request through interface controller to save the diffcal results # needs the list of output workspaces, and may take an optional version # create an export request using an existing record as a basis @@ -410,7 +410,7 @@ def save_diffcal(self, res, version=None): "createIndexEntryRequest": createIndexEntryRequest, "createRecordRequest": createRecordRequest, } - request = SNAPRequest(path="calibration/save", payload=json.dumps(payload)) + request = SNAPRequest(path="calibration/save", payload=json.dumps(payload, default=str)) response = self.api.executeRequest(request) assert response.code <= ResponseCode.MAX_OK return response.data @@ -420,13 +420,19 @@ def assert_diffcal_saved(self, version): assert self.indexer.versionPath(version).exists() assert self.indexer.recordPath(version).exists() assert self.indexer.parametersPath(version).exists() - savedRecord = parse_file_as(CalibrationRecord, self.indexer.recordPath(version)) + savedRecord = None + if version == VERSION_START: + savedRecord = parse_file_as(CalibrationDefaultRecord, self.indexer.recordPath(version)) + else: + savedRecord = parse_file_as(CalibrationRecord, self.indexer.recordPath(version)) + assert savedRecord.version == version assert savedRecord.calculationParameters.version == version # make sure all workspaces exist workspaces = savedRecord.workspaces - assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_OUTPUT][0] + ".nxs.h5")).exists() - assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_DIAG][0] + ".nxs.h5")).exists() + if not version == VERSION_START: + assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_OUTPUT][0] + ".nxs.h5")).exists() + assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_DIAG][0] + ".nxs.h5")).exists() assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_TABLE][0] + ".h5")).exists() # assert this version is in the index index = self.indexer.readIndex() @@ -438,7 +444,7 @@ def assert_diffcal_saved(self, version): assert self.indexer.latestApplicableVersion(self.runNumber) == version assert self.indexer.nextVersion() == version + 1 # load the previous calibration and verify equality - runConfig = {"runNumber": self.runNumber, "useLiteMode": self.useLiteMode} + runConfig = {"runConfig": {"runNumber": self.runNumber, "useLiteMode": self.useLiteMode}, "version": version} request = SNAPRequest(path="calibration/load", payload=json.dumps(runConfig)) response = self.api.executeRequest(request) assert response.code <= ResponseCode.MAX_OK diff --git a/tests/integration/test_workflow_panels_happy_path.py b/tests/integration/test_workflow_panels_happy_path.py index aba5e5062..eca307711 100644 --- a/tests/integration/test_workflow_panels_happy_path.py +++ b/tests/integration/test_workflow_panels_happy_path.py @@ -20,6 +20,7 @@ # however, for the moment, the reduction-data output relocation fixture is defined in the current file. from snapred.backend.data.LocalDataService import LocalDataService from snapred.meta.Config import Config, Resource +from snapred.meta.Enum import StrEnum from snapred.ui.main import SNAPRedGUI, prependDataSearchDirectories from snapred.ui.view import InitializeStateCheckView from snapred.ui.view.DiffCalAssessmentView import DiffCalAssessmentView @@ -39,6 +40,63 @@ class InterruptWithBlock(BaseException): pass +class TestSummary: + def __init__(self): + self._index = 0 + self._steps = [] + + def SUCCESS(self): + step = self._steps[self._index] + step.status = self.TestStep.StepStatus.SUCCESS + self._index += 1 + + def FAILURE(self): + step = self._steps[self._index] + step.status = self.TestStep.StepStatus.FAILURE + self._index += 1 + + def isComplete(self): + return self._index == len(self._steps) + + def isFailure(self): + return any(step.status == self.TestStep.StepStatus.FAILURE for step in self._steps) + + def builder(): + return TestSummary.TestSummaryBuilder() + + def __str__(self): + longestStatus = max(len(step.status) for step in self._steps) + longestName = max(len(step.name) for step in self._steps) + tableCapStr = "#" * (longestName + longestStatus + 6) + tableStr = ( + f"\n{tableCapStr}\n" + + "\n".join(f"# {step.name:{longestName}}: {step.status:{longestStatus}} #" for step in self._steps) + + f"\n{tableCapStr}\n" + ) + return tableStr + + class TestStep: + class StepStatus(StrEnum): + SUCCESS = "SUCCESS" + FAILURE = "FAILURE" + INCOMPLETE = "INCOMPLETE" + + def __init__(self, name: str): + self.name = name + self.status = self.StepStatus.INCOMPLETE + + class TestSummaryBuilder: + def __init__(self): + self.summary = TestSummary() + + def step(self, name: str): + self.summary._steps.append(TestSummary.TestStep(name)) + return self + + def build(self): + return self.summary + + @pytest.fixture def calibration_home_from_mirror(): # Test fixture to create a copy of the calibration home directory from an existing mirror: @@ -196,8 +254,16 @@ def _setup_gui(self, qapp): # Establish context for each test: these normally run as part of `src/snapred/__main__.py`. self.exitStack = ExitStack() self.exitStack.enter_context(amend_config(data_dir=prependDataSearchDirectories(), prepend_datadir=True)) + + self.testSummary = None yield + if isinstance(self.testSummary, TestSummary): + if not self.testSummary.isComplete(): + self.testSummary.FAILURE() + if self.testSummary.isFailure(): + pytest.fail(f"Test Summary (-vv for full table): {self.testSummary}") + # teardown... self._warningMessageBox.stop() self._criticalMessageBox.stop() @@ -722,7 +788,18 @@ def test_diffraction_calibration_panel_happy_path(self, qtbot, qapp, calibration # Override the mirror with a new home directory, omitting any existing # calibration or normalization data. tmpCalibrationHomeDirectory = calibration_home_from_mirror() # noqa: F841 - + self.testSummary = ( + TestSummary.builder() + .step("Open the GUI") + .step("Open the calibration panel") + .step("Set the diffraction calibration request") + .step("Execute the diffraction calibration request") + .step("Tweak the peaks") + .step("Assess the peaks") + .step("Save the diffraction calibration") + .step("Close the GUI") + .build() + ) with ( qtbot.captureExceptions() as exceptions, suppress(InterruptWithBlock), @@ -750,6 +827,7 @@ def test_diffraction_calibration_panel_happy_path(self, qtbot, qapp, calibration """ # Open the calibration panel: + self.testSummary.SUCCESS() # QPushButton* button = pWin->findChild("Button name"); qtbot.mouseClick(gui.calibrationPanelButton, QtCore.Qt.LeftButton) if len(exceptions): @@ -780,6 +858,7 @@ def test_diffraction_calibration_panel_happy_path(self, qtbot, qapp, calibration requestView = workflowNodeTabs.currentWidget().view assert isinstance(requestView, DiffCalRequestView) + self.testSummary.SUCCESS() # set "Run Number", "Convergence Threshold", ,: requestView.runNumberField.setText("46680") @@ -804,6 +883,7 @@ def test_diffraction_calibration_panel_happy_path(self, qtbot, qapp, calibration assert requestView.peakFunctionDropdown.currentIndex() == 0 assert requestView.peakFunctionDropdown.currentText() == "Gaussian" + self.testSummary.SUCCESS() # execute the request # TODO: make sure that there's no initialized state => abort the test if there is! @@ -858,13 +938,17 @@ def test_diffraction_calibration_panel_happy_path(self, qtbot, qapp, calibration questionMessageBox.stop() successPrompt.stop() + # Now that there is a new state, we need to reselect the grouping file ... : + # Why was this error box being swallowed? + requestView.groupingFileDropdown.setCurrentIndex(1) + # (2) execute the calibration workflow with qtbot.waitSignal(actionCompleted, timeout=60000): qtbot.mouseClick(workflowNodeTabs.currentWidget().continueButton, Qt.MouseButton.LeftButton) qtbot.waitUntil( lambda: isinstance(workflowNodeTabs.currentWidget().view, DiffCalTweakPeakView), timeout=60000 ) - + self.testSummary.SUCCESS() tweakPeakView = workflowNodeTabs.currentWidget().view # --------------------------------------------------------------------------- @@ -896,6 +980,7 @@ def test_diffraction_calibration_panel_happy_path(self, qtbot, qapp, calibration # continue to the next panel with qtbot.waitSignal(actionCompleted, timeout=80000): qtbot.mouseClick(workflowNodeTabs.currentWidget().continueButton, Qt.MouseButton.LeftButton) + self.testSummary.SUCCESS() qtbot.waitUntil( lambda: isinstance(workflowNodeTabs.currentWidget().view, DiffCalAssessmentView), timeout=80000 @@ -913,6 +998,7 @@ def test_diffraction_calibration_panel_happy_path(self, qtbot, qapp, calibration # continue to the next panel with qtbot.waitSignal(actionCompleted, timeout=80000): qtbot.mouseClick(workflowNodeTabs.currentWidget().continueButton, Qt.MouseButton.LeftButton) + self.testSummary.SUCCESS() qtbot.waitUntil(lambda: isinstance(workflowNodeTabs.currentWidget().view, DiffCalSaveView), timeout=5000) saveView = workflowNodeTabs.currentWidget().view @@ -924,7 +1010,7 @@ def test_diffraction_calibration_panel_happy_path(self, qtbot, qapp, calibration # continue in order to save workspaces and to finish the workflow with qtbot.waitSignal(actionCompleted, timeout=60000): qtbot.mouseClick(workflowNodeTabs.currentWidget().continueButton, Qt.MouseButton.LeftButton) - + self.testSummary.SUCCESS() # `ActionPrompt.prompt("..The workflow has completed successfully..)` gives immediate mocked response: # Here we still need to wait until the ADS cleanup has occurred, # or else it will happen in the middle of the next workflow. :( @@ -938,6 +1024,7 @@ def test_diffraction_calibration_panel_happy_path(self, qtbot, qapp, calibration calibrationPanel.widget.close() gui.close() + self.testSummary.SUCCESS() ##################################################################### # Force a printout of information about any exceptions that happened# @@ -955,7 +1042,17 @@ def test_normalization_panel_happy_path(self, qtbot, qapp, calibration_home_from # Override the mirror with a new home directory, omitting any existing # calibration or normalization data. tmpCalibrationHomeDirectory = calibration_home_from_mirror() # noqa: F841 - + self.testSummary = ( + TestSummary.builder() + .step("Open the GUI") + .step("Open the Normalization panel") + .step("Set the normalization request") + .step("Execute the normalization request") + .step("Tweak the peaks") + .step("Save the normalization calibration") + .step("Close the GUI") + .build() + ) with ( qtbot.captureExceptions() as exceptions, suppress(InterruptWithBlock), @@ -963,7 +1060,7 @@ def test_normalization_panel_happy_path(self, qtbot, qapp, calibration_home_from gui = SNAPRedGUI(translucentBackground=True) gui.show() qtbot.addWidget(gui) - + self.testSummary.SUCCESS() """ SNAPRedGUI owns the following widgets: @@ -1013,6 +1110,7 @@ def test_normalization_panel_happy_path(self, qtbot, qapp, calibration_home_from requestView = workflowNodeTabs.currentWidget().view assert isinstance(requestView, NormalizationRequestView) + self.testSummary.SUCCESS() # set "Run Number", "Background run number": requestView.runNumberField.setText("46680") @@ -1031,7 +1129,7 @@ def test_normalization_panel_happy_path(self, qtbot, qapp, calibration_home_from requestView.groupingFileDropdown.setCurrentIndex(1) assert requestView.groupingFileDropdown.currentIndex() == 1 assert requestView.groupingFileDropdown.currentText() == "Bank" - + self.testSummary.SUCCESS() """ # Why no "peak function" for normalization calibration?! requestView.peakFunctionDropdown.setCurrentIndex(0) assert requestView.peakFunctionDropdown.currentIndex() == 0 @@ -1093,6 +1191,10 @@ def test_normalization_panel_happy_path(self, qtbot, qapp, calibration_home_from questionMessageBox.stop() successPrompt.stop() + # Now that there is a new state, we need to reselect the grouping file ... : + # Why was this error box being swallowed? + requestView.groupingFileDropdown.setCurrentIndex(1) + warningMessageBox = mock.patch( # noqa: PT008 "qtpy.QtWidgets.QMessageBox.warning", lambda *args, **kwargs: QMessageBox.Yes, # noqa: ARG005 @@ -1106,6 +1208,7 @@ def test_normalization_panel_happy_path(self, qtbot, qapp, calibration_home_from lambda: isinstance(workflowNodeTabs.currentWidget().view, NormalizationTweakPeakView), timeout=60000 ) warningMessageBox.stop() + self.testSummary.SUCCESS() tweakPeakView = workflowNodeTabs.currentWidget().view # set "Smoothing", "xtal dMin", "xtal dMax", "intensity threshold", and "groupingDropDown" @@ -1131,6 +1234,7 @@ def test_normalization_panel_happy_path(self, qtbot, qapp, calibration_home_from qtbot.waitUntil( lambda: isinstance(workflowNodeTabs.currentWidget().view, NormalizationSaveView), timeout=60000 ) + self.testSummary.SUCCESS() saveView = workflowNodeTabs.currentWidget().view # set "author" and "comment" @@ -1140,7 +1244,7 @@ def test_normalization_panel_happy_path(self, qtbot, qapp, calibration_home_from # continue in order to save workspaces and to finish the workflow with qtbot.waitSignal(actionCompleted, timeout=60000): qtbot.mouseClick(workflowNodeTabs.currentWidget().continueButton, Qt.MouseButton.LeftButton) - + self.testSummary.SUCCESS() # `ActionPrompt.prompt("..The workflow has completed successfully..)` gives immediate mocked response: # Here we still need to wait until the ADS cleanup has occurred, # or else it will happen in the middle of the next workflow. :( @@ -1154,7 +1258,7 @@ def test_normalization_panel_happy_path(self, qtbot, qapp, calibration_home_from calibrationPanel.widget.close() gui.close() - + self.testSummary.SUCCESS() ##################################################################### # Force a printout of information about any exceptions that happened# # within the Qt event loop. # @@ -1181,6 +1285,16 @@ def test_reduction_panel_happy_path(self, qtbot, qapp, reduction_home_from_mirro # under the existing location within the mirror. tmpReductionHomeDirectory = reduction_home_from_mirror(reductionRunNumber) # noqa: F841 + self.testSummary = ( + TestSummary.builder() + .step("Open the GUI") + .step("Open the Reduction panel") + .step("Set the reduction request") + .step("Execute the reduction request") + .step("Close the GUI") + .build() + ) + self.completionMessageHasAppeared = False def completionMessageBoxAssert(*args, **kwargs): # noqa: ARG001 @@ -1202,7 +1316,7 @@ def completionMessageBoxAssert(*args, **kwargs): # noqa: ARG001 gui = SNAPRedGUI(translucentBackground=True) gui.show() qtbot.addWidget(gui) - + self.testSummary.SUCCESS() """ SNAPRedGUI owns the following widgets: @@ -1250,7 +1364,7 @@ def completionMessageBoxAssert(*args, **kwargs): # noqa: ARG001 requestView = workflowNodeTabs.currentWidget().view assert isinstance(requestView, ReductionRequestView) - + self.testSummary.SUCCESS() # Without this next wait, the "run number entry" section happens too fast. # (And I'd love to understand _why_! :( ) qtbot.wait(1000) @@ -1263,7 +1377,7 @@ def completionMessageBoxAssert(*args, **kwargs): # noqa: ARG001 _runNumbers = [requestView.runNumberDisplay.item(x).text() for x in range(_count)] assert reductionRunNumber in _runNumbers - + self.testSummary.SUCCESS() """ request.liteModeToggle.setState(True); request.retainUnfocusedDataCheckbox.setValue(False); @@ -1344,13 +1458,14 @@ def completionMessageBoxAssert(*args, **kwargs): # noqa: ARG001 timeout=60000, ) completionMessageBox.stop() - + self.testSummary.SUCCESS() ############################### ########### END OF TEST ####### ############################### calibrationPanel.widget.close() gui.close() + self.testSummary.SUCCESS() ##################################################################### # Force a printout of information about any exceptions that happened# diff --git a/tests/resources/inputs/reduction/ReductionRecord_20240614T130420.json b/tests/resources/inputs/reduction/ReductionRecord_20240614T130420.json index 16a3eede6..640b22625 100644 --- a/tests/resources/inputs/reduction/ReductionRecord_20240614T130420.json +++ b/tests/resources/inputs/reduction/ReductionRecord_20240614T130420.json @@ -761,6 +761,7 @@ "fitted_strippedFocussedData_stripped_2" ], "version": 7, + "calibrationVersionUsed": 1, "crystalDBounds": {"minimum": 0.4, "maximum": 100.0}, "normalizationCalibrantSamplePath": "/SNS/SNAP/shared/Calibration/Calibrants/Al2O3/Al2O3_20240614T130420.nxs" }, diff --git a/tests/unit/backend/dao/test_VersionedObject.py b/tests/unit/backend/dao/test_VersionedObject.py index ded21ea95..59dadcbfc 100644 --- a/tests/unit/backend/dao/test_VersionedObject.py +++ b/tests/unit/backend/dao/test_VersionedObject.py @@ -9,10 +9,9 @@ from snapred.backend.dao.indexing.IndexEntry import IndexEntry from snapred.backend.dao.indexing.Record import Record from snapred.backend.dao.indexing.Versioning import ( - VERSION_DEFAULT, - VERSION_DEFAULT_NAME, - VERSION_NONE_NAME, + VERSION_START, VersionedObject, + VersionState, ) @@ -25,24 +24,19 @@ def test_init_bad(): VersionedObject(version=1.2) -def test_init_name_none(): - vo = VersionedObject(version=VERSION_NONE_NAME) - assert vo.version is None - - def test_init_name_default(): - vo = VersionedObject(version=VERSION_DEFAULT_NAME) - assert vo.version == VERSION_DEFAULT + vo = VersionedObject(version=VersionState.DEFAULT) + assert vo.version == VersionState.DEFAULT def test_init_none(): - vo = VersionedObject(version=None) - assert vo.version is None + with pytest.raises(ValueError): + VersionedObject(version=None) def test_init_default(): - vo = VersionedObject(version=VERSION_DEFAULT) - assert vo.version == VERSION_DEFAULT + vo = VersionedObject(version=VersionState.DEFAULT) + assert vo.version == VersionState.DEFAULT def test_init_int(): @@ -73,19 +67,16 @@ def test_write_version_int(): def test_write_version_none(): - vo = VersionedObject(version=None) - assert vo.version is None - assert vo.model_dump_json() == f'{{"version":"{VERSION_NONE_NAME}"}}' - assert vo.model_dump_json() != '{"version":null}' - assert vo.dict()["version"] == VERSION_NONE_NAME + with pytest.raises(ValueError): + VersionedObject(version=None) def test_write_version_default(): - vo = VersionedObject(version=VERSION_DEFAULT_NAME) - assert vo.version == VERSION_DEFAULT - assert vo.model_dump_json() == f'{{"version":"{VERSION_DEFAULT_NAME}"}}' - assert vo.model_dump_json() != f'{{"version":{VERSION_DEFAULT}}}' - assert vo.dict()["version"] == VERSION_DEFAULT_NAME + vo = VersionedObject(version=VERSION_START) + assert vo.version == VERSION_START + assert vo.model_dump_json() == f'{{"version":{VERSION_START}}}' + assert vo.model_dump_json() != f'{{"version":"{VERSION_START}"}}' + assert vo.dict()["version"] == VERSION_START def test_can_set_valid(): @@ -95,11 +86,11 @@ def test_can_set_valid(): vo.version = new_version assert vo.version == new_version - vo.version = VERSION_DEFAULT - assert vo.version == VERSION_DEFAULT + vo.version = VersionState.DEFAULT + assert vo.version == VersionState.DEFAULT - vo.version = VERSION_DEFAULT_NAME - assert vo.version == VERSION_DEFAULT + vo.version = VersionState.DEFAULT + assert vo.version == VersionState.DEFAULT def test_cannot_set_invalid(): @@ -117,17 +108,12 @@ def test_shaped_liked_itself(): Make a versioned object. Serialize it. Then parse it back as itself. It should validate and create an identical object. """ - # version is None - vo_old = VersionedObject(version=None) - vo_new = VersionedObject.model_validate(vo_old.model_dump()) - assert vo_old == vo_new - vo_new = VersionedObject.model_validate(vo_old.dict()) - assert vo_old == vo_new - vo_new = VersionedObject.model_validate_json(vo_old.model_dump_json()) - assert vo_old == vo_new # version is default - vo_old = VersionedObject(version=VERSION_DEFAULT) + vo_old = VersionedObject(version=VersionState.DEFAULT) + with pytest.raises(ValueError, match="must be flattened to an int before writing to"): + vo_old.model_dump_json() + vo_old.version = VERSION_START vo_new = VersionedObject.model_validate(vo_old.model_dump()) assert vo_old == vo_new vo_new = VersionedObject.model_validate(vo_old.dict()) @@ -172,17 +158,10 @@ def test_init_index_entry(): vo = indexEntryWithVersion(version) assert vo.version == version - # init with none - vo = indexEntryWithVersion(VERSION_NONE_NAME) - assert vo.version is None - vo = indexEntryWithVersion(None) - assert vo.version is None - # init with default - vo = indexEntryWithVersion(VERSION_DEFAULT_NAME) - assert vo.version == VERSION_DEFAULT - vo = indexEntryWithVersion(VERSION_DEFAULT) - assert vo.version == VERSION_DEFAULT + vo = indexEntryWithVersion(VersionState.DEFAULT) + # This must be flattened to an int before writing to JSON + assert vo.version == VersionState.DEFAULT def test_set_index_entry(): @@ -192,11 +171,8 @@ def test_set_index_entry(): vo.version = new_version assert vo.version == new_version - vo.version = VERSION_DEFAULT - assert vo.version == VERSION_DEFAULT - - vo.version = VERSION_DEFAULT_NAME - assert vo.version == VERSION_DEFAULT + vo.version = VERSION_START + assert vo.version == VERSION_START vo = indexEntryWithVersion(randint(0, 120)) with pytest.raises(ValueError): @@ -215,19 +191,16 @@ def test_write_version_index_entry(): assert f'"version":{version}' in vo.model_dump_json() assert vo.dict()["version"] == version - # test write none - vo = indexEntryWithVersion(None) - assert vo.version is None - assert f'"version":"{VERSION_NONE_NAME}"' in vo.model_dump_json() - assert '"version":null' not in vo.model_dump_json() - assert vo.dict()["version"] == VERSION_NONE_NAME - # test write default - vo = indexEntryWithVersion(VERSION_DEFAULT) - assert vo.version == VERSION_DEFAULT - assert f'"version":"{VERSION_DEFAULT_NAME}"' in vo.model_dump_json() - assert f'"version":{VERSION_DEFAULT}' not in vo.model_dump_json() - assert vo.dict()["version"] == VERSION_DEFAULT_NAME + vo = indexEntryWithVersion(VersionState.DEFAULT) + + with pytest.raises(ValueError, match="must be flattened to an int before writing to"): + vo.model_dump_json() + + vo.version = VERSION_START + assert f'"version":{VERSION_START}' in vo.model_dump_json() + assert f'"version":"{VERSION_START}"' not in vo.model_dump_json() + assert vo.dict()["version"] == VERSION_START ### TESTS OF RECORDS AS VERSIONED OBJECTS ### @@ -258,17 +231,11 @@ def test_init_record(): vo = recordWithVersion(version) assert vo.version == version - # init with none - vo = recordWithVersion(VERSION_NONE_NAME) - assert vo.version is None - vo = recordWithVersion(None) - assert vo.version is None - # init with default - vo = recordWithVersion(VERSION_DEFAULT_NAME) - assert vo.version == VERSION_DEFAULT - vo = recordWithVersion(VERSION_DEFAULT) - assert vo.version == VERSION_DEFAULT + vo = recordWithVersion(VersionState.DEFAULT) + assert vo.version == VersionState.DEFAULT + vo = recordWithVersion(VersionState.DEFAULT) + assert vo.version == VersionState.DEFAULT def test_set_record(): @@ -278,11 +245,11 @@ def test_set_record(): vo.version = new_version assert vo.version == new_version - vo.version = VERSION_DEFAULT - assert vo.version == VERSION_DEFAULT + vo.version = VersionState.DEFAULT + assert vo.version == VersionState.DEFAULT - vo.version = VERSION_DEFAULT_NAME - assert vo.version == VERSION_DEFAULT + vo.version = VersionState.DEFAULT + assert vo.version == VersionState.DEFAULT vo = recordWithVersion(randint(0, 120)) with pytest.raises(ValueError): @@ -301,16 +268,14 @@ def test_write_version_record(): assert f'"version":{version}' in vo.model_dump_json() assert vo.dict()["version"] == version - # test write none - vo = recordWithVersion(None) - assert vo.version is None - assert f'"version":"{VERSION_NONE_NAME}"' in vo.model_dump_json() - assert '"version":null' not in vo.model_dump_json() - assert vo.dict()["version"] == VERSION_NONE_NAME - # test write default - vo = recordWithVersion(VERSION_DEFAULT) - assert vo.version == VERSION_DEFAULT - assert f'"version":"{VERSION_DEFAULT_NAME}"' in vo.model_dump_json() - assert f'"version":{VERSION_DEFAULT}' not in vo.model_dump_json() - assert vo.dict()["version"] == VERSION_DEFAULT_NAME + vo = recordWithVersion(VersionState.DEFAULT) + assert vo.version == VersionState.DEFAULT + + with pytest.raises(ValueError, match="must be flattened to an int before writing to"): + vo.model_dump_json() + vo.version = VERSION_START + + assert f'"version":{VERSION_START}' in vo.model_dump_json() + assert f'"version":"{VERSION_START}"' not in vo.model_dump_json() + assert vo.dict()["version"] == VERSION_START diff --git a/tests/unit/backend/data/test_DataFactoryService.py b/tests/unit/backend/data/test_DataFactoryService.py index e87bddd03..8d4423ce6 100644 --- a/tests/unit/backend/data/test_DataFactoryService.py +++ b/tests/unit/backend/data/test_DataFactoryService.py @@ -67,16 +67,12 @@ def setUpClass(cls): cls.mockLookupService.calibrationIndexer.return_value = mock.Mock( versionPath=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Calibration", *x)), getIndex=mock.Mock(return_value=[cls.expected(cls, "Calibration")]), - thisOrNextVersion=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Calibration", *x)), - thisOrCurrentVersion=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Calibration", *x)), - thisOrLatestApplicableVersion=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Calibration", *x)), + latestApplicableVersion=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Calibration", *x)), ) cls.mockLookupService.normalizationIndexer.return_value = mock.Mock( versionPath=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Normalization", *x)), getIndex=mock.Mock(return_value=[cls.expected(cls, "Normalization")]), - thisOrNextVersion=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Normalization", *x)), - thisOrCurrentVersion=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Normalization", *x)), - thisOrLatestApplicableVersion=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Normalization", *x)), + latestApplicableVersion=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Normalization", *x)), ) def setUp(self): @@ -187,20 +183,10 @@ def test_getCalibrationDataWorkspace(self): actual = self.instance.getCalibrationDataWorkspace("456", useLiteMode, self.version, "bunko") assert actual == self.instance.groceryService.fetchWorkspace.return_value - def test_getThisOrCurrentCalibrationVersion(self): + def test_getLatestCalibrationVersion(self): for useLiteMode in [True, False]: - actual = self.instance.getThisOrCurrentCalibrationVersion("123", useLiteMode, self.version) - assert actual == self.expected("Calibration", self.version) # NOTE mock indexer called only with version - - def test_getThisOrNextCalibrationVersion(self): - for useLiteMode in [True, False]: - actual = self.instance.getThisOrNextCalibrationVersion("123", useLiteMode, self.version) - assert actual == self.expected("Calibration", self.version) # NOTE mock indexer called only with version - - def test_getThisOrLatestCalibrationVersion(self): - for useLiteMode in [True, False]: - actual = self.instance.getThisOrLatestCalibrationVersion("123", useLiteMode, self.version) - assert actual == self.expected("Calibration", "123", self.version) + actual = self.instance.getLatestApplicableCalibrationVersion("123", useLiteMode) + assert actual == self.expected("Calibration", "123") ## TEST NORMALIZATION METHODS @@ -240,20 +226,10 @@ def test_getNormalizationDataWorkspace(self): actual = self.instance.getNormalizationDataWorkspace("456", useLiteMode, self.version, "bunko") assert actual == self.instance.groceryService.fetchWorkspace.return_value - def test_getThisOrCurrentNormalizationVersion(self): - for useLiteMode in [True, False]: - actual = self.instance.getThisOrCurrentNormalizationVersion("123", useLiteMode, self.version) - assert actual == self.expected("Normalization", self.version) # NOTE mock indexer called only with version - - def test_getThisOrNextNormalizationVersion(self): - for useLiteMode in [True, False]: - actual = self.instance.getThisOrNextNormalizationVersion("123", useLiteMode, self.version) - assert actual == self.expected("Normalization", self.version) # NOTE mock indexer called only with version - - def test_getThisOrLatestNormalizationVersion(self): + def test_getLatestNormalizationVersion(self): for useLiteMode in [True, False]: - actual = self.instance.getThisOrLatestNormalizationVersion("123", useLiteMode, self.version) - assert actual == self.expected("Normalization", "123", self.version) + actual = self.instance.getLatestApplicableNormalizationVersion("123", useLiteMode) + assert actual == self.expected("Normalization", "123") ## TEST REDUCTION METHODS diff --git a/tests/unit/backend/data/test_GroceryService.py b/tests/unit/backend/data/test_GroceryService.py index f5032a5a9..6ad8b7991 100644 --- a/tests/unit/backend/data/test_GroceryService.py +++ b/tests/unit/backend/data/test_GroceryService.py @@ -156,8 +156,21 @@ def setUp(self): .version(self.version) .build() ) + # self.versionOutputPath = Path(Resource.getPath(f"outputs/)) return super().setUp() + def mockIndexer(self, root=None, calType=None): + mockIndexer = mock.Mock() + mockIndexer.readRecord = mock.Mock( + return_value=mock.MagicMock(version=self.version, runNumber=self.runNumber1, useLiteMode=False) + ) + # tmpqw8kedh8/native/diffraction/v_0078/ + if root is not None: + liteModeStr = "lite" if self.useLiteMode == "lite" else "native" + versionPath = Path(f"{root}/{liteModeStr}/{calType}/v_{self.version:04d}") + mockIndexer.versionPath = mock.Mock(return_value=versionPath) + return mock.Mock(return_value=mockIndexer) + def clearoutWorkspaces(self) -> None: """Delete the workspaces created by loading""" for ws in mtd.getObjectNames(): @@ -313,6 +326,11 @@ def test_diffcal_table_filename_from_workspaceName(self): def test_diffcal_table_filename_from_workspaceName_failure_name_mismatch(self): workspaceName = "bogus_name_" + mockIndexer = mock.Mock() + mockIndexer.readRecord = mock.Mock() + mockDataService = mock.Mock() + mockDataService.calibrationIndexer = mock.Mock(return_value=mockIndexer) + self.instance.dataService = mockDataService with pytest.raises(ValueError, match=r".*Workspace name .* does not match the expected *"): self.instance._createDiffcalTableFilepathFromWsName( self.runNumber, self.useLiteMode, self.version, workspaceName @@ -328,6 +346,7 @@ def test_diffcal_table_filename(self): def test_normalization_workspace_filename(self): # Test name generation for diffraction-calibration table filename + self.instance.dataService.normalizationIndexer = self.mockIndexer("root", "normalization") res = self.instance._createNormalizationWorkspaceFilename(self.runNumber, self.useLiteMode, self.version) assert self.runNumber in res assert wnvf.formatVersion(self.version) in res @@ -1100,6 +1119,7 @@ def test_fetch_grocery_list_diffcal_fails(self): def test_fetch_grocery_list_diffcal_output(self): # Test of workspace type "diffcal_output" as `Input` argument in the `GroceryList` with state_root_redirect(self.instance.dataService) as tmpRoot: + self.instance.dataService.calibrationIndexer = self.mockIndexer(tmpRoot.path(), "diffraction") groceryList = ( GroceryListItem.builder() .native() @@ -1144,6 +1164,9 @@ def test_fetch_grocery_list_diffcal_output_cached(self): InputWorkspace=self.sampleWS, OutputWorkspace=diffCalOutputName, ) + + self.instance.dataService.calibrationIndexer = self.mockIndexer("root", "diffraction") + assert mtd.doesExist(diffCalOutputName) testTitle = "I'm a little teapot" mtd[diffCalOutputName].setTitle(testTitle) @@ -1156,6 +1179,7 @@ def test_fetch_grocery_list_diffcal_table(self): # Test of workspace type "diffcal_table" as `Input` argument in the `GroceryList` self.instance._fetchInstrumentDonor = mock.Mock(return_value=self.sampleWS) with state_root_redirect(self.instance.dataService) as tmpRoot: + self.instance.dataService.calibrationIndexer = self.mockIndexer(tmpRoot.path(), "diffraction") groceryList = GroceryListItem.builder().native().diffcal_table(self.runNumber1, self.version).buildList() # independently construct the pathname, move file to there, assert exists diffCalTableName = wng.diffCalTable().runNumber(self.runNumber1).version(self.version).build() @@ -1200,6 +1224,7 @@ def test_fetch_grocery_list_diffcal_table_loads_mask(self): # * corresponding mask workspace is also loaded from the hdf5-format file. self.instance._fetchInstrumentDonor = mock.Mock(return_value=self.sampleWS) with state_root_redirect(self.instance.dataService) as tmpRoot: + self.instance.dataService.calibrationIndexer = self.mockIndexer(tmpRoot.path(), "diffraction") groceryList = GroceryListItem.builder().native().diffcal_table(self.runNumber1, self.version).buildList() diffCalTableName = wng.diffCalTable().runNumber(self.runNumber1).version(self.version).build() self.instance.lookupDiffcalTableWorkspaceName = mock.Mock(return_value=diffCalTableName) @@ -1223,6 +1248,7 @@ def test_fetch_grocery_list_diffcal_mask(self): # Test of workspace type "diffcal_mask" as `Input` argument in the `GroceryList` self.instance._fetchInstrumentDonor = mock.Mock(return_value=self.sampleWS) with state_root_redirect(self.instance.dataService) as tmpRoot: + self.instance.dataService.calibrationIndexer = self.mockIndexer(tmpRoot.path(), "diffraction") groceryList = GroceryListItem.builder().native().diffcal_mask(self.runNumber1, self.version).buildList() diffCalMaskName = wng.diffCalMask().runNumber(self.runNumber1).version(self.version).build() @@ -1251,6 +1277,7 @@ def test_fetch_grocery_list_diffcal_mask_cached(self): self.instance.lookupDiffcalTableWorkspaceName = mock.Mock( return_value=wng.diffCalTable().runNumber(self.runNumber1).version(self.version).build() ) + self.instance.dataService.calibrationIndexer = self.mockIndexer("root", "diffraction") CloneWorkspace( InputWorkspace=self.sampleMaskWS, OutputWorkspace=diffCalMaskName, @@ -1275,8 +1302,10 @@ def test_fetch_grocery_list_diffcal_mask_cached(self): def test_fetch_grocery_list_diffcal_mask_loads_table(self): # Test of workspace type "diffcal_mask" as `Input` argument in the `GroceryList`: # * corresponding table workspace is also loaded from the hdf5-format file. + self.instance._fetchInstrumentDonor = mock.Mock(return_value=self.sampleWS) with state_root_redirect(self.instance.dataService) as tmpRoot: + self.instance.dataService.calibrationIndexer = self.mockIndexer(tmpRoot.path(), "diffraction") groceryList = GroceryListItem.builder().native().diffcal_mask(self.runNumber1, self.version).buildList() diffCalMaskName = wng.diffCalMask().runNumber(self.runNumber1).version(self.version).build() @@ -1302,6 +1331,7 @@ def test_fetch_grocery_list_normalization(self): # Test of workspace type "normalization" as `Input` argument in the `GroceryList` self.instance._fetchInstrumentDonor = mock.Mock(return_value=self.sampleWS) with state_root_redirect(self.instance.dataService) as tmpRoot: + self.instance.dataService.normalizationIndexer = self.mockIndexer(tmpRoot.path(), "normalization") groceryList = GroceryListItem.builder().native().normalization(self.runNumber1, self.version).buildList() # normalization filename is constructed @@ -1323,6 +1353,7 @@ def test_fetch_grocery_list_normalization_cached(self): # Test of workspace type "normalization" as `Input` argument in the `GroceryList`: # workspace already in ADS self.instance.grocer = mock.Mock() + self.instance.dataService.normalizationIndexer = self.mockIndexer() groceryList = GroceryListItem.builder().native().normalization(self.runNumber1, self.version).buildList() normalizationWorkspaceName = wng.rawVanadium().runNumber(self.runNumber1).version(self.version).build() CloneWorkspace( diff --git a/tests/unit/backend/data/test_Indexer.py b/tests/unit/backend/data/test_Indexer.py index 51071ef57..04eacf00d 100644 --- a/tests/unit/backend/data/test_Indexer.py +++ b/tests/unit/backend/data/test_Indexer.py @@ -10,7 +10,6 @@ from unittest import mock import pytest -from pydantic import ValidationError from util.dao import DAOFactory from snapred.backend.dao.calibration.Calibration import Calibration @@ -18,7 +17,7 @@ from snapred.backend.dao.indexing.CalculationParameters import CalculationParameters from snapred.backend.dao.indexing.IndexEntry import IndexEntry from snapred.backend.dao.indexing.Record import Record -from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT, VERSION_START +from snapred.backend.dao.indexing.Versioning import VERSION_START, VersionState from snapred.backend.dao.normalization.Normalization import Normalization from snapred.backend.dao.normalization.NormalizationRecord import NormalizationRecord from snapred.backend.data.Indexer import DEFAULT_RECORD_TYPE, Indexer, IndexerType @@ -262,14 +261,57 @@ def test_allVersions_some(self): def test_defaultVersion(self): indexer = self.initIndexer() - assert indexer.defaultVersion() == VERSION_DEFAULT + assert indexer.defaultVersion() == VERSION_START def test_currentVersion_none(self): # ensure the current version of an empty index is unitialized indexer = self.initIndexer() assert indexer.currentVersion() is None - # the path should go to the starting version - indexer.currentPath() == self.versionPath(VERSION_START) + # if there is no current version then there is no current path on disk. + with pytest.raises(ValueError, match=r".*The indexer has encountered an invalid version*"): + indexer.currentPath() == self.versionPath(VERSION_START) + + def test_flattenVersion(self): + indexer = self.initIndexer() + indexer.currentVersion = lambda: 3 + indexer.nextVersion = lambda: 4 + assert indexer._flattenVersion(VersionState.DEFAULT) == indexer.defaultVersion() + assert indexer._flattenVersion(VersionState.NEXT) == indexer.nextVersion() + assert indexer._flattenVersion(3) == 3 + + with pytest.raises(ValueError, match=r".*Version must be an int or*"): + indexer._flattenVersion(None) + + def test_writeNewVersion_noAppliesTo(self): + # ensure that a new record is written to disk + # and the index is updated to reflect the new record + indexer = self.initIndexer() + version = randint(2, 120) + record = self.record(version) + entry = self.indexEntryFromRecord(record) + entry.appliesTo = None + indexer.writeNewVersion(record, entry) + assert self.recordPath(version).exists() + assert indexer.index[version] == entry + + def test_writeNewVersion_recordAlreadyExists(self): + # ensure that a new record is written to disk + # and the index is updated to reflect the new record + indexer = self.initIndexer() + version = randint(2, 120) + record = self.record(version) + entry = self.indexEntryFromRecord(record) + indexer.writeNewVersion(record, entry) + assert self.recordPath(version).exists() + assert indexer.index[version] == entry + + # now write the record again + # ensure that the record is overwritten + record = self.record(version) + entry = self.indexEntryFromRecord(record) + + with pytest.raises(ValueError, match=".*already exists.*"): + indexer.writeNewVersion(record, entry) def test_currentVersion_add(self): # ensure current version advances when index entries are written @@ -378,23 +420,23 @@ def test_latestApplicableVersion_sorts_in_time(self): def test_latestApplicableVersion_returns_default(self): # ensure latest applicable version will be default if it is the only one runNumber = "123" - versionList = [VERSION_DEFAULT] + versionList = [VERSION_START] self.prepareVersions(versionList) indexer = self.initIndexer() # make it applicable - indexer.index[VERSION_DEFAULT].appliesTo = f">={runNumber}" + indexer.index[indexer.defaultVersion()].appliesTo = f">={runNumber}" # get latest apllicable latest = indexer.latestApplicableVersion(runNumber) - assert latest == VERSION_DEFAULT + assert latest == VERSION_START def test_latestApplicableVersion_excludes_default(self): # ensure latest applicable version will remove default if other runs exist runNumber = "123" - versionList = [VERSION_DEFAULT, 4, 5] + versionList = [VERSION_START, 4, 5] self.prepareVersions(versionList) indexer = self.initIndexer() # make some entries applicable - applicableVersions = [VERSION_DEFAULT, 4] + applicableVersions = [VERSION_START, 4] print(indexer.index) for version in applicableVersions: indexer.index[version].appliesTo = f">={runNumber}" @@ -402,21 +444,7 @@ def test_latestApplicableVersion_excludes_default(self): latest = indexer.latestApplicableVersion(runNumber) assert latest == applicableVersions[-1] - def test_thisOrCurrentVersion(self): - version = randint(20, 120) - indexer = self.initIndexer() - assert indexer.thisOrCurrentVersion(None) == indexer.currentVersion() - assert indexer.thisOrCurrentVersion(VERSION_DEFAULT) == VERSION_DEFAULT - assert indexer.thisOrCurrentVersion(version) == version - - def test_thisOrNextVersion(self): - version = randint(20, 120) - indexer = self.initIndexer() - assert indexer.thisOrNextVersion(None) == indexer.nextVersion() - assert indexer.thisOrNextVersion(VERSION_DEFAULT) == VERSION_DEFAULT - assert indexer.thisOrNextVersion(version) == version - - def test_thisOrLatestApplicableVersion(self): + def test_getLatestApplicableVersion(self): # make one applicable entry version1 = randint(1, 10) entry1 = self.indexEntry(version1) @@ -429,22 +457,19 @@ def test_thisOrLatestApplicableVersion(self): indexer = self.initIndexer() indexer.index = {version1: entry1, version2: entry2} # only the applicable entry is returned - assert indexer.thisOrLatestApplicableVersion("123", None) == version1 - assert indexer.thisOrLatestApplicableVersion("123", version1) == version1 - assert indexer.thisOrLatestApplicableVersion("123", version2) == version1 + assert indexer.latestApplicableVersion("123") == version1 def test_isValidVersion(self): indexer = self.initIndexer() # the good for i in range(10): - assert indexer.isValidVersion(randint(2, 120)) - assert indexer.isValidVersion(VERSION_DEFAULT) + assert indexer.validateVersion(randint(2, 120)) + assert indexer.validateVersion(VersionState.DEFAULT) # the bad - assert not indexer.isValidVersion("bad") - assert not indexer.isValidVersion("*") - assert not indexer.isValidVersion(None) - assert not indexer.isValidVersion(-2) - assert not indexer.isValidVersion(1.2) + badInput = ["bad", "*", None, -2, 1.2] + for i in badInput: + with pytest.raises(ValueError, match=r".*The indexer has encountered an invalid version*"): + indexer.validateVersion(i) def test_nextVersion(self): # check that the current version advances as expected as @@ -454,14 +479,10 @@ def test_nextVersion(self): expectedIndex = {} indexer = self.initIndexer() assert indexer.index == expectedIndex - # there is no current version assert indexer.currentVersion() is None - assert indexer.currentVersion() is None - # the first "next" version is the start assert indexer.nextVersion() == VERSION_START - assert indexer.nextVersion() == VERSION_START # add an entry to the calibration index here = VERSION_START @@ -470,18 +491,26 @@ def test_nextVersion(self): indexer.addIndexEntry(entry) expectedIndex[here] = entry assert indexer.index == expectedIndex - # the current version should be this version assert indexer.currentVersion() == here - assert indexer.currentVersion() == here # the next version also should be this version # until a record is written to disk assert indexer.nextVersion() == here - assert indexer.nextVersion() == here + expectedIndex.pop(here) + assert indexer.index == expectedIndex + + # NOTE: At this point, the index will have been purged because + # the entry did not have a corresponding record on disk. + # PREVIOUSLY: This had been testing against a ghost entry which + # did not represent a record on disk, + # but was still in the index. + # This is no longer the case. # now write the record + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> WRITE 1 <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< record = self.recordFromIndexEntry(entry) - indexer.writeRecord(record) + indexer.writeNewVersion(record, entry) + expectedIndex[here] = entry # the current version hasn't moved assert indexer.currentVersion() == here @@ -491,18 +520,12 @@ def test_nextVersion(self): assert indexer.currentVersion() == here assert indexer.nextVersion() == here + 1 - # add another entry - here = here + 1 - # ensure it is added at the next version + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> WRITE 2 <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< entry = self.indexEntry(indexer.nextVersion()) + expectedIndex[entry.version] = entry indexer.addIndexEntry(entry) - expectedIndex[here] = entry - assert indexer.index == expectedIndex - assert indexer.currentVersion() == here - # the next version should be here - assert indexer.nextVersion() == here - # now write the record indexer.writeRecord(self.recordFromIndexEntry(entry)) + here = here + 1 # ensure current still here assert indexer.currentVersion() == here # ensure next is after here @@ -511,37 +534,24 @@ def test_nextVersion(self): assert indexer.currentVersion() == here assert indexer.nextVersion() == here + 1 - # now write a record FIRST, at the next version - here = here + 1 - record = self.record(here) - indexer.writeRecord(record) - # the current version will point here - assert indexer.currentVersion() == here - assert indexer.currentVersion() == here - # the next version will point here - assert indexer.nextVersion() == here - assert indexer.nextVersion() == here - - # there is no index entry for this version - assert indexer.nextVersion() not in indexer.index + record = self.record(here + 1) + # NOTE: Writing aribitrary records to disk is not allowed. + # Our code should never produce an unindexed record. + # What value is there in writing a record that is + # inherently missing metadata supplied by the index? + with pytest.raises(ValueError, match=".*not found in index, please write an index entry first.*"): + indexer.writeRecord(record) - # add the entry - entry = self.indexEntryFromRecord(record) - indexer.addIndexEntry(entry) - expectedIndex[here] = entry - assert indexer.index == expectedIndex - # ensure current version points here, next points to next - assert indexer.currentVersion() == here - assert indexer.nextVersion() == here + 1 - assert indexer.currentVersion() == here - assert indexer.nextVersion() == here + 1 + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> WRITE 3 <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< # write a record first, at a much future version # then add an index entry, and ensure it matches here = here + 23 record = self.record(here) - indexer.writeRecord(record) - assert indexer.nextVersion() == here + entry = self.indexEntryFromRecord(record) + expectedIndex[entry.version] = entry + indexer.writeNewVersion(record, entry) + assert indexer.nextVersion() == here + 1 assert indexer.nextVersion() not in indexer.index # now add the entry @@ -562,49 +572,41 @@ def test_nextVersion_with_default_index_first(self): # there is no current version assert indexer.currentVersion() is None - # the first "next" version is the start assert indexer.nextVersion() == VERSION_START # add an entry at the default version - entry = self.indexEntry(VERSION_DEFAULT) - indexer.addIndexEntry(entry) - expectedIndex[VERSION_DEFAULT] = entry - assert indexer.index == expectedIndex - assert entry.version == VERSION_DEFAULT - - # the current version is now the default version - assert indexer.currentVersion() == VERSION_DEFAULT - # the next version also should be the default version - # until a record is written to disk - assert indexer.nextVersion() == VERSION_DEFAULT - - # now write the record -- it should write to default + entry = self.indexEntry(VersionState.DEFAULT) record = self.recordFromIndexEntry(entry) - indexer.writeRecord(record) - assert self.recordPath(VERSION_DEFAULT).exists() - + expectedIndex[VERSION_START] = entry + indexer.writeNewVersion(record, entry) + assert self.recordPath(indexer.defaultVersion()).exists() # the current version is still the default version - assert indexer.currentVersion() == VERSION_DEFAULT - # the next version will be the starting version - assert indexer.nextVersion() == VERSION_START + assert indexer.currentVersion() == indexer.defaultVersion() + # the next version will be the starting version + 1 + assert indexer.nextVersion() == VERSION_START + 1 # add another entry -- now at the start entry = self.indexEntry(indexer.nextVersion()) + expectedIndex[indexer.nextVersion()] = entry indexer.addIndexEntry(entry) - expectedIndex[VERSION_START] = entry assert indexer.index == expectedIndex - assert indexer.currentVersion() == VERSION_START + assert indexer.currentVersion() == VERSION_START + 1 + # the next version should be the starting version # until a record is written - assert indexer.nextVersion() == VERSION_START - # now write the record -- ensure it is written at the start - indexer.writeRecord(self.recordFromIndexEntry(entry)) + assert indexer.nextVersion() == VERSION_START + 1 + expectedIndex.pop(entry.version) + assert indexer.index == expectedIndex + # now write the record -- ensure it is written at the + record = self.recordFromIndexEntry(entry) + entry = self.indexEntryFromRecord(record) + indexer.writeNewVersion(record, entry) assert self.recordPath(VERSION_START).exists() # ensure current still here - assert indexer.currentVersion() == VERSION_START + assert indexer.currentVersion() == VERSION_START + 1 # ensure next is after here - assert indexer.nextVersion() == VERSION_START + 1 + assert indexer.nextVersion() == VERSION_START + 2 def test_nextVersion_with_default_record_first(self): # check default behaves correctly if a record is written first @@ -619,23 +621,15 @@ def test_nextVersion_with_default_record_first(self): assert indexer.nextVersion() == VERSION_START # add a record at the default version - record = self.record(VERSION_DEFAULT) - indexer.writeRecord(record) - - # the current version is now the default version - assert indexer.currentVersion() == VERSION_DEFAULT - # the next version also should be the default version - # until an entry is written to disk - assert indexer.nextVersion() == VERSION_DEFAULT - - # now write the index entry -- it should write to default + record = self.record(VersionState.DEFAULT) entry = self.indexEntryFromRecord(record) - indexer.addIndexEntry(entry) + assert indexer._flattenVersion(record.version) == VERSION_START + indexer.writeNewVersion(record, entry) # the current version is still the default version - assert indexer.currentVersion() == VERSION_DEFAULT + assert indexer.currentVersion() == VERSION_START # the next version will be one past the starting version - assert indexer.nextVersion() == VERSION_START + assert indexer.nextVersion() == VERSION_START + 1 ### TESTS OF VERSION COMPARISON METHODS ### @@ -689,9 +683,8 @@ def test_versionPath(self): self.writeRecordVersion(version) indexer = self.initIndexer() - # if version path is unitialized, path points to version start - ans1 = indexer.versionPath(None) - assert ans1 == self.versionPath(VERSION_START) + with pytest.raises(ValueError, match=".*The indexer has encountered an invalid version*"): + indexer.versionPath(None) # if version is specified, return that one for i in versionList: @@ -703,6 +696,7 @@ def test_currentPath(self): versionList = [3, 4, 5] for version in versionList: self.writeRecordVersion(version) + self.prepareIndex(versionList) indexer = self.initIndexer() indexer.currentPath() == self.versionPath(max(versionList)) @@ -720,7 +714,7 @@ def test_latestApplicablePath(self): indexer.index[version].appliesTo = f">={runNumber}" print(indexer.index) latest = indexer.latestApplicableVersion(runNumber) - assert indexer.latestApplicablePath(runNumber) == self.versionPath(latest) + assert indexer.getLatestApplicablePath(runNumber) == self.versionPath(latest) ### TEST INDEX MANIPULATION METHODS ### @@ -765,19 +759,11 @@ def test_addEntry_writes(self): readIndex = parse_file_as(List[IndexEntry], indexer.indexPath()) assert readIndex == list(indexer.index.values()) - def test_addEntry_fails(self): - # adding an entry with a bad version will fail - indexer = self.initIndexer() - indexer.isValidVersion = lambda x: False # now it will always return false - entry = self.indexEntry() - with pytest.raises(RuntimeError): - indexer.addIndexEntry(entry) - def test_addEntry_default(self): indexer = self.initIndexer() - entry = self.indexEntry(VERSION_DEFAULT) + entry = self.indexEntry(indexer.defaultVersion()) indexer.addIndexEntry(entry) - assert VERSION_DEFAULT in indexer.index + assert indexer.defaultVersion() in indexer.index def test_addEntry_advances(self): # adding an index entry advances the current version @@ -831,7 +817,8 @@ def test_readWriteRecord_next_version(self): # make sure the record was saved at the next version # and the read / written records match record = self.record(nextVersion) - indexer.writeRecord(record) + entry = self.indexEntryFromRecord(record) + indexer.writeNewVersion(record, entry) res = indexer.readRecord(nextVersion) assert record.version == nextVersion assert res == record @@ -840,11 +827,12 @@ def test_readWriteRecord_any_version(self): # write a record at some version number version = randint(10, 20) record = self.record(version) + entry = self.indexEntryFromRecord(record) indexer = self.initIndexer() # write then read the record # make sure the record version was updated # and the read / written records match - indexer.writeRecord(record) + indexer.writeNewVersion(record, entry) res = indexer.readRecord(version) assert record.version == version assert res == record @@ -855,8 +843,8 @@ def test_readRecord_none(self): version = randint(1, 11) indexer = self.initIndexer() assert not self.recordPath(version).exists() - res = indexer.readRecord(version) - assert res is None + with pytest.raises(FileNotFoundError, match=r".*No record found at*"): + indexer.readRecord(version) def test_readRecord(self): record = self.record(randint(1, 100)) @@ -870,25 +858,19 @@ def test_readRecord_invalid_version(self): record = self.record(randint(1, 100)) self.writeRecord(record) indexer = self.initIndexer() - with pytest.raises(ValidationError): + with pytest.raises(ValueError, match=r".*The indexer has encountered an invalid version*"): indexer.readRecord("*") # write # - def test_writeRecord_fails(self): - record = self.record() - indexer = self.initIndexer() - indexer.isValidVersion = lambda x: False - with pytest.raises(RuntimeError): - indexer.writeRecord(record) - def test_writeRecord_with_version(self): # this test ensures a record can be written to the indicated version # create a record and write it version = randint(2, 120) record = self.record(version) + entry = self.indexEntryFromRecord(record) indexer = self.initIndexer() - indexer.writeRecord(record) + indexer.writeNewVersion(record, entry) assert record.version == version assert self.recordPath(version).exists() # read it back in and ensure it is the same @@ -909,7 +891,8 @@ def test_writeRecord_next_version(self): assert nextVersion != VERSION_START # now write the record record = self.record(nextVersion) - indexer.writeRecord(record) + entry = self.indexEntryFromRecord(record) + indexer.writeNewVersion(record, entry) assert record.version == nextVersion assert self.recordPath(nextVersion).exists() res = parse_file_as(Record, self.recordPath(nextVersion)) @@ -924,10 +907,11 @@ def test_readWriteRecord_calibration(self): # prepare the record record = DAOFactory.calibrationRecord() record.version = randint(2, 100) + entry = self.indexEntryFromRecord(record) # write then read in the record indexer = self.initIndexer(IndexerType.CALIBRATION) - indexer.writeRecord(record) - res = indexer.readRecord() + indexer.writeNewVersion(record, entry) + res = indexer.readRecord(record.version) assert type(res) is CalibrationRecord assert res == record @@ -935,10 +919,11 @@ def test_readWriteRecord_normalization(self): # prepare the record record = DAOFactory.normalizationRecord() record.version = randint(2, 100) + entry = self.indexEntryFromRecord(record) # write then read in the record indexer = self.initIndexer(IndexerType.NORMALIZATION) - indexer.writeRecord(record) - res = indexer.readRecord() + indexer.writeNewVersion(record, entry) + res = indexer.readRecord(record.version) assert type(res) is NormalizationRecord assert res == record @@ -950,13 +935,6 @@ def test_readParameters_nope(self): with pytest.raises(FileNotFoundError): indexer.readParameters(1) - def test_writeParameters_fails(self): - params = self.calculationParameters(randint(2, 10)) - indexer = self.initIndexer() - indexer.isValidVersion = lambda x: False - with pytest.raises(RuntimeError): - indexer.writeParameters(params) - def test_readWriteParameters(self): version = randint(1, 10) params = self.calculationParameters(version) @@ -990,7 +968,7 @@ def test_readWriteParameters_calibration(self): params = DAOFactory.calibrationParameters() indexer = self.initIndexer(IndexerType.CALIBRATION) indexer.writeParameters(params) - res = indexer.readParameters() + res = indexer.readParameters(params.version) assert type(res) is Calibration assert res == params @@ -998,7 +976,7 @@ def test_readWriteParameters_normalization(self): params = DAOFactory.normalizationParameters() indexer = self.initIndexer(IndexerType.NORMALIZATION) indexer.writeParameters(params) - res = indexer.readParameters() + res = indexer.readParameters(params.version) assert type(res) is Normalization assert res == params @@ -1006,10 +984,12 @@ def test_readWriteParameters_reduction(self): params = CalculationParameters(**DAOFactory.calibrationParameters().model_dump()) indexer = self.initIndexer(IndexerType.REDUCTION) indexer.writeParameters(params) - res = indexer.readParameters() + res = indexer.readParameters(params.version) assert type(res) is CalculationParameters assert res == params def test__determineRecordType(self): indexer = self.initIndexer(IndexerType.CALIBRATION) - assert indexer._determineRecordType(VERSION_DEFAULT) == DEFAULT_RECORD_TYPE.get(IndexerType.CALIBRATION) + assert indexer._determineRecordType(indexer.defaultVersion()) == DEFAULT_RECORD_TYPE.get( + IndexerType.CALIBRATION + ) diff --git a/tests/unit/backend/data/test_LocalDataService.py b/tests/unit/backend/data/test_LocalDataService.py index a54a86ea7..2152c3f71 100644 --- a/tests/unit/backend/data/test_LocalDataService.py +++ b/tests/unit/backend/data/test_LocalDataService.py @@ -45,7 +45,7 @@ from snapred.backend.dao.calibration.CalibrationRecord import CalibrationRecord from snapred.backend.dao.GroupPeakList import GroupPeakList from snapred.backend.dao.indexing.IndexEntry import IndexEntry -from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT +from snapred.backend.dao.indexing.Versioning import VERSION_START, VersionState from snapred.backend.dao.ingredients import ReductionIngredients from snapred.backend.dao.normalization.NormalizationRecord import NormalizationRecord from snapred.backend.dao.reduction.ReductionRecord import ReductionRecord @@ -129,6 +129,18 @@ def _capture_logging(monkeypatch): fakeInstrumentFilePath = Resource.getPath("inputs/testInstrument/fakeSNAP_Definition.xml") + +def entryFromRecord(record): + return IndexEntry( + runNumber=record.runNumber, + useLiteMode=record.useLiteMode, + appliesTo=record.runNumber, + comments="test comment", + author="test author", + version=record.version, + ) + + ### GENERALIZED METHODS FOR TESTING NORMALIZATION / CALIBRATION METHODS ### # Note: the REDUCTION workflow does not use the Indexer system except indirectly. @@ -231,8 +243,13 @@ def do_test_read_state_no_version(workflow: Literal["Calibration", "Normalizatio expectedState = getattr(DAOFactory, f"{workflow.lower()}Parameters")() indexer = localDataService.indexer("xyz", useLiteMode, workflow) tmpRoot.saveObjectAt(expectedState, indexer.parametersPath(currentVersion)) - indexer.index = {currentVersion: mock.Mock()} # NOTE manually update indexer - actualState = getattr(localDataService, f"read{workflow}State")("xyz", useLiteMode) # NOTE no version + indexer.index = { + currentVersion: mock.MagicMock(appliesTo="123", version=currentVersion) + } # NOTE manually update indexer + indexer.dirVersions = [currentVersion] # NOTE manually update indexer + actualState = getattr(localDataService, f"read{workflow}State")( + "123", useLiteMode, VersionState.LATEST + ) # NOTE no version assert actualState == expectedState @@ -293,13 +310,16 @@ def getMockInstrumentConfig(): def test_readStateConfig_default(): # readstateConfig will load the default parameters file groupingMap = DAOFactory.groupingMap_SNAP() - parameters = DAOFactory.calibrationParameters("57514", True, VERSION_DEFAULT) localDataService = LocalDataService() with state_root_redirect(localDataService) as tmpRoot: indexer = localDataService.calibrationIndexer("57514", True) + parameters = DAOFactory.calibrationParameters("57514", True, indexer.defaultVersion()) tmpRoot.saveObjectAt(groupingMap, localDataService._groupingMapPath(tmpRoot.stateId)) - tmpRoot.saveObjectAt(parameters, indexer.parametersPath(VERSION_DEFAULT)) - indexer.index = {VERSION_DEFAULT: mock.Mock()} # NOTE manually update the Indexer + tmpRoot.saveObjectAt(parameters, indexer.parametersPath(indexer.defaultVersion())) + + indexer.index = { + VersionState.DEFAULT: mock.MagicMock(appliesTo="57514", version=indexer.defaultVersion()) + } # NOTE manually update the Indexer actual = localDataService.readStateConfig("57514", True) assert actual is not None assert actual.stateId == DAOFactory.magical_state_id @@ -315,7 +335,9 @@ def test_readStateConfig_previous(): indexer = localDataService.calibrationIndexer("57514", True) tmpRoot.saveObjectAt(groupingMap, localDataService._groupingMapPath(tmpRoot.stateId)) tmpRoot.saveObjectAt(parameters, indexer.parametersPath(version)) - indexer.index = {version: mock.Mock()} # NOTE manually update the Indexer + indexer.index = { + version: mock.MagicMock(appliesTo="57514", version=version) + } # NOTE manually update the Indexer actual = localDataService.readStateConfig("57514", True) assert actual is not None assert actual.stateId == DAOFactory.magical_state_id @@ -331,7 +353,9 @@ def test_readStateConfig_attaches_grouping_map(): indexer = localDataService.calibrationIndexer("57514", True) tmpRoot.saveObjectAt(groupingMap, localDataService._groupingMapPath(tmpRoot.stateId)) tmpRoot.saveObjectAt(parameters, indexer.parametersPath(version)) - indexer.index = {version: mock.Mock()} # NOTE manually update the Indexer + indexer.index = { + version: mock.MagicMock(appliesTo="57514", version=version) + } # NOTE manually update the Indexer actual = localDataService.readStateConfig("57514", True) expectedMap = DAOFactory.groupingMap_SNAP() assert actual.groupingMap == expectedMap @@ -348,7 +372,9 @@ def test_readStateConfig_invalid_grouping_map(): indexer = localDataService.calibrationIndexer("57514", True) tmpRoot.saveObjectAt(groupingMap, localDataService._groupingMapPath(tmpRoot.stateId)) tmpRoot.saveObjectAt(parameters, indexer.parametersPath(version)) - indexer.index = {version: mock.Mock()} # NOTE manually update the Indexer + indexer.index = { + version: mock.MagicMock(appliesTo="57514", version=version) + } # NOTE manually update the Indexer # 'GroupingMap.defaultStateId' is _not_ a valid grouping-map 'stateId' for an existing `StateConfig`. with pytest.raises( # noqa: PT012 RuntimeError, @@ -366,7 +392,9 @@ def test_readStateConfig_calls_prepareStateRoot(): with state_root_redirect(localDataService, stateId=expected.instrumentState.id.hex) as tmpRoot: indexer = localDataService.calibrationIndexer("57514", True) tmpRoot.saveObjectAt(expected, indexer.parametersPath(version)) - indexer.index = {version: mock.Mock()} # NOTE manually update the Indexer + indexer.index = { + version: mock.MagicMock(appliesTo="57514", version=version) + } # NOTE manually update the Indexer assert not localDataService._groupingMapPath(tmpRoot.stateId).exists() localDataService._prepareStateRoot = mock.Mock( side_effect=lambda x: tmpRoot.saveObjectAt( # noqa ARG005 @@ -702,8 +730,9 @@ def test_write_model_pretty_StateConfig_excludes_grouping_map(): with state_root_redirect(localDataService) as tmpRoot: # move the calculation parameters into correct folder indexer = localDataService.calibrationIndexer("57514", True) - indexer.writeParameters(DAOFactory.calibrationParameters("57514", True, VERSION_DEFAULT)) - indexer.index = {VERSION_DEFAULT: mock.Mock()} + indexer.writeParameters(DAOFactory.calibrationParameters("57514", True, indexer.defaultVersion())) + indexer.index = {indexer.defaultVersion(): mock.MagicMock(appliesTo="57514", version=indexer.defaultVersion())} + # move the grouping map into correct folder write_model_pretty(DAOFactory.groupingMap_SNAP(), localDataService._groupingMapPath(tmpRoot.stateId)) @@ -1168,10 +1197,11 @@ def test_createCalibrationIndexEntry(): assert ans.useLiteMode == request.useLiteMode assert ans.version == request.version - request.version = None - indexer = localDataService.calibrationIndexer(request.runNumber, request.useLiteMode) + request.version = VersionState.NEXT + localDataService.calibrationIndexer(request.runNumber, request.useLiteMode) ans = localDataService.createCalibrationIndexEntry(request) - assert ans.version == indexer.nextVersion() + # Set to next version, which on the first call should be the start version + assert ans.version == VERSION_START def test_createCalibrationRecord(): @@ -1185,10 +1215,11 @@ def test_createCalibrationRecord(): assert ans.useLiteMode == request.useLiteMode assert ans.version == request.version - request.version = None - indexer = localDataService.calibrationIndexer(request.runNumber, request.useLiteMode) + request.version = VersionState.NEXT + localDataService.calibrationIndexer(request.runNumber, request.useLiteMode) ans = localDataService.createCalibrationRecord(request) - assert ans.version == indexer.nextVersion() + # Set to next version, which on the first call should be the start version + assert ans.version == VERSION_START def test_readCalibrationRecord_with_version(): @@ -1214,6 +1245,8 @@ def test_readWriteCalibrationRecord(): for useLiteMode in [True, False]: record = DAOFactory.calibrationRecord("57514", useLiteMode, version=1) with state_root_redirect(localDataService): + entry = entryFromRecord(record) + localDataService.writeCalibrationIndexEntry(entry) localDataService.writeCalibrationRecord(record) actualRecord = localDataService.readCalibrationRecord("57514", useLiteMode) assert actualRecord.version == record.version @@ -1311,9 +1344,10 @@ def test_createNormalizationIndexEntry(): assert ans.version == request.version request.version = None - indexer = localDataService.normalizationIndexer(request.runNumber, request.useLiteMode) + localDataService.normalizationIndexer(request.runNumber, request.useLiteMode) ans = localDataService.createNormalizationIndexEntry(request) - assert ans.version == indexer.nextVersion() + # Set to next version, which on the first call should be the start version + assert ans.version == VERSION_START def test_createNormalizationRecord(): @@ -1327,10 +1361,9 @@ def test_createNormalizationRecord(): assert ans.useLiteMode == request.useLiteMode assert ans.version == request.version - request.version = None - indexer = localDataService.normalizationIndexer(request.runNumber, request.useLiteMode) + request.version = VersionState.NEXT ans = localDataService.createNormalizationRecord(request) - assert ans.version == indexer.nextVersion() + assert ans.version == VERSION_START def test_readNormalizationRecord_with_version(): @@ -1355,11 +1388,18 @@ def test_readWriteNormalizationRecord(): localDataService = LocalDataService() for useLiteMode in [True, False]: record.useLiteMode = useLiteMode + currentVersion = randint(VERSION_START, 120) + runNumber = record.runNumber + record.version = currentVersion # NOTE redirect nested so assertion occurs outside of redirect # failing assertions inside tempdirs can create unwanted files with state_root_redirect(localDataService): - localDataService.writeNormalizationRecord(record) - actualRecord = localDataService.readNormalizationRecord("57514", useLiteMode) + entry = entryFromRecord(record) + localDataService.writeNormalizationRecord(record, entry) + indexer = localDataService.normalizationIndexer(runNumber, useLiteMode) + + indexer.index = {currentVersion: mock.MagicMock(appliesTo=runNumber, version=currentVersion)} + actualRecord = localDataService.readNormalizationRecord(runNumber, useLiteMode) assert actualRecord.version == record.version assert actualRecord.calculationParameters.version == record.calculationParameters.version assert actualRecord == record @@ -1507,7 +1547,7 @@ def test_readWriteReductionRecord(): localDataService.groceryService = mock.Mock() localDataService.writeReductionRecord(testRecord) actualRecord = localDataService.readReductionRecord(runNumber, testRecord.useLiteMode, testRecord.timestamp) - assert actualRecord == testRecord + assert actualRecord.dict() == testRecord.dict() @pytest.fixture @@ -1833,7 +1873,9 @@ def test_readWriteReductionData(readSyntheticReductionRecord, createReductionWor cleanup_workspace_at_exit(_uniquePrefix + ws) actualRecord = localDataService.readReductionData(runNumber, useLiteMode, timestamp) - assert actualRecord == testRecord + + assert actualRecord.normalization.calibrationVersionUsed == testRecord.normalization.calibrationVersionUsed + assert actualRecord.dict() == testRecord.dict() # workspaces should have been reloaded with their original names # Implementation note: @@ -2017,16 +2059,17 @@ def test_readNormalizationState_no_version(): def test_readWriteCalibrationState(): - # NOTE this test is already covered by tests of the Indexer - # but it doesn't hurt to retain this test anyway runNumber = "123" localDataService = LocalDataService() - for useLiteMode in [True, False]: - calibration = DAOFactory.calibrationParameters(runNumber, useLiteMode) - with state_root_redirect(localDataService): - localDataService.writeCalibrationState(calibration) - ans = localDataService.readCalibrationState(runNumber, useLiteMode) - assert ans == calibration + mockCalibrationIndexer = mock.Mock() + + localDataService.calibrationIndexer = mock.Mock(return_value=mockCalibrationIndexer) + localDataService.calibrationIndexer().latestApplicableVersion = mock.Mock(return_value=1) + mockCalibrationIndexer.nextVersion = mock.Mock(return_value=1) + + ans = localDataService.readCalibrationState(runNumber, True, VersionState.LATEST) + assert ans == mockCalibrationIndexer.readParameters.return_value + mockCalibrationIndexer.readParameters.assert_called_once_with(1) def test_readWriteCalibrationState_noWritePermissions(): @@ -2055,7 +2098,7 @@ def test_readCalibrationState_hasWritePermissions(): def test_writeDefaultDiffCalTable(fetchInstrumentDonor, createDiffCalTableWorkspaceName): # verify that the default diffcal table is being written to the default state directory runNumber = "default" - version = VERSION_DEFAULT + version = VERSION_START useLiteMode = True # mock the grocery service to return the fake instrument to use for geometry idfWS = mtd.unique_name(prefix="_idf_") @@ -2081,12 +2124,15 @@ def test_readWriteNormalizationState(): # but it doesn't hurt to retain this test anyway runNumber = "123" localDataService = LocalDataService() - for useLiteMode in [True, False]: - normalization = DAOFactory.normalizationParameters(runNumber, useLiteMode) - with state_root_redirect(localDataService): - localDataService.writeNormalizationState(normalization) - ans = localDataService.readNormalizationState(runNumber, useLiteMode) - assert ans == normalization + mockNormalizationIndexer = mock.Mock() + + localDataService.normalizationIndexer = mock.Mock(return_value=mockNormalizationIndexer) + localDataService.normalizationIndexer().latestApplicableVersion = mock.Mock(return_value=1) + mockNormalizationIndexer.nextVersion = mock.Mock(return_value=1) + + ans = localDataService.readNormalizationState(runNumber, True, VersionState.LATEST) + assert ans == mockNormalizationIndexer.readParameters.return_value + mockNormalizationIndexer.readParameters.assert_called_once_with(1) def test_readDetectorState(): @@ -2279,7 +2325,7 @@ def test_initializeState(): testCalibrationData = DAOFactory.calibrationParameters( runNumber=runNumber, useLiteMode=useLiteMode, - version=VERSION_DEFAULT, + version=VERSION_START, instrumentState=DAOFactory.pv_instrument_state.copy(), ) diff --git a/tests/unit/backend/service/test_CalibrationService.py b/tests/unit/backend/service/test_CalibrationService.py index 3a993823d..4524a8b22 100644 --- a/tests/unit/backend/service/test_CalibrationService.py +++ b/tests/unit/backend/service/test_CalibrationService.py @@ -89,25 +89,13 @@ def test_exportCalibrationIndex(): calibrationService.dataExportService.exportCalibrationIndexEntry = mock.Mock() calibrationService.dataExportService.exportCalibrationIndexEntry.return_value = "expected" calibrationService.saveCalibrationToIndex( - IndexEntry(runNumber="1", useLiteMode=True, comments="", author=""), + IndexEntry(runNumber="1", useLiteMode=True, comments="", author="", version=1), ) assert calibrationService.dataExportService.exportCalibrationIndexEntry.called savedEntry = calibrationService.dataExportService.exportCalibrationIndexEntry.call_args.args[0] assert savedEntry.appliesTo == ">=1" assert savedEntry.timestamp is not None - def test_exportCalibrationIndex_no_timestamp(): - calibrationService = CalibrationService() - calibrationService.dataExportService.exportCalibrationIndexEntry = mock.Mock() - calibrationService.dataExportService.exportCalibrationIndexEntry.return_value = "expected" - calibrationService.dataExportService.getUniqueTimestamp = mock.Mock(return_value=123.123) - entry = IndexEntry(runNumber="1", useLiteMode=True, comments="", author="") - entry.timestamp = None - calibrationService.saveCalibrationToIndex(entry) - assert calibrationService.dataExportService.exportCalibrationIndexEntry.called - savedEntry = calibrationService.dataExportService.exportCalibrationIndexEntry.call_args.args[0] - assert savedEntry.timestamp == calibrationService.dataExportService.getUniqueTimestamp.return_value - def test_save(): workspace = mtd.unique_name(prefix="_dsp_") CreateSingleValuedWorkspace(OutputWorkspace=workspace) @@ -117,6 +105,7 @@ def test_save(): calibrationService.dataExportService.exportCalibrationWorkspaces = mock.Mock() calibrationService.dataExportService.exportCalibrationIndexEntry = mock.Mock() calibrationService.dataFactoryService.createCalibrationIndexEntry = mock.Mock() + calibrationService.dataFactoryService.calibrationExists = mock.Mock(return_value=False) calibrationService.dataFactoryService.createCalibrationRecord = mock.Mock( return_value=mock.Mock( runNumber="012345", @@ -138,6 +127,7 @@ def test_save_unexpected_units(): calibrationService.dataExportService.exportCalibrationWorkspaces = mock.Mock() calibrationService.dataExportService.exportCalibrationIndexEntry = mock.Mock() calibrationService.dataFactoryService.createCalibrationIndexEntry = mock.Mock() + calibrationService.dataFactoryService.calibrationExists = mock.Mock(return_value=False) calibrationService.dataFactoryService.createCalibrationRecord = mock.Mock( return_value=mock.Mock( runNumber="012345", @@ -157,6 +147,7 @@ def test_save_unexpected_type(): calibrationService.dataExportService.exportCalibrationWorkspaces = mock.Mock() calibrationService.dataExportService.exportCalibrationIndexEntry = mock.Mock() calibrationService.dataFactoryService.createCalibrationIndexEntry = mock.Mock() + calibrationService.dataFactoryService.calibrationExists = mock.Mock(return_value=False) calibrationService.dataFactoryService.createCalibrationRecord = mock.Mock( return_value=mock.Mock( runNumber="012345", @@ -176,7 +167,7 @@ def test_load(): def test_getCalibrationIndex(): calibrationService = CalibrationService() calibrationService.dataFactoryService.getCalibrationIndex = mock.Mock( - return_value=IndexEntry(runNumber="1", useLiteMode=True, comments="", author="") + return_value=IndexEntry(runNumber="1", useLiteMode=True, comments="", author="", version=1) ) calibrationService.getCalibrationIndex(mock.MagicMock(run=mock.MagicMock(runNumber="123"))) assert calibrationService.dataFactoryService.getCalibrationIndex.called @@ -458,9 +449,8 @@ def test_load_quality_assessment_no_calibration_record_exception(self): version=self.version, checkExistent=False, ) - with pytest.raises(ValueError) as excinfo: # noqa: PT011 + with pytest.raises(FileNotFoundError) as excinfo: # noqa: PT011 self.instance.loadQualityAssessment(mockRequest) - assert str(mockRequest.runId) in str(excinfo.value) assert str(mockRequest.version) in str(excinfo.value) @mock.patch(thisService + "CalibrationMetricsWorkspaceIngredients") @@ -1046,7 +1036,7 @@ def test_fitPeaks(self, FitMultiplePeaksRecipe): assert res == FitMultiplePeaksRecipe.return_value.executeRecipe.return_value def test_matchRuns(self): - self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock( + self.instance.dataFactoryService.getLatestApplicableCalibrationVersion = mock.Mock( side_effect=[mock.sentinel.version1, mock.sentinel.version2, mock.sentinel.version3], ) request = mock.Mock(runNumbers=[mock.sentinel.run1, mock.sentinel.run2], useLiteMode=True) diff --git a/tests/unit/backend/service/test_NormalizationService.py b/tests/unit/backend/service/test_NormalizationService.py index 4946ec3d5..dbc1076ec 100644 --- a/tests/unit/backend/service/test_NormalizationService.py +++ b/tests/unit/backend/service/test_NormalizationService.py @@ -10,6 +10,7 @@ mtd, ) +from snapred.backend.dao.indexing.Versioning import VersionState from snapred.backend.dao.request import CalibrationWritePermissionsRequest from snapred.backend.dao.response.NormalizationResponse import NormalizationResponse from snapred.backend.error.ContinueWarning import ContinueWarning @@ -63,7 +64,7 @@ def test_exportNormalizationIndexEntry(): normalizationService.dataExportService.exportNormalizationIndexEntry = MagicMock() normalizationService.dataExportService.exportNormalizationIndexEntry.return_value = "expected" normalizationService.saveNormalizationToIndex( - IndexEntry(runNumber="1", useLiteMode=True, backgroundRunNumber="2") + IndexEntry(runNumber="1", useLiteMode=True, backgroundRunNumber="2", version=VersionState.NEXT) ) assert normalizationService.dataExportService.exportNormalizationIndexEntry.called savedEntry = normalizationService.dataExportService.exportNormalizationIndexEntry.call_args.args[0] @@ -210,7 +211,7 @@ def test_smoothDataExcludingPeaks( ) def test_matchRuns(self): - self.instance.dataFactoryService.getThisOrLatestNormalizationVersion = mock.Mock( + self.instance.dataFactoryService.getLatestApplicableNormalizationVersion = mock.Mock( side_effect=[mock.sentinel.version1, mock.sentinel.version2], ) request = mock.Mock(runNumbers=[mock.sentinel.run1, mock.sentinel.run2], useLiteMode=True) @@ -276,7 +277,7 @@ def test_normalization( self.instance.sousChef = SculleryBoy() self.instance.groceryService = mockGroceryService self.instance.dataFactoryService.getCifFilePath = MagicMock(return_value="path/to/cif") - self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=1) + self.instance.dataFactoryService.getLatestApplicableCalibrationVersion = mock.Mock(return_value=1) self.instance.dataExportService.getCalibrationStateRoot = mock.Mock(return_value="lah/dee/dah") self.instance.dataFactoryService.calibrationExists = mock.Mock(return_value=True) self.instance.dataFactoryService.getCalibrationRecord = mock.Mock(return_value=mock.Mock(runNumber="12345")) @@ -297,7 +298,7 @@ def test_validateRequest(self): # test `validateRequest` internal calls self.instance._sameStates = mock.Mock(return_value=True) self.instance.dataFactoryService.calibrationExists = mock.Mock(return_value=True) - self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=1) + self.instance.dataFactoryService.getLatestApplicableCalibrationVersion = mock.Mock(return_value=1) permissionsRequest = CalibrationWritePermissionsRequest( runNumber=self.request.runNumber, continueFlags=self.request.continueFlags ) @@ -309,7 +310,7 @@ def test_validateRequest(self): def test_validateDiffractionCalibrationExists_failure(self): request = mock.Mock(runNumber="12345", backgroundRunNumber="67890", continueFlags=ContinueWarning.Type.UNSET) self.instance.sousChef = SculleryBoy() - self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=-1) + self.instance.dataFactoryService.getLatestApplicableCalibrationVersion = mock.Mock(return_value=None) with pytest.raises( ContinueWarning, @@ -324,7 +325,7 @@ def test_validateDiffractionCalibrationExists_success_contineuAnyway(self): continueFlags=ContinueWarning.Type.DEFAULT_DIFFRACTION_CALIBRATION, ) self.instance.sousChef = SculleryBoy() - self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=-1) + self.instance.dataFactoryService.getLatestApplicableCalibrationVersion = mock.Mock(return_value=None) self.instance._validateDiffractionCalibrationExists(request) def test_validateRequest_different_states(self): @@ -370,7 +371,7 @@ def test_cachedNormalization(self, mockFarmFreshIngredients): self.instance.dataFactoryService.getCifFilePath = mock.Mock(return_value="path/to/cif") self.instance.dataExportService.getCalibrationStateRoot = mock.Mock(return_value="lah/dee/dah") self.instance.dataFactoryService.calibrationExists = mock.Mock(return_value=True) - self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=1) + self.instance.dataFactoryService.getLatestApplicableCalibrationVersion = mock.Mock(return_value=1) self.instance.dataExportService.checkWritePermissions = mock.Mock(return_value=True) result = self.instance.normalization(self.request) diff --git a/tests/unit/backend/service/test_ReductionService.py b/tests/unit/backend/service/test_ReductionService.py index e66372a72..c92f95856 100644 --- a/tests/unit/backend/service/test_ReductionService.py +++ b/tests/unit/backend/service/test_ReductionService.py @@ -125,8 +125,8 @@ def test_prepReductionIngredients(self): assert result == expected def test_fetchReductionGroceries(self): - self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=1) - self.instance.dataFactoryService.getThisOrLatestNormalizationVersion = mock.Mock(return_value=1) + self.instance.dataFactoryService.getLatestApplicableCalibrationVersion = mock.Mock(return_value=1) + self.instance.dataFactoryService.getLatestApplicableNormalizationVersion = mock.Mock(return_value=1) self.instance._markWorkspaceMetadata = mock.Mock() self.request.continueFlags = ContinueWarning.Type.UNSET res = self.instance.fetchReductionGroceries(self.request) @@ -142,10 +142,10 @@ def test_reduction(self, mockReductionRecipe): "outputs": ["one", "two", "three"], } mockReductionRecipe.return_value.cook = mock.Mock(return_value=mockResult) - self.instance.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=1) + self.instance.dataFactoryService.getLatestApplicableCalibrationVersion = mock.Mock(return_value=1) self.instance.dataFactoryService.stateExists = mock.Mock(return_value=True) self.instance.dataFactoryService.calibrationExists = mock.Mock(return_value=True) - self.instance.dataFactoryService.getThisOrLatestNormalizationVersion = mock.Mock(return_value=1) + self.instance.dataFactoryService.getLatestApplicableNormalizationVersion = mock.Mock(return_value=1) self.instance.dataFactoryService.normalizationExists = mock.Mock(return_value=True) self.instance._markWorkspaceMetadata = mock.Mock() @@ -326,7 +326,7 @@ def test_groupRequests(self): # Verify the request is sorted by state id then normalization version mockDataFactory = mock.Mock() - mockDataFactory.getThisOrCurrentNormalizationVersion.side_effect = [0, 1] + mockDataFactory.getLatestApplicableNormalizationVersion.side_effect = [0, 1] mockDataFactory.constructStateId.return_value = ("state1", "_") self.instance.dataFactoryService = mockDataFactory @@ -680,8 +680,8 @@ def trackFetchGroceryDict(*args, **kwargs): pixelMasks=[self.maskWS1, self.maskWS2, self.maskWS5], focusGroups=[FocusGroup(name="apple", definition="path/to/grouping")], ) - self.service.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=1) - self.service.dataFactoryService.getThisOrLatestNormalizationVersion = mock.Mock(return_value=2) + self.service.dataFactoryService.getLatestApplicableCalibrationVersion = mock.Mock(return_value=1) + self.service.dataFactoryService.getLatestApplicableNormalizationVersion = mock.Mock(return_value=2) self.service._markWorkspaceMetadata = mock.Mock() groceryClerk = self.service.groceryClerk @@ -741,8 +741,8 @@ def test_fetchReductionGroceries_pixelMasks_not_a_mask(self): pixelMasks=[self.maskWS1, self.maskWS2, self.maskWS5, not_a_mask], focusGroups=[FocusGroup(name="apple", definition="path/to/grouping")], ) - self.service.dataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=1) - self.service.dataFactoryService.getThisOrLatestNormalizationVersion = mock.Mock(return_value=2) + self.service.dataFactoryService.getLatestApplicableCalibrationVersion = mock.Mock(return_value=1) + self.service.dataFactoryService.getLatestApplicableNormalizationVersion = mock.Mock(return_value=2) combinedMaskName = wng.reductionPixelMask().runNumber(request.runNumber).build() mockPrepCombinedMask.return_value = combinedMaskName diff --git a/tests/unit/backend/service/test_SousChef.py b/tests/unit/backend/service/test_SousChef.py index 83c05d9d1..eb80e69bb 100644 --- a/tests/unit/backend/service/test_SousChef.py +++ b/tests/unit/backend/service/test_SousChef.py @@ -3,6 +3,7 @@ import unittest from unittest import mock +import pytest from mantid.simpleapi import DeleteWorkspace, mtd from snapred.backend.dao.CrystallographicInfo import CrystallographicInfo @@ -110,6 +111,7 @@ def test_prepCalibration_userFWHM(self): self.instance._getThresholdFromCalibrantSample = mock.Mock(return_value=0.5) fakeLeft = 116 fakeRight = 17 + self.ingredients.model_config["validate_assignment"] = False self.ingredients.fwhmMultipliers = mock.Mock(left=fakeLeft, right=fakeRight) self.instance.prepCalibrantSample = mock.Mock() @@ -332,6 +334,7 @@ def test_prepPeakIngredients(self, PeakIngredients): self.instance.prepPixelGroup = mock.Mock() self.instance.prepCalibrantSample = mock.Mock() calibrantSample = self.instance.prepCalibrantSample() + self.ingredients.model_config["validate_assignment"] = False self.ingredients.peakIntensityThreshold = calibrantSample.peakIntensityFractionThreshold result = self.instance.prepPeakIngredients(self.ingredients) @@ -496,6 +499,7 @@ def test_prepReductionIngredients(self, ReductionIngredients, mockOS): # noqa: # Modifications to a copy of `ingredients` during the first part of `prepReductionIngredients`, # before the `prepManyPixelGroups` calls: ingredients_ = self.ingredients.model_copy() + ingredients_.model_config["validate_assignment"] = False # ... from calibration record: ingredients_.cifPath = self.instance.dataFactoryService.getCifFilePath.return_value # ... from normalization record: @@ -598,3 +602,15 @@ def test__getThresholdFromCalibrantSample(self, mockOS): # noqa: ARG002 def test__getThresholdFromCalibrantSample_none_path(self): result = self.instance._getThresholdFromCalibrantSample(None) assert result == Config["constants.PeakIntensityFractionThreshold"] + + def test_pullCalibrationRecordFFI_noCalibrationVersion(self): + mockDataFactory = mock.Mock() + mockDataFactory.getCalibrationRecord = mock.Mock() + mockDataFactory.getCalibrationRecord.return_value = None + self.instance.dataFactoryService = mockDataFactory + self.ingredients.version = None + + assert self.ingredients.versions.calibration is None + + with pytest.raises(ValueError, match="Calibration version must be specified"): + self.instance._pullCalibrationRecordFFI(self.ingredients) diff --git a/tests/unit/meta/mantid/test_WorkspaceNameGenerator.py b/tests/unit/meta/mantid/test_WorkspaceNameGenerator.py index b2982317e..a3776ae46 100644 --- a/tests/unit/meta/mantid/test_WorkspaceNameGenerator.py +++ b/tests/unit/meta/mantid/test_WorkspaceNameGenerator.py @@ -7,12 +7,11 @@ from snapred.meta.Config import Config from snapred.meta.mantid.WorkspaceNameGenerator import ( - VERSION_DEFAULT, - VERSION_DEFAULT_NAME, - WorkspaceName, + ValueFormatter as wnvf, ) from snapred.meta.mantid.WorkspaceNameGenerator import ( - ValueFormatter as wnvf, + VersionState, + WorkspaceName, ) from snapred.meta.mantid.WorkspaceNameGenerator import ( WorkspaceNameGenerator as wng, @@ -297,9 +296,9 @@ def test_pathVersion_none(): assert ans == expected -def test_pathVersion_default(): - expected = f"v_{VERSION_DEFAULT_NAME}" - ans = wnvf.pathVersion(VERSION_DEFAULT) +def test_pathversion_default(): + expected = f"v_{VersionState.DEFAULT}" + ans = wnvf.pathVersion(VersionState.DEFAULT) assert ans == expected diff --git a/tests/unit/meta/test_Decorators.py b/tests/unit/meta/test_Decorators.py index 138e8b18a..1bf4f61cd 100644 --- a/tests/unit/meta/test_Decorators.py +++ b/tests/unit/meta/test_Decorators.py @@ -184,6 +184,7 @@ def setText(self, text): self.text = text +@pytest.mark.ui def test_resettable(qtbot): parent = QWidget() qtbot.addWidget(parent) diff --git a/tests/unit/ui/presenter/test_InitializeStatePresenter.py b/tests/unit/ui/presenter/test_InitializeStatePresenter.py index 07ff9a923..1e7a3def7 100644 --- a/tests/unit/ui/presenter/test_InitializeStatePresenter.py +++ b/tests/unit/ui/presenter/test_InitializeStatePresenter.py @@ -9,8 +9,6 @@ from snapred.ui.presenter.InitializeStatePresenter import InitializeStatePresenter from snapred.ui.widget.LoadingCursor import LoadingCursor -app = QApplication(sys.argv) - @not_a_test class TestableQWidget(QWidget): @@ -26,11 +24,14 @@ def __init__(self, *args, **kwargs): @pytest.fixture def setup_view_and_workflow(): + if QApplication.instance() is None: + QApplication(sys.argv) view = TestableQWidget() workflow = InitializeStatePresenter(view=view) return view, workflow +@pytest.mark.ui def test_handleButtonClicked_with_valid_input(setup_view_and_workflow): view, workflow = setup_view_and_workflow view.getRunNumber.return_value = "12345" @@ -42,6 +43,7 @@ def test_handleButtonClicked_with_valid_input(setup_view_and_workflow): mock_initializeState.assert_called_once_with("12345", "Test State", True) +@pytest.mark.ui def test_handleButtonClicked_with_invalid_input(setup_view_and_workflow): view, workflow = setup_view_and_workflow view.getRunNumber.return_value = "invalid" @@ -51,6 +53,7 @@ def test_handleButtonClicked_with_invalid_input(setup_view_and_workflow): mock_warning.assert_called_once() +@pytest.mark.ui def test__initializeState(setup_view_and_workflow): view, workflow = setup_view_and_workflow view.getRunNumber.return_value = "12345" @@ -66,6 +69,7 @@ def test__initializeState(setup_view_and_workflow): mock_dialog_showSuccess.assert_called_once() +@pytest.mark.ui def test__handleResponse_error(setup_view_and_workflow): view, workflow = setup_view_and_workflow error_response = SNAPResponse(code=ResponseCode.ERROR, message="Error message") @@ -78,6 +82,7 @@ def test__handleResponse_error(setup_view_and_workflow): mock_critical.assert_called_once() +@pytest.mark.ui def test__handleResponse_success(setup_view_and_workflow): view, workflow = setup_view_and_workflow success_response = SNAPResponse(code=ResponseCode.OK) diff --git a/tests/unit/ui/view/test_CalibrationAssessmentView.py b/tests/unit/ui/view/test_CalibrationAssessmentView.py index 4b1d7856a..8990064b7 100644 --- a/tests/unit/ui/view/test_CalibrationAssessmentView.py +++ b/tests/unit/ui/view/test_CalibrationAssessmentView.py @@ -1,9 +1,12 @@ from unittest.mock import MagicMock +import pytest + from snapred.backend.dao.indexing.IndexEntry import IndexEntry from snapred.ui.view.DiffCalAssessmentView import DiffCalAssessmentView +@pytest.mark.ui def test_calibration_record_dropdown(qtbot): view = DiffCalAssessmentView() assert view.getCalibrationRecordCount() == 0 @@ -26,6 +29,7 @@ def test_calibration_record_dropdown(qtbot): assert view.getSelectedCalibrationRecordData() == (runNumber, useLiteMode, version) +@pytest.mark.ui def test_error_on_load_calibration_record(qtbot): view = DiffCalAssessmentView() qtbot.addWidget(view.loadButton) diff --git a/tests/unit/ui/view/test_InitializeStateCheckView.py b/tests/unit/ui/view/test_InitializeStateCheckView.py index b08bc9a64..51908eec7 100644 --- a/tests/unit/ui/view/test_InitializeStateCheckView.py +++ b/tests/unit/ui/view/test_InitializeStateCheckView.py @@ -1,6 +1,9 @@ +import pytest + from snapred.ui.view.InitializeStateCheckView import InitializationMenu +@pytest.mark.ui def test_run_number_field(qtbot): menu = InitializationMenu(None) qtbot.addWidget(menu) diff --git a/tests/unit/ui/widget/test_Workflow.py b/tests/unit/ui/widget/test_Workflow.py index 2e4c62ccf..ef05416de 100644 --- a/tests/unit/ui/widget/test_Workflow.py +++ b/tests/unit/ui/widget/test_Workflow.py @@ -40,6 +40,7 @@ def continueAction(workflowPresenter): # noqa: ARG001 return WorkflowBuilder().addNode(continueAction, view, "Test").build() +@pytest.mark.ui def test_workflowPresenterHandleContinueButtonClicked(qtbot): # Mock the worker pool mockWorkerPool = MagicMock() diff --git a/tests/unit/ui/workflow/test_DiffCalWorkflow.py b/tests/unit/ui/workflow/test_DiffCalWorkflow.py index c58d2005d..d6ab83870 100644 --- a/tests/unit/ui/workflow/test_DiffCalWorkflow.py +++ b/tests/unit/ui/workflow/test_DiffCalWorkflow.py @@ -14,6 +14,7 @@ from snapred.ui.workflow.DiffCalWorkflow import DiffCalWorkflow +@pytest.mark.ui @patch("snapred.ui.workflow.DiffCalWorkflow.WorkflowImplementer.request") def test_purge_bad_peaks(workflowRequest, qtbot): # noqa: ARG001 """ @@ -65,6 +66,7 @@ def test_purge_bad_peaks(workflowRequest, qtbot): # noqa: ARG001 ) +@pytest.mark.ui @patch("snapred.ui.workflow.DiffCalWorkflow.WorkflowImplementer.request") def test_purge_bad_peaks_two_wkspindex(workflowRequest, qtbot): # noqa: ARG001 """ @@ -126,6 +128,7 @@ def test_purge_bad_peaks_two_wkspindex(workflowRequest, qtbot): # noqa: ARG001 ) +@pytest.mark.ui @patch("snapred.ui.workflow.DiffCalWorkflow.WorkflowImplementer.request") def test_purge_bad_peaks_too_few(workflowRequest, qtbot): # noqa: ARG001 """ diff --git a/tests/unit/ui/workflow/test_WorkflowImplementer.py b/tests/unit/ui/workflow/test_WorkflowImplementer.py index 2414c41e1..2f3dbd671 100644 --- a/tests/unit/ui/workflow/test_WorkflowImplementer.py +++ b/tests/unit/ui/workflow/test_WorkflowImplementer.py @@ -1,11 +1,13 @@ from random import randint from unittest.mock import MagicMock +import pytest from mantid.simpleapi import CreateSingleValuedWorkspace, GroupWorkspaces, mtd from snapred.ui.workflow.WorkflowImplementer import WorkflowImplementer +@pytest.mark.ui def test_rename_on_iterate_list(qtbot): # noqa: ARG001 """ Test that on iteration, a list of workspaces will be renamed according to the iteration template. @@ -24,6 +26,7 @@ def test_rename_on_iterate_list(qtbot): # noqa: ARG001 assert instance.collectedOutputs == newNames +@pytest.mark.ui def test_rename_on_iterate_group(qtbot): # noqa: ARG001 """ Test that on iteration, a workspace group has all of its members renamed. diff --git a/tests/util_tests/test_state_helpers.py b/tests/util_tests/test_state_helpers.py index 77499884d..543e3204b 100644 --- a/tests/util_tests/test_state_helpers.py +++ b/tests/util_tests/test_state_helpers.py @@ -8,7 +8,7 @@ from util.dao import DAOFactory from util.state_helpers import reduction_root_redirect, state_root_override, state_root_redirect -from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT +from snapred.backend.dao.indexing.Versioning import VERSION_START from snapred.backend.data.LocalDataService import LocalDataService from snapred.meta.Config import Config from snapred.meta.mantid.WorkspaceNameGenerator import ValueFormatter as wnvf @@ -68,7 +68,7 @@ def test_state_root_override_enter( assert Path(stateRootPath) == expectedStateRootPath assert Path(stateRootPath).exists() assert Path(stateRootPath).joinpath("groupingMap.json").exists() - versionString = wnvf.pathVersion(VERSION_DEFAULT) + versionString = wnvf.pathVersion(VERSION_START) assert (Path(stateRootPath) / "lite" / "diffraction" / versionString / "CalibrationParameters.json").exists() From 9d96eb4a584a68e5ee98beec5d9c3703876a81aa Mon Sep 17 00:00:00 2001 From: Michael Walsh <68125095+walshmm@users.noreply.github.com> Date: Tue, 17 Dec 2024 14:42:15 -0500 Subject: [PATCH 2/7] fix inputs for Artificial Norm (#520) add fixes for art norm and refactor art norm algo add distribution flag fix broken tests, still pending coverage add test for rebinfocussedgroupdatarecipe add rebin to the preview window reorg binning and some other steps fix unit tests remove old comment up coverage last bit of coverage remove unused members give art norm preview workspaces more appropriate names fix tests remove incorrect log --- .../GenerateFocussedVanadiumIngredients.py | 2 + .../RebinFocussedGroupDataIngredients.py | 8 ++ .../dao/ingredients/ReductionIngredients.py | 1 + .../CreateArtificialNormalizationRequest.py | 10 +-- .../recipe/ApplyNormalizationRecipe.py | 47 ++++++----- .../recipe/GenerateFocussedVanadiumRecipe.py | 56 +++++++++---- .../recipe/RebinFocussedGroupDataRecipe.py | 82 +++++++++++++++++++ src/snapred/backend/recipe/Recipe.py | 9 +- .../recipe/ReductionGroupProcessingRecipe.py | 11 ++- src/snapred/backend/recipe/ReductionRecipe.py | 19 +++-- .../CreateArtificialNormalizationAlgo.py | 57 +++++++------ .../backend/service/ReductionService.py | 42 +++++++--- .../meta/mantid/WorkspaceNameGenerator.py | 15 ++++ src/snapred/resources/application.yml | 1 + .../reduction/ArtificialNormalizationView.py | 8 +- src/snapred/ui/widget/TrueFalseDropDown.py | 3 + src/snapred/ui/workflow/ReductionWorkflow.py | 11 ++- tests/resources/application.yml | 1 + .../test_CreateArtificialNormalizationAlgo.py | 15 +++- .../recipe/test_ApplyNormalizationRecipe.py | 58 +++---------- .../test_GenerateFocussedVanadiumRecipe.py | 36 ++++++-- .../test_RebinFocussedGroupDataRecipe.py | 74 +++++++++++++++++ .../test_ReductionGroupProcessingRecipe.py | 8 +- .../backend/recipe/test_ReductionRecipe.py | 32 ++++---- .../backend/service/test_ReductionService.py | 21 ++++- tests/util/SculleryBoy.py | 4 + 26 files changed, 446 insertions(+), 185 deletions(-) create mode 100644 src/snapred/backend/dao/ingredients/RebinFocussedGroupDataIngredients.py create mode 100644 src/snapred/backend/recipe/RebinFocussedGroupDataRecipe.py create mode 100644 tests/unit/backend/recipe/test_RebinFocussedGroupDataRecipe.py diff --git a/src/snapred/backend/dao/ingredients/GenerateFocussedVanadiumIngredients.py b/src/snapred/backend/dao/ingredients/GenerateFocussedVanadiumIngredients.py index 916f4815e..d82a74b6a 100644 --- a/src/snapred/backend/dao/ingredients/GenerateFocussedVanadiumIngredients.py +++ b/src/snapred/backend/dao/ingredients/GenerateFocussedVanadiumIngredients.py @@ -3,6 +3,7 @@ from pydantic import BaseModel from snapred.backend.dao.GroupPeakList import GroupPeakList +from snapred.backend.dao.ingredients.ArtificialNormalizationIngredients import ArtificialNormalizationIngredients from snapred.backend.dao.state.PixelGroup import PixelGroup from snapred.meta.Config import Config @@ -14,3 +15,4 @@ class GenerateFocussedVanadiumIngredients(BaseModel): pixelGroup: PixelGroup # This can be None if we lack a calibration detectorPeaks: Optional[list[GroupPeakList]] = None + artificialNormalizationIngredients: Optional[ArtificialNormalizationIngredients] = None diff --git a/src/snapred/backend/dao/ingredients/RebinFocussedGroupDataIngredients.py b/src/snapred/backend/dao/ingredients/RebinFocussedGroupDataIngredients.py new file mode 100644 index 000000000..05bf70da4 --- /dev/null +++ b/src/snapred/backend/dao/ingredients/RebinFocussedGroupDataIngredients.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + +from snapred.backend.dao.state.PixelGroup import PixelGroup + + +class RebinFocussedGroupDataIngredients(BaseModel): + pixelGroup: PixelGroup + preserveEvents: bool = False diff --git a/src/snapred/backend/dao/ingredients/ReductionIngredients.py b/src/snapred/backend/dao/ingredients/ReductionIngredients.py index ea9ce4e4c..740b0a0fa 100644 --- a/src/snapred/backend/dao/ingredients/ReductionIngredients.py +++ b/src/snapred/backend/dao/ingredients/ReductionIngredients.py @@ -58,6 +58,7 @@ def generateFocussedVanadium(self, groupingIndex: int) -> GenerateFocussedVanadi smoothingParameter=self.smoothingParameter, pixelGroup=self.pixelGroups[groupingIndex], detectorPeaks=self.getDetectorPeaks(groupingIndex), + artificialNormalizationIngredients=self.artificialNormalizationIngredients, ) def applyNormalization(self, groupingIndex: int) -> ApplyNormalizationIngredients: diff --git a/src/snapred/backend/dao/request/CreateArtificialNormalizationRequest.py b/src/snapred/backend/dao/request/CreateArtificialNormalizationRequest.py index f8792ea94..69dd18b55 100644 --- a/src/snapred/backend/dao/request/CreateArtificialNormalizationRequest.py +++ b/src/snapred/backend/dao/request/CreateArtificialNormalizationRequest.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, root_validator +from pydantic import BaseModel from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName @@ -11,13 +11,7 @@ class CreateArtificialNormalizationRequest(BaseModel): decreaseParameter: bool = True lss: bool = True diffractionWorkspace: WorkspaceName - outputWorkspace: WorkspaceName = None - - @root_validator(pre=True) - def set_output_workspace(cls, values): - if values.get("diffractionWorkspace") and not values.get("outputWorkspace"): - values["outputWorkspace"] = WorkspaceName(f"{values['diffractionWorkspace']}_artificialNorm") - return values + outputWorkspace: WorkspaceName class Config: arbitrary_types_allowed = True # Allow arbitrary types like WorkspaceName diff --git a/src/snapred/backend/recipe/ApplyNormalizationRecipe.py b/src/snapred/backend/recipe/ApplyNormalizationRecipe.py index d48547a03..25c56adc9 100644 --- a/src/snapred/backend/recipe/ApplyNormalizationRecipe.py +++ b/src/snapred/backend/recipe/ApplyNormalizationRecipe.py @@ -2,6 +2,7 @@ from snapred.backend.dao.ingredients import ApplyNormalizationIngredients as Ingredients from snapred.backend.log.logger import snapredLogger +from snapred.backend.recipe.RebinFocussedGroupDataRecipe import RebinFocussedGroupDataRecipe from snapred.backend.recipe.Recipe import Recipe from snapred.meta.Config import Config from snapred.meta.decorators.Singleton import Singleton @@ -26,19 +27,6 @@ def chopIngredients(self, ingredients: Ingredients): We are mostly concerned about the drange for a ResampleX operation. """ self.pixelGroup = ingredients.pixelGroup - # The adjustment below is a temp fix, will be permanently fixed in EWM 6262 - lowdSpacingCrop = Config["constants.CropFactors.lowdSpacingCrop"] - if lowdSpacingCrop < 0: - raise ValueError("Low d-spacing crop factor must be positive") - highdSpacingCrop = Config["constants.CropFactors.highdSpacingCrop"] - if highdSpacingCrop < 0: - raise ValueError("High d-spacing crop factor must be positive") - dMin = [x + lowdSpacingCrop for x in self.pixelGroup.dMin()] - dMax = [x - highdSpacingCrop for x in self.pixelGroup.dMax()] - if not dMax > dMin: - raise ValueError("d-spacing crop factors are too large -- resultant dMax must be > resultant dMin") - self.dMin = dMin - self.dMax = dMax def unbagGroceries(self, groceries: Dict[str, WorkspaceName]): """ @@ -48,6 +36,8 @@ def unbagGroceries(self, groceries: Dict[str, WorkspaceName]): The background workspace, backgroundWorkspace, is optional, not implemented, in dspacing. """ self.sampleWs = groceries["inputWorkspace"] + # NOTE: the normalization workspace should be appropriately binned + # and then converted to a histogram prior to this recipe self.normalizationWs = groceries.get("normalizationWorkspace", "") self.backgroundWs = groceries.get("backgroundWorkspace", "") @@ -59,11 +49,22 @@ def stirInputs(self): if self.backgroundWs != "": raise NotImplementedError("Background Subtraction is not implemented for this release.") + def _rebinSample(self, preserveEvents: bool): + """ + Rebins the sample workspace to the pixel group. + """ + rebinRecipe = RebinFocussedGroupDataRecipe(self.utensils) + rebinIngredients = RebinFocussedGroupDataRecipe.Ingredients( + pixelGroup=self.pixelGroup, preserveEvents=preserveEvents + ) + rebinRecipe.cook(rebinIngredients, {"inputWorkspace": self.sampleWs}) + def queueAlgos(self): """ Queues up the procesing algorithms for the recipe. Requires: unbagged groceries and chopped ingredients. """ + if self.normalizationWs: self.mantidSnapper.Divide( "Dividing out the normalization..", @@ -71,15 +72,6 @@ def queueAlgos(self): RHSWorkspace=self.normalizationWs, OutputWorkspace=self.sampleWs, ) - self.mantidSnapper.RebinRagged( - "Resampling X-axis...", - InputWorkspace=self.sampleWs, - XMin=self.dMin, - XMax=self.dMax, - Delta=self.pixelGroup.dBin(), - OutputWorkspace=self.sampleWs, - PreserveEvents=False, - ) # NOTE: Metaphorically, would ingredients better have been called Spices? # Considering they are mostly never the meat of a recipe. @@ -90,8 +82,19 @@ def cook(self, ingredients: Ingredients, groceries: Dict[str, str]) -> Dict[str, """ self.prep(ingredients, groceries) self.execute() + self.mantidSnapper.mtd[self.sampleWs].setDistribution(True) return self.sampleWs + def execute(self): + """ + Final step in a recipe, executes the queued algorithms. + Requires: queued algorithms. + """ + self._rebinSample(preserveEvents=True) + self.mantidSnapper.executeQueue() + self._rebinSample(preserveEvents=False) + return True + def cater(self, shipment: List[Pallet]) -> List[WorkspaceName]: """ A secondary interface method for the recipe. diff --git a/src/snapred/backend/recipe/GenerateFocussedVanadiumRecipe.py b/src/snapred/backend/recipe/GenerateFocussedVanadiumRecipe.py index c1d3648b0..6623b9b81 100644 --- a/src/snapred/backend/recipe/GenerateFocussedVanadiumRecipe.py +++ b/src/snapred/backend/recipe/GenerateFocussedVanadiumRecipe.py @@ -2,6 +2,7 @@ from snapred.backend.dao.ingredients import GenerateFocussedVanadiumIngredients as Ingredients from snapred.backend.log.logger import snapredLogger +from snapred.backend.recipe.RebinFocussedGroupDataRecipe import RebinFocussedGroupDataRecipe from snapred.backend.recipe.Recipe import Recipe from snapred.meta.decorators.Singleton import Singleton from snapred.meta.redantic import list_to_raw @@ -25,36 +26,63 @@ class GenerateFocussedVanadiumRecipe(Recipe[Ingredients]): def chopIngredients(self, ingredients: Ingredients): self.smoothingParameter = ingredients.smoothingParameter - self.detectorPeaks = list_to_raw(ingredients.detectorPeaks) - self.dMin = ingredients.pixelGroup.dMin() - self.dMax = ingredients.pixelGroup.dMax() - self.dBin = ingredients.pixelGroup.dBin() + self.detectorPeaks = list_to_raw(ingredients.detectorPeaks) if ingredients.detectorPeaks is not None else None + self.pixelGroup = ingredients.pixelGroup + + self.artificialNormalizationIngredients = ingredients.artificialNormalizationIngredients def unbagGroceries(self, groceries: Dict[str, Any]): self.inputWS = groceries["inputWorkspace"] self.outputWS = groceries.get("outputWorkspace", groceries["inputWorkspace"]) - def queueAlgos(self): + def queueArtificialNormalization(self): """ - Queues up the procesing algorithms for the recipe. - Requires: unbagged groceries. + Queues up the artificial normalization recipe if the ingredients are available. """ + self.mantidSnapper.CreateArtificialNormalizationAlgo( + "Create Artificial Normalization...", + InputWorkspace=self.inputWS, + OutputWorkspace=self.outputWS, + peakWindowClippingSize=self.artificialNormalizationIngredients.peakWindowClippingSize, + smoothingParameter=self.artificialNormalizationIngredients.smoothingParameter, + decreaseParameter=self.artificialNormalizationIngredients.decreaseParameter, + LSS=self.artificialNormalizationIngredients.lss, + ) + + def queueNaturalNormalization(self): self.mantidSnapper.SmoothDataExcludingPeaksAlgo( "Smoothing Data Excluding Peaks...", - InputWorkspace=self.outputWS, + InputWorkspace=self.inputWS, OutputWorkspace=self.outputWS, DetectorPeaks=self.detectorPeaks, SmoothingParameter=self.smoothingParameter, ) + def _rebinInputWorkspace(self): + """ + Rebins the input workspace to the pixel group. + """ + rebinRecipe = RebinFocussedGroupDataRecipe(self.utensils) + rebinIngredients = RebinFocussedGroupDataRecipe.Ingredients(pixelGroup=self.pixelGroup) + rebinRecipe.cook(rebinIngredients, {"inputWorkspace": self.inputWS}) + + def queueAlgos(self): + """ + Queues up the procesing algorithms for the recipe. + Requires: unbagged groceries. + """ + self._rebinInputWorkspace() + + if self.artificialNormalizationIngredients is not None: + self.queueArtificialNormalization() + else: + self.queueNaturalNormalization() + def cook(self, ingredients: Ingredients, groceries: Dict[str, str]) -> Dict[str, Any]: self.prep(ingredients, groceries) - output = None - if self.inputWS is not None: - self.execute() - output = self.outputWS - else: - raise NotImplementedError("Fake Vanadium not implemented yet.") + + self.execute() + output = self.outputWS logger.info(f"Finished generating focussed vanadium for {self.inputWS}...") return output diff --git a/src/snapred/backend/recipe/RebinFocussedGroupDataRecipe.py b/src/snapred/backend/recipe/RebinFocussedGroupDataRecipe.py new file mode 100644 index 000000000..f75663d1b --- /dev/null +++ b/src/snapred/backend/recipe/RebinFocussedGroupDataRecipe.py @@ -0,0 +1,82 @@ +from typing import Any, Dict, Tuple + +from snapred.backend.dao.ingredients import RebinFocussedGroupDataIngredients as Ingredients +from snapred.backend.log.logger import snapredLogger +from snapred.backend.recipe.Recipe import Recipe +from snapred.meta.Config import Config +from snapred.meta.decorators.Singleton import Singleton +from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName + +logger = snapredLogger.getLogger(__name__) + +Pallet = Tuple[Ingredients, Dict[str, str]] + + +@Singleton +class RebinFocussedGroupDataRecipe(Recipe[Ingredients]): + NUM_BINS = Config["constants.ResampleX.NumberBins"] + LOG_BINNING = True + + def mandatoryInputWorkspaces(self): + return {"inputWorkspace"} + + def chopIngredients(self, ingredients: Ingredients): + """ + Chops off the needed elements of the ingredients. + We are mostly concerned about the drange for a ResampleX operation. + """ + self.pixelGroup = ingredients.pixelGroup + # The adjustment below is a temp fix, will be permanently fixed in EWM 6262 + lowdSpacingCrop = Config["constants.CropFactors.lowdSpacingCrop"] + if lowdSpacingCrop < 0: + raise ValueError("Low d-spacing crop factor must be positive") + + highdSpacingCrop = Config["constants.CropFactors.highdSpacingCrop"] + if highdSpacingCrop < 0: + raise ValueError("High d-spacing crop factor must be positive") + + dMin = [x + lowdSpacingCrop for x in self.pixelGroup.dMin()] + dMax = [x - highdSpacingCrop for x in self.pixelGroup.dMax()] + + if not dMax > dMin: + raise ValueError("d-spacing crop factors are too large -- resultant dMax must be > resultant dMin") + self.dMin = dMin + self.dMax = dMax + self.dBin = self.pixelGroup.dBin() + + self.preserveEvents = ingredients.preserveEvents + + def unbagGroceries(self, groceries: Dict[str, WorkspaceName]): + """ + Unpacks the workspace data from the groceries. + The input sample data workpsace, inputworkspace, is required, in dspacing + The normalization workspace, normalizationWorkspace, is optional, in dspacing. + The background workspace, backgroundWorkspace, is optional, not implemented, in dspacing. + """ + self.sampleWs = groceries["inputWorkspace"] + + def queueAlgos(self): + """ + Queues up the procesing algorithms for the recipe. + Requires: unbagged groceries and chopped ingredients. + """ + self.mantidSnapper.RebinRagged( + "Rebinning workspace for group...", + InputWorkspace=self.sampleWs, + XMin=self.dMin, + XMax=self.dMax, + Delta=self.dBin, + OutputWorkspace=self.sampleWs, + PreserveEvents=self.preserveEvents, + ) + + # NOTE: Metaphorically, would ingredients better have been called Spices? + # Considering they are mostly never the meat of a recipe. + def cook(self, ingredients: Ingredients, groceries: Dict[str, str]) -> Dict[str, Any]: + """ + Main interface method for the recipe. + Given the ingredients and groceries, it prepares, executes and returns the final workspace. + """ + self.prep(ingredients, groceries) + self.execute() + return self.sampleWs diff --git a/src/snapred/backend/recipe/Recipe.py b/src/snapred/backend/recipe/Recipe.py index 9dd0bf940..9765d1e09 100644 --- a/src/snapred/backend/recipe/Recipe.py +++ b/src/snapred/backend/recipe/Recipe.py @@ -22,10 +22,11 @@ def __init__(self, utensils: Utensils = None): Sets up the recipe with the necessary utensils. """ # NOTE: workaround, we just add an empty host algorithm. - if utensils is None: - utensils = Utensils() - utensils.PyInit() - self.mantidSnapper = utensils.mantidSnapper + self.utensils = utensils + if self.utensils is None: + self.utensils = Utensils() + self.utensils.PyInit() + self.mantidSnapper = self.utensils.mantidSnapper def __init_subclass__(cls) -> None: cls._Ingredients = get_args(cls.__orig_bases__[0])[0] # type: ignore diff --git a/src/snapred/backend/recipe/ReductionGroupProcessingRecipe.py b/src/snapred/backend/recipe/ReductionGroupProcessingRecipe.py index b26a19f5d..423dc6f43 100644 --- a/src/snapred/backend/recipe/ReductionGroupProcessingRecipe.py +++ b/src/snapred/backend/recipe/ReductionGroupProcessingRecipe.py @@ -32,13 +32,13 @@ def queueAlgos(self): "Converting to TOF...", InputWorkspace=self.rawInput, Target="TOF", - OutputWorkspace=self.rawInput, + OutputWorkspace=self.outputWS, ) self.mantidSnapper.FocusSpectraAlgorithm( "Focusing Spectra...", - InputWorkspace=self.rawInput, - OutputWorkspace=self.rawInput, + InputWorkspace=self.outputWS, + OutputWorkspace=self.outputWS, GroupingWorkspace=self.groupingWS, Ingredients=self.pixelGroup.json(), RebinOutput=False, @@ -46,10 +46,9 @@ def queueAlgos(self): self.mantidSnapper.NormalizeByCurrentButTheCorrectWay( "Normalizing Current ... but the correct way!", - InputWorkspace=self.rawInput, - OutputWorkspace=self.rawInput, + InputWorkspace=self.outputWS, + OutputWorkspace=self.outputWS, ) - self.outputWS = self.rawInput def mandatoryInputWorkspaces(self) -> Set[WorkspaceName]: return {"inputWorkspace", "groupingWorkspace"} diff --git a/src/snapred/backend/recipe/ReductionRecipe.py b/src/snapred/backend/recipe/ReductionRecipe.py index aaf905d8b..65da198a2 100644 --- a/src/snapred/backend/recipe/ReductionRecipe.py +++ b/src/snapred/backend/recipe/ReductionRecipe.py @@ -148,6 +148,8 @@ def _applyRecipe(self, recipe: Type[Recipe], ingredients_, **kwargs): if not inputWorkspace: self.logger().debug(f"{recipe.__name__} :: Skipping recipe with default empty input workspace") return + if "outputWorkspace" not in kwargs and "outputWorkspace" in self.groceries: + del self.groceries["outputWorkspace"] if self.mantidSnapper.mtd.doesExist(inputWorkspace): self.groceries.update(kwargs) recipe().cook(ingredients_, self.groceries) @@ -253,19 +255,22 @@ def execute(self): ) self._cloneIntermediateWorkspace(normalizationClone, f"normalization_GroupProcessing_{groupingIndex}") + vanadiumBasisWorkspace = normalizationClone + # if there was no normalization and the user elected to use artificial normalization + # generate one given the params and the processed sample data + if self.ingredients.artificialNormalizationIngredients: + vanadiumBasisWorkspace = sampleClone + normalizationClone = self._getNormalizationWorkspaceName(groupingIndex) + # 3. GenerateFocussedVanadiumRecipe self._applyRecipe( GenerateFocussedVanadiumRecipe, self.ingredients.generateFocussedVanadium(groupingIndex), - inputWorkspace=normalizationClone, + inputWorkspace=vanadiumBasisWorkspace, + outputWorkspace=normalizationClone, ) - self._cloneIntermediateWorkspace(normalizationClone, f"normalization_FoocussedVanadium_{groupingIndex}") - # if there was no normalization and the user elected to use artificial normalization - # generate one given the params and the processed sample data - # Skipping the above steps as they are accounted for in generating the artificial normalization - if self.ingredients.artificialNormalizationIngredients: - normalizationClone = self._prepareArtificialNormalization(sampleClone, groupingIndex) + self._cloneIntermediateWorkspace(normalizationClone, f"normalization_FoocussedVanadium_{groupingIndex}") # 4. ApplyNormalizationRecipe self._applyRecipe( diff --git a/src/snapred/backend/recipe/algorithm/CreateArtificialNormalizationAlgo.py b/src/snapred/backend/recipe/algorithm/CreateArtificialNormalizationAlgo.py index 741b99a33..b7d7ac027 100644 --- a/src/snapred/backend/recipe/algorithm/CreateArtificialNormalizationAlgo.py +++ b/src/snapred/backend/recipe/algorithm/CreateArtificialNormalizationAlgo.py @@ -1,5 +1,3 @@ -import json - import numpy as np from mantid.api import ( AlgorithmFactory, @@ -42,19 +40,33 @@ def PyInit(self): doc="Workspace that contains artificial normalization.", ) self.declareProperty( - "Ingredients", - defaultValue="", + "decreaseParameter", + defaultValue=True, + direction=Direction.Input, + ) + self.declareProperty( + "lss", + defaultValue=True, + direction=Direction.Input, + ) + self.declareProperty( + "peakWindowClippingSize", + defaultValue=5, + direction=Direction.Input, + ) + self.declareProperty( + "smoothingParameter", + defaultValue=5.0, direction=Direction.Input, ) self.setRethrows(True) self.mantidSnapper = MantidSnapper(self, __name__) - def chopInredients(self, ingredientsStr: str): - ingredientsDict = json.loads(ingredientsStr) - self.peakWindowClippingSize = ingredientsDict["peakWindowClippingSize"] - self.smoothingParameter = ingredientsDict["smoothingParameter"] - self.decreaseParameter = ingredientsDict["decreaseParameter"] - self.LSS = ingredientsDict["lss"] + def chopInredients(self): + self.peakWindowClippingSize = int(self.getPropertyValue("peakWindowClippingSize")) + self.smoothingParameter = float(self.getPropertyValue("smoothingParameter")) + self.decreaseParameter = int(self.getPropertyValue("decreaseParameter")) + self.LSS = int(self.getPropertyValue("lss")) def unbagGroceries(self): self.inputWorkspaceName = self.getPropertyValue("InputWorkspace") @@ -113,11 +125,18 @@ def InvLLSTransformation(self, input): # noqa: A002 def PyExec(self): # Main execution method for the algorithm self.unbagGroceries() - ingredients = self.getProperty("Ingredients").value - self.chopInredients(ingredients) - self.mantidSnapper.CloneWorkspace( - "Cloning input workspace...", - InputWorkspace=self.inputWorkspaceName, + self.chopInredients() + if self.outputWorkspaceName != self.inputWorkspaceName: + self.mantidSnapper.CloneWorkspace( + "Cloning input workspace...", + InputWorkspace=self.inputWorkspaceName, + OutputWorkspace=self.outputWorkspaceName, + ) + + self.mantidSnapper.ConvertUnits( + "Converting to dSpacing...", + InputWorkspace=self.outputWorkspaceName, + Target="dSpacing", OutputWorkspace=self.outputWorkspaceName, ) # if input workspace is an eventworkspace, convert it to a histogram workspace @@ -129,15 +148,7 @@ def PyExec(self): OutputWorkspace=self.outputWorkspaceName, ) - self.mantidSnapper.ConvertUnits( - "Converting to dSpacing...", - InputWorkspace=self.outputWorkspaceName, - Target="dSpacing", - OutputWorkspace=self.outputWorkspaceName, - ) - self.mantidSnapper.executeQueue() - self.inputWorkspace = self.mantidSnapper.mtd[self.inputWorkspaceName] self.outputWorkspace = self.mantidSnapper.mtd[self.outputWorkspaceName] # Apply peak clipping to each histogram in the workspace diff --git a/src/snapred/backend/service/ReductionService.py b/src/snapred/backend/service/ReductionService.py index c9d92e651..142e9b172 100644 --- a/src/snapred/backend/service/ReductionService.py +++ b/src/snapred/backend/service/ReductionService.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional from snapred.backend.dao.ingredients import ( - ArtificialNormalizationIngredients, GroceryListItem, ReductionIngredients, ) @@ -27,18 +26,15 @@ from snapred.backend.log.logger import snapredLogger from snapred.backend.recipe.algorithm.MantidSnapper import MantidSnapper from snapred.backend.recipe.GenericRecipe import ArtificialNormalizationRecipe +from snapred.backend.recipe.RebinFocussedGroupDataRecipe import RebinFocussedGroupDataRecipe from snapred.backend.recipe.ReductionGroupProcessingRecipe import ReductionGroupProcessingRecipe from snapred.backend.recipe.ReductionRecipe import ReductionRecipe from snapred.backend.service.Service import Service from snapred.backend.service.SousChef import SousChef from snapred.meta.decorators.FromString import FromString from snapred.meta.decorators.Singleton import Singleton -from snapred.meta.mantid.WorkspaceNameGenerator import ( - WorkspaceName, -) -from snapred.meta.mantid.WorkspaceNameGenerator import ( - WorkspaceNameGenerator as wng, -) +from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName +from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceNameGenerator as wng from snapred.meta.mantid.WorkspaceNameGenerator import ( WorkspaceType as wngt, ) @@ -508,15 +504,12 @@ def getCompatibleMasks(self, request: ReductionRequest) -> List[WorkspaceName]: return self.dataFactoryService.getCompatibleReductionMasks(runNumber, useLiteMode) def artificialNormalization(self, request: CreateArtificialNormalizationRequest): - ingredients = ArtificialNormalizationIngredients( + artificialNormWorkspace = ArtificialNormalizationRecipe().executeRecipe( + InputWorkspace=request.diffractionWorkspace, peakWindowClippingSize=request.peakWindowClippingSize, smoothingParameter=request.smoothingParameter, decreaseParameter=request.decreaseParameter, lss=request.lss, - ) - artificialNormWorkspace = ArtificialNormalizationRecipe().executeRecipe( - InputWorkspace=request.diffractionWorkspace, - Ingredients=ingredients, OutputWorkspace=request.outputWorkspace, ) return artificialNormWorkspace @@ -535,9 +528,32 @@ def grabWorkspaceforArtificialNorm(self, request: ReductionRequest): request.focusGroups = [columnGroup] # 2.5. get ingredients ingredients = self.prepReductionIngredients(request) + + artNormBasisWorkspace = ( + wng.artificialNormalizationPreview() + .runNumber(request.runNumber) + .group(wng.Groups.COLUMN) + .type(wng.ArtificialNormWorkspaceType.SOURCE) + .build() + ) groceries = { "inputWorkspace": runWorkspace, "groupingWorkspace": columnGroupWorkspace, + "outputWorkspace": artNormBasisWorkspace, } # 3. Diffraction Focus Spectra - return ReductionGroupProcessingRecipe().cook(ingredients.groupProcessing(0), groceries) + ReductionGroupProcessingRecipe().cook(ingredients.groupProcessing(0), groceries) + + # 4. Rebin + rebinIngredients = RebinFocussedGroupDataRecipe.Ingredients( + pixelGroup=ingredients.pixelGroups[0], preserveEvents=True + ) + + # NOTE: This is PURPOSELY reinstanced to support testing. + # assert_called_with DOES NOT deep copy the dictionary. + # Thus reusing the above dict would fail the test. + groceries = {"inputWorkspace": artNormBasisWorkspace} + + rebinResult = RebinFocussedGroupDataRecipe().cook(rebinIngredients, groceries) + # 5. Return the rebin result + return rebinResult diff --git a/src/snapred/meta/mantid/WorkspaceNameGenerator.py b/src/snapred/meta/mantid/WorkspaceNameGenerator.py index 603abbe5f..2bcce2c16 100644 --- a/src/snapred/meta/mantid/WorkspaceNameGenerator.py +++ b/src/snapred/meta/mantid/WorkspaceNameGenerator.py @@ -69,6 +69,7 @@ class WorkspaceType(str, Enum): RAW_VANADIUM = "rawVanadium" FOCUSED_RAW_VANADIUM = "focusedRawVanadium" SMOOTHED_FOCUSED_RAW_VANADIUM = "smoothedFocusedRawVanadium" + ARTIFICIAL_NORMALIZATION_PREVIEW = "artificialNormalizationPreview" # __ REDUCTION_OUTPUT = "reductionOutput" @@ -293,6 +294,10 @@ class Lite: TRUE = "lite" FALSE = "" + class ArtificialNormWorkspaceType: + PREVIEW = "preview" + SOURCE = "source" + # TODO: Return abstract WorkspaceName type to help facilitate control over workspace names # and discourage non-standard names. def run(self): @@ -412,6 +417,16 @@ def smoothedFocusedRawVanadium(self): version=None, ) + def artificialNormalizationPreview(self): + return NameBuilder( + WorkspaceType.ARTIFICIAL_NORMALIZATION_PREVIEW, + self._normCalArtificialNormalizationPreviewTemplate, + self._normCalArtificialNormalizationPreviewTemplateKeys, + self._delimiter, + unit=self.Units.DSP, + type=self.ArtificialNormWorkspaceType.PREVIEW, + ) + def reductionOutput(self): return NameBuilder( WorkspaceType.REDUCTION_OUTPUT, diff --git a/src/snapred/resources/application.yml b/src/snapred/resources/application.yml index 6968c4004..1486b60c7 100644 --- a/src/snapred/resources/application.yml +++ b/src/snapred/resources/application.yml @@ -126,6 +126,7 @@ mantid: rawVanadium: "{unit},{group},{runNumber},raw_van_corr,{version}" focusedRawVanadium: "{unit},{group},{runNumber},raw_van_corr,{version}" smoothedFocusedRawVanadium: "{unit},{group},{runNumber},fitted_van_corr,{version}" + artificialNormalizationPreview: "artificial_norm,{unit},{group},{runNumber},{type}" reduction: output: "reduced,{unit},{group},{runNumber},{timestamp}" outputGroup: "reduced,{runNumber},{timestamp}" diff --git a/src/snapred/ui/view/reduction/ArtificialNormalizationView.py b/src/snapred/ui/view/reduction/ArtificialNormalizationView.py index 1af46ff62..f2b8b0dd7 100644 --- a/src/snapred/ui/view/reduction/ArtificialNormalizationView.py +++ b/src/snapred/ui/view/reduction/ArtificialNormalizationView.py @@ -105,8 +105,8 @@ def emitValueChange(self): # verify the fields before recalculation try: smoothingValue = float(self.smoothingSlider.field.text()) - lss = self.lssDropdown.currentIndex() == "True" - decreaseParameter = self.decreaseParameterDropdown.currentIndex == "True" + lss = self.lssDropdown.getValue() + decreaseParameter = self.decreaseParameterDropdown.getValue() peakWindowClippingSize = int(self.peakWindowClippingSize.field.text()) except ValueError as e: QMessageBox.warning( @@ -199,7 +199,7 @@ def getSmoothingParameter(self): return float(self.smoothingSlider.field.text()) def getLSS(self): - return self.lssDropdown.currentIndex() == 1 + return self.lssDropdown.getValue() def getDecreaseParameter(self): - return self.decreaseParameterDropdown.currentIndex() == 1 + return self.decreaseParameterDropdown.getValue() diff --git a/src/snapred/ui/widget/TrueFalseDropDown.py b/src/snapred/ui/widget/TrueFalseDropDown.py index bb41951f2..03f534f7f 100644 --- a/src/snapred/ui/widget/TrueFalseDropDown.py +++ b/src/snapred/ui/widget/TrueFalseDropDown.py @@ -29,3 +29,6 @@ def setCurrentIndex(self, index): def currentText(self): return self.dropDown.currentText() + + def getValue(self): + return self.currentText() == "True" diff --git a/src/snapred/ui/workflow/ReductionWorkflow.py b/src/snapred/ui/workflow/ReductionWorkflow.py index d79337850..4b0d07e00 100644 --- a/src/snapred/ui/workflow/ReductionWorkflow.py +++ b/src/snapred/ui/workflow/ReductionWorkflow.py @@ -14,6 +14,7 @@ from snapred.backend.log.logger import snapredLogger from snapred.meta.decorators.ExceptionToErrLog import ExceptionToErrLog from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName +from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceNameGenerator as wng from snapred.ui.view.reduction.ArtificialNormalizationView import ArtificialNormalizationView from snapred.ui.view.reduction.ReductionRequestView import ReductionRequestView from snapred.ui.workflow.WorkflowBuilder import WorkflowBuilder @@ -233,7 +234,9 @@ def _triggerReduction(self, workflowPresenter): # SPECIAL FOR THE REDUCTION WORKFLOW: clear everything _except_ the output workspaces # _before_ transitioning to the "save" panel. # TODO: make '_clearWorkspaces' a public method (i.e make this combination a special `cleanup` method). - self._clearWorkspaces(exclude=self.outputs, clearCachedWorkspaces=True) + + # NOTE: should this not occur in the 'finalizeReduction' method? + # self._clearWorkspaces(exclude=self.outputs, clearCachedWorkspaces=True) return self.responses[-1] def _artificialNormalization(self, workflowPresenter, responseData, runNumber): @@ -244,9 +247,10 @@ def _artificialNormalization(self, workflowPresenter, responseData, runNumber): useLiteMode=self.useLiteMode, peakWindowClippingSize=int(self._artificialNormalizationView.peakWindowClippingSize.field.text()), smoothingParameter=self._artificialNormalizationView.getSmoothingParameter(), - decreaseParameter=self._artificialNormalizationView.decreaseParameterDropdown.currentIndex() == 1, - lss=self._artificialNormalizationView.lssDropdown.currentIndex() == 1, + decreaseParameter=self._artificialNormalizationView.decreaseParameterDropdown.getValue(), + lss=self._artificialNormalizationView.lssDropdown.getValue(), diffractionWorkspace=responseData, + outputWorkspace=wng.artificialNormalizationPreview().runNumber(runNumber).group(wng.Groups.COLUMN).build(), ) response = self.request(path="reduction/artificialNormalization", payload=request_) # Update artificial normalization view with the response @@ -272,6 +276,7 @@ def onArtificialNormalizationValueChange(self, smoothingValue, lss, decreasePara decreaseParameter=decreaseParameter, lss=lss, diffractionWorkspace=diffractionWorkspace, + outputWorkspace=wng.artificialNormalizationPreview().runNumber(runNumber).group(wng.Groups.COLUMN).build(), ) response = self.request(path="reduction/artificialNormalization", payload=request_) diff --git a/tests/resources/application.yml b/tests/resources/application.yml index 3229e6c2c..5dbfb7249 100644 --- a/tests/resources/application.yml +++ b/tests/resources/application.yml @@ -127,6 +127,7 @@ mantid: rawVanadium: "_{unit},{group},raw_van_corr,{runNumber},{version}" focusedRawVanadium: "_{unit},{group},raw_van_corr,{runNumber},{version}" smoothedFocusedRawVanadium: "_{unit},{group},fitted_van_corr,{runNumber},{version}" + artificialNormalizationPreview: "artificial_norm,{unit},{group},{runNumber},{type}" reduction: output: "_reduced,{unit},{group},{runNumber},{timestamp}" outputGroup: "_reduced,{runNumber},{timestamp}" diff --git a/tests/unit/backend/recipe/algorithm/test_CreateArtificialNormalizationAlgo.py b/tests/unit/backend/recipe/algorithm/test_CreateArtificialNormalizationAlgo.py index 7f9aa80ca..50e8c2d4f 100644 --- a/tests/unit/backend/recipe/algorithm/test_CreateArtificialNormalizationAlgo.py +++ b/tests/unit/backend/recipe/algorithm/test_CreateArtificialNormalizationAlgo.py @@ -42,7 +42,10 @@ def test_chop_ingredients(self): algo = Algo() algo.initialize() algo.setProperty("InputWorkspace", self.fakeRawData) - algo.setProperty("Ingredients", self.fakeIngredients.json()) + algo.setProperty("decreaseParameter", self.fakeIngredients.decreaseParameter) + algo.setProperty("lss", self.fakeIngredients.lss) + algo.setProperty("peakWindowClippingSize", self.fakeIngredients.peakWindowClippingSize) + algo.setProperty("smoothingParameter", self.fakeIngredients.smoothingParameter) algo.setProperty("OutputWorkspace", "test_output_ws") originalData = [] inputWs = mtd[self.fakeRawData] @@ -70,7 +73,10 @@ def test_execute(self): algo = Algo() algo.initialize() algo.setProperty("InputWorkspace", self.fakeRawData) - algo.setProperty("Ingredients", self.fakeIngredients.json()) + algo.setProperty("decreaseParameter", self.fakeIngredients.decreaseParameter) + algo.setProperty("lss", self.fakeIngredients.lss) + algo.setProperty("peakWindowClippingSize", self.fakeIngredients.peakWindowClippingSize) + algo.setProperty("smoothingParameter", self.fakeIngredients.smoothingParameter) algo.setProperty("OutputWorkspace", "test_output_ws") assert algo.execute() @@ -84,7 +90,10 @@ def test_output_data_characteristics(self): algo = Algo() algo.initialize() algo.setProperty("InputWorkspace", self.fakeRawData) - algo.setProperty("Ingredients", self.fakeIngredients.json()) + algo.setProperty("decreaseParameter", self.fakeIngredients.decreaseParameter) + algo.setProperty("lss", self.fakeIngredients.lss) + algo.setProperty("peakWindowClippingSize", self.fakeIngredients.peakWindowClippingSize) + algo.setProperty("smoothingParameter", self.fakeIngredients.smoothingParameter) algo.setProperty("OutputWorkspace", "test_output_ws") algo.execute() output_ws = mtd["test_output_ws"] diff --git a/tests/unit/backend/recipe/test_ApplyNormalizationRecipe.py b/tests/unit/backend/recipe/test_ApplyNormalizationRecipe.py index bea9c20fa..772d740ad 100644 --- a/tests/unit/backend/recipe/test_ApplyNormalizationRecipe.py +++ b/tests/unit/backend/recipe/test_ApplyNormalizationRecipe.py @@ -2,12 +2,10 @@ import pytest from mantid.simpleapi import CreateSingleValuedWorkspace, mtd -from util.Config_helpers import Config_override from util.SculleryBoy import SculleryBoy from snapred.backend.recipe.algorithm.Utensils import Utensils from snapred.backend.recipe.ApplyNormalizationRecipe import ApplyNormalizationRecipe, Ingredients -from snapred.meta.Config import Config class ApplyNormalizationRecipeTest(unittest.TestCase): @@ -32,17 +30,13 @@ def test_init_reuseUtensils(self): recipe = ApplyNormalizationRecipe(utensils=utensils) assert recipe.mantidSnapper == utensils.mantidSnapper - def test_chopIngredients(self): + @unittest.mock.patch("snapred.backend.recipe.ApplyNormalizationRecipe.RebinFocussedGroupDataRecipe") + def test_chopIngredients(self, mockRebinRecipe): # noqa: ARG002 recipe = ApplyNormalizationRecipe() ingredients = Ingredients(pixelGroup=self.sculleryBoy.prepPixelGroup()) recipe.chopIngredients(ingredients) assert recipe.pixelGroup == ingredients.pixelGroup - # Apply adjustment that happens in chopIngredients step - dMin = [x + Config["constants.CropFactors.lowdSpacingCrop"] for x in ingredients.pixelGroup.dMin()] - dMax = [x - Config["constants.CropFactors.highdSpacingCrop"] for x in ingredients.pixelGroup.dMax()] - assert recipe.dMin == dMin - assert recipe.dMax == dMax def test_unbagGroceries(self): recipe = ApplyNormalizationRecipe() @@ -93,25 +87,14 @@ def test_queueAlgos(self): queuedAlgos = recipe.mantidSnapper._algorithmQueue divideTuple = queuedAlgos[0] - rebinRaggedTuple = queuedAlgos[1] assert divideTuple[0] == "Divide" - assert rebinRaggedTuple[0] == "RebinRagged" - # Excessive testing maybe? - assert divideTuple[1] == "Dividing out the normalization.." - assert rebinRaggedTuple[1] == "Resampling X-axis..." - assert divideTuple[2]["LHSWorkspace"] == groceries["inputWorkspace"] - assert divideTuple[2]["RHSWorkspace"] == groceries["normalizationWorkspace"] - assert rebinRaggedTuple[2]["InputWorkspace"] == groceries["inputWorkspace"] - # Apply adjustment that happens in chopIngredients step - dMin = [x + Config["constants.CropFactors.lowdSpacingCrop"] for x in ingredients.pixelGroup.dMin()] - dMax = [x - Config["constants.CropFactors.highdSpacingCrop"] for x in ingredients.pixelGroup.dMax()] - assert rebinRaggedTuple[2]["XMin"] == dMin - assert rebinRaggedTuple[2]["XMax"] == dMax - - def test_cook(self): + + @unittest.mock.patch("snapred.backend.recipe.ApplyNormalizationRecipe.RebinFocussedGroupDataRecipe") + def test_cook(self, mockRebinRecipe): untensils = Utensils() mockSnapper = unittest.mock.Mock() + mockSnapper.mtd = unittest.mock.MagicMock() untensils.mantidSnapper = mockSnapper recipe = ApplyNormalizationRecipe(utensils=untensils) ingredients = Ingredients(pixelGroup=self.sculleryBoy.prepPixelGroup()) @@ -126,9 +109,10 @@ def test_cook(self): assert mockSnapper.executeQueue.called assert mockSnapper.Divide.called - assert mockSnapper.RebinRagged.called + assert mockRebinRecipe().cook.called - def test_cater(self): + @unittest.mock.patch("snapred.backend.recipe.ApplyNormalizationRecipe.RebinFocussedGroupDataRecipe") + def test_cater(self, mockRebinRecipe): untensils = Utensils() mockSnapper = unittest.mock.Mock() untensils.mantidSnapper = mockSnapper @@ -145,26 +129,4 @@ def test_cater(self): assert mockSnapper.executeQueue.called assert mockSnapper.Divide.called - assert mockSnapper.RebinRagged.called - - def test_badChopIngredients(self): - recipe = ApplyNormalizationRecipe() - ingredients = Ingredients(pixelGroup=self.sculleryBoy.prepPixelGroup()) - with ( - Config_override("constants.CropFactors.lowdSpacingCrop", 500.0), - Config_override("constants.CropFactors.highdSpacingCrop", 1000.0), - pytest.raises(ValueError, match="d-spacing crop factors are too large"), - ): - recipe.chopIngredients(ingredients) - - with ( - Config_override("constants.CropFactors.lowdSpacingCrop", -10.0), - pytest.raises(ValueError, match="Low d-spacing crop factor must be positive"), - ): - recipe.chopIngredients(ingredients) - - with ( - Config_override("constants.CropFactors.highdSpacingCrop", -10.0), - pytest.raises(ValueError, match="High d-spacing crop factor must be positive"), - ): - recipe.chopIngredients(ingredients) + assert mockRebinRecipe().cook.called diff --git a/tests/unit/backend/recipe/test_GenerateFocussedVanadiumRecipe.py b/tests/unit/backend/recipe/test_GenerateFocussedVanadiumRecipe.py index 6cb5a3934..f6c4b72f9 100644 --- a/tests/unit/backend/recipe/test_GenerateFocussedVanadiumRecipe.py +++ b/tests/unit/backend/recipe/test_GenerateFocussedVanadiumRecipe.py @@ -1,7 +1,6 @@ import unittest from unittest import mock -import pytest from mantid.simpleapi import ( LoadNexusProcessed, mtd, @@ -56,6 +55,7 @@ def tearDown(self) -> None: return super().tearDown() def test_execute_successful(self): + self.recipe._rebinInputWorkspace = mock.Mock() mock_instance = self.mockSnapper.SmoothDataExcludingPeaksAlgo.return_value mock_instance.execute.return_value = None mock_instance.getPropertyValue.return_value = self.fakeOutputWorkspace @@ -66,14 +66,22 @@ def test_execute_successful(self): self.assertEqual(output, expected_output) # noqa: PT009 - def test_execute_unsuccessful(self): - mock_instance = self.mockSnapper.SmoothDataExcludingPeaksAlgo.return_value - mock_instance.execute.return_value = None - mock_instance.getPropertyValue.return_value = None - - with pytest.raises(NotImplementedError) as e: - self.recipe.cook(self.fakeIngredients, self.errorGroceryList) - assert str(e.value) == "Fake Vanadium not implemented yet." # noqa: PT009 + def test_execute_artificial(self): + self.recipe._rebinInputWorkspace = mock.Mock() + self.fakeIngredients.artificialNormalizationIngredients = SculleryBoy().prepArtificialNormalizationIngredients() + self.recipe.cook(self.fakeIngredients, self.groceryList) + + self.recipe.mantidSnapper.CreateArtificialNormalizationAlgo.assert_called_once() + self.recipe.mantidSnapper.SmoothDataExcludingPeaksAlgo.assert_not_called() + self.recipe.mantidSnapper.CreateArtificialNormalizationAlgo.assert_called_with( + "Create Artificial Normalization...", + InputWorkspace=self.fakeInputWorkspace, + OutputWorkspace=self.fakeOutputWorkspace, + peakWindowClippingSize=10, + smoothingParameter=0.1, + decreaseParameter=True, + LSS=True, + ) def test_catering(self): self.recipe.cook = mock.Mock() @@ -83,3 +91,13 @@ def test_catering(self): assert self.recipe.cook.called assert output[0] == self.recipe.cook.return_value + + @mock.patch(f"{ThisRecipe}.RebinFocussedGroupDataRecipe") + def test_rebinInputWorkspace(self, mockRebinRecipe): + self.recipe.prep(self.fakeIngredients, self.groceryList) + self.recipe._rebinInputWorkspace() + + mockRebinRecipe().cook.assert_called_with( + mockRebinRecipe.Ingredients(pixelGroup=self.fakeIngredients.pixelGroup), + {"inputWorkspace": self.fakeInputWorkspace}, + ) diff --git a/tests/unit/backend/recipe/test_RebinFocussedGroupDataRecipe.py b/tests/unit/backend/recipe/test_RebinFocussedGroupDataRecipe.py new file mode 100644 index 000000000..d1673ea40 --- /dev/null +++ b/tests/unit/backend/recipe/test_RebinFocussedGroupDataRecipe.py @@ -0,0 +1,74 @@ +import unittest + +import pytest +from mantid.api import MatrixWorkspace +from mantid.simpleapi import ( + CreateSampleWorkspace, + DeleteWorkspaces, + GroupDetectors, + mtd, +) +from util.Config_helpers import Config_override +from util.dao import DAOFactory +from util.SculleryBoy import SculleryBoy + +from snapred.backend.recipe.RebinFocussedGroupDataRecipe import RebinFocussedGroupDataRecipe as Recipe + +ThisRecipe: str = "snapred.backend.recipe.RebinFocussedGroupDataRecipe" + + +class TestRebinFocussedGroupDataRecipe(unittest.TestCase): + sculleryBoy = SculleryBoy() + + def setUp(self): + testCalibration = DAOFactory.calibrationRecord("57514", True, 1) + self.pixelGroup = testCalibration.pixelGroups[0] + self.pixelGroup + self.ingredients = Recipe.Ingredients(pixelGroup=self.pixelGroup, preserveEvents=True) + + self.sampleWorkspace = "sampleWorkspace" + CreateSampleWorkspace( + OutputWorkspace=self.sampleWorkspace, + BankPixelWidth=3, + ) + GroupDetectors( + InputWorkspace=self.sampleWorkspace, + OutputWorkspace=self.sampleWorkspace, + GroupingPattern="0-3,4-5,6-8,9-17", + ) + + def tearDown(self) -> None: + DeleteWorkspaces(self.sampleWorkspace) + return super().tearDown() + + def test_recipe(self): + inputWs = mtd[self.sampleWorkspace] + assert not inputWs.isRaggedWorkspace() + recipe = Recipe() + recipe.cook(self.ingredients, {"inputWorkspace": self.sampleWorkspace}) + + outputWs = mtd[self.sampleWorkspace] + assert isinstance(outputWs, MatrixWorkspace) + assert outputWs.isRaggedWorkspace() + + def test_badChopIngredients(self): + ingredients = Recipe.Ingredients(pixelGroup=self.sculleryBoy.prepPixelGroup()) + recipe = Recipe() + with ( + Config_override("constants.CropFactors.lowdSpacingCrop", 500.0), + Config_override("constants.CropFactors.highdSpacingCrop", 1000.0), + pytest.raises(ValueError, match="d-spacing crop factors are too large"), + ): + recipe.chopIngredients(ingredients) + # + with ( + Config_override("constants.CropFactors.lowdSpacingCrop", -10.0), + pytest.raises(ValueError, match="Low d-spacing crop factor must be positive"), + ): + recipe.chopIngredients(ingredients) + # + with ( + Config_override("constants.CropFactors.highdSpacingCrop", -10.0), + pytest.raises(ValueError, match="High d-spacing crop factor must be positive"), + ): + recipe.chopIngredients(ingredients) diff --git a/tests/unit/backend/recipe/test_ReductionGroupProcessingRecipe.py b/tests/unit/backend/recipe/test_ReductionGroupProcessingRecipe.py index bb7325a8e..b915fa5e8 100644 --- a/tests/unit/backend/recipe/test_ReductionGroupProcessingRecipe.py +++ b/tests/unit/backend/recipe/test_ReductionGroupProcessingRecipe.py @@ -85,9 +85,9 @@ def test_cook(self): output = recipe.cook(self.mockIngredients(), groceries) assert recipe.rawInput == groceries["inputWorkspace"] - assert recipe.outputWS == groceries["inputWorkspace"] + assert recipe.outputWS == groceries["outputWorkspace"] assert recipe.groupingWS == groceries["groupingWorkspace"] - assert output == groceries["inputWorkspace"] + assert output == groceries["outputWorkspace"] assert mockSnapper.executeQueue.called assert mockSnapper.FocusSpectraAlgorithm.called @@ -110,9 +110,9 @@ def test_cater(self): output = recipe.cater([(self.mockIngredients(), groceries)]) assert recipe.rawInput == groceries["inputWorkspace"] - assert recipe.outputWS == groceries["inputWorkspace"] + assert recipe.outputWS == groceries["outputWorkspace"] assert recipe.groupingWS == groceries["groupingWorkspace"] - assert output[0] == groceries["inputWorkspace"] + assert output[0] == groceries["outputWorkspace"] assert mockSnapper.executeQueue.called assert mockSnapper.FocusSpectraAlgorithm.called diff --git a/tests/unit/backend/recipe/test_ReductionRecipe.py b/tests/unit/backend/recipe/test_ReductionRecipe.py index 65ce78d4f..ac651f530 100644 --- a/tests/unit/backend/recipe/test_ReductionRecipe.py +++ b/tests/unit/backend/recipe/test_ReductionRecipe.py @@ -384,8 +384,8 @@ def test_execute(self, mockMtd): recipe = ReductionRecipe() recipe.mantidSnapper = mockMantidSnapper recipe.mantidSnapper.mtd = mockMtd - recipe._prepareArtificialNormalization = mock.Mock() - recipe._prepareArtificialNormalization.return_value = "norm_grouped" + recipe._getNormalizationWorkspaceName = mock.Mock() + recipe._getNormalizationWorkspaceName.return_value = "norm_grouped" # Set up ingredients and other variables for the recipe recipe.groceries = {} @@ -447,12 +447,14 @@ def test_execute(self, mockMtd): recipe._applyRecipe.assert_any_call( GenerateFocussedVanadiumRecipe, recipe.ingredients.generateFocussedVanadium(0), - inputWorkspace="norm_grouped", + inputWorkspace="sample_grouped", + outputWorkspace=recipe._getNormalizationWorkspaceName.return_value, ) recipe._applyRecipe.assert_any_call( GenerateFocussedVanadiumRecipe, recipe.ingredients.generateFocussedVanadium(1), - inputWorkspace="norm_grouped", + inputWorkspace="sample_grouped", + outputWorkspace=recipe._getNormalizationWorkspaceName.return_value, ) recipe._applyRecipe.assert_any_call( @@ -468,9 +470,9 @@ def test_execute(self, mockMtd): normalizationWorkspace="norm_grouped", ) - recipe._prepareArtificialNormalization.call_count == 2 - recipe._prepareArtificialNormalization.assert_any_call("sample_grouped", 0) - recipe._prepareArtificialNormalization.assert_any_call("sample_grouped", 1) + recipe._getNormalizationWorkspaceName.call_count == 2 + recipe._getNormalizationWorkspaceName.assert_any_call(0) + recipe._getNormalizationWorkspaceName.assert_any_call(1) recipe.ingredients.effectiveInstrument.assert_not_called() @@ -497,8 +499,8 @@ def test_execute_useEffectiveInstrument(self, mockMtd): recipe = ReductionRecipe() recipe.mantidSnapper = mockMantidSnapper recipe.mantidSnapper.mtd = mockMtd - recipe._prepareArtificialNormalization = mock.Mock() - recipe._prepareArtificialNormalization.return_value = "norm_grouped" + recipe._getNormalizationWorkspaceName = mock.Mock() + recipe._getNormalizationWorkspaceName.return_value = "norm_grouped" # Set up ingredients and other variables for the recipe recipe.groceries = {} @@ -560,12 +562,14 @@ def test_execute_useEffectiveInstrument(self, mockMtd): recipe._applyRecipe.assert_any_call( GenerateFocussedVanadiumRecipe, recipe.ingredients.generateFocussedVanadium(0), - inputWorkspace="norm_grouped", + inputWorkspace="sample_grouped", + outputWorkspace=recipe._getNormalizationWorkspaceName.return_value, ) recipe._applyRecipe.assert_any_call( GenerateFocussedVanadiumRecipe, recipe.ingredients.generateFocussedVanadium(1), - inputWorkspace="norm_grouped", + inputWorkspace="sample_grouped", + outputWorkspace=recipe._getNormalizationWorkspaceName.return_value, ) recipe._applyRecipe.assert_any_call( @@ -581,9 +585,9 @@ def test_execute_useEffectiveInstrument(self, mockMtd): normalizationWorkspace="norm_grouped", ) - recipe._prepareArtificialNormalization.call_count == 2 - recipe._prepareArtificialNormalization.assert_any_call("sample_grouped", 0) - recipe._prepareArtificialNormalization.assert_any_call("sample_grouped", 1) + recipe._getNormalizationWorkspaceName.call_count == 2 + recipe._getNormalizationWorkspaceName.assert_any_call(0) + recipe._getNormalizationWorkspaceName.assert_any_call(1) recipe._applyRecipe.assert_any_call( EffectiveInstrumentRecipe, diff --git a/tests/unit/backend/service/test_ReductionService.py b/tests/unit/backend/service/test_ReductionService.py index c92f95856..10db2143d 100644 --- a/tests/unit/backend/service/test_ReductionService.py +++ b/tests/unit/backend/service/test_ReductionService.py @@ -471,23 +471,31 @@ def test_artificialNormalization(self, mockArtificialNormalizationRecipe): decreaseParameter=True, lss=True, diffractionWorkspace="mock_diffraction_workspace", - outputWorkspace="mock_output_workspace", + outputWorkspace="artificial_norm_dsp_column_000123_preview", ) result = self.instance.artificialNormalization(request) mockArtificialNormalizationRecipe.return_value.executeRecipe.assert_called_once_with( InputWorkspace=request.diffractionWorkspace, - Ingredients=mock.ANY, + peakWindowClippingSize=request.peakWindowClippingSize, + smoothingParameter=request.smoothingParameter, + decreaseParameter=request.decreaseParameter, + lss=request.lss, OutputWorkspace=request.outputWorkspace, ) assert result == mockResult + @mock.patch(thisService + "RebinFocussedGroupDataRecipe") @mock.patch(thisService + "ReductionGroupProcessingRecipe") @mock.patch(thisService + "GroceryService") @mock.patch(thisService + "DataFactoryService") def test_grabWorkspaceforArtificialNorm( - self, mockDataFactoryService, mockGroceryService, mockReductionGroupProcessingRecipe + self, + mockDataFactoryService, + mockGroceryService, + mockReductionGroupProcessingRecipe, + mockRebinFocussedGroupDataRecipe, ): self.instance.groceryService = mockGroceryService self.instance.dataFactoryService = mockDataFactoryService @@ -500,7 +508,10 @@ def test_grabWorkspaceforArtificialNorm( pixelMasks=[], focusGroups=[FocusGroup(name="apple", definition="path/to/grouping")], ) + mockIngredients = mock.Mock() + mockIngredients.pixelGroups = [mock.Mock()] + runWorkspaceName = "runworkspace" columnGroupingWS = "columnGroupingWS" self.instance.groceryService.fetchGroceryList.return_value = [runWorkspaceName] @@ -517,9 +528,13 @@ def test_grabWorkspaceforArtificialNorm( groceries = { "inputWorkspace": runWorkspaceName, "groupingWorkspace": columnGroupingWS, + "outputWorkspace": "artificial_norm_dsp_column_000123_source", } mockReductionGroupProcessingRecipe().cook.assert_called_once_with(mockIngredients.groupProcessing(0), groceries) + groceries = {"inputWorkspace": groceries["outputWorkspace"]} + rebinIngredients = mockRebinFocussedGroupDataRecipe.Ingredients() + mockRebinFocussedGroupDataRecipe().cook.assert_called_once_with(rebinIngredients, groceries) class TestReductionServiceMasks: diff --git a/tests/util/SculleryBoy.py b/tests/util/SculleryBoy.py index 21167e55f..63b9a997e 100644 --- a/tests/util/SculleryBoy.py +++ b/tests/util/SculleryBoy.py @@ -6,6 +6,7 @@ from snapred.backend.dao.GroupPeakList import GroupPeakList from snapred.backend.dao.ingredients import ( + ArtificialNormalizationIngredients, DiffractionCalibrationIngredients, NormalizationIngredients, ReductionIngredients, @@ -106,6 +107,9 @@ def prepNormalizationIngredients(self, ingredients: FarmFreshIngredients) -> Nor detectorPeaks=self.prepDetectorPeaks(ingredients), ) + def prepArtificialNormalizationIngredients(self): + return ArtificialNormalizationIngredients(smoothingParameter=0.1) + def verifyCalibrationExists(self, runNumber: str, useLiteMode: bool) -> bool: # noqa ARG002 return True From 195a469ac45a5a526812935fdddf0f66a6b81279 Mon Sep 17 00:00:00 2001 From: Michael Walsh <68125095+walshmm@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:54:43 -0500 Subject: [PATCH 3/7] Fix conda build failure dec 17 2024 (#522) * update python version of conda build job, update required mantid version in pyproject.toml * update conda build channels to account for mantid nightly --- .github/workflows/actions.yml | 4 ++-- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index de6f10715..13f55d60e 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -100,7 +100,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.8"] + python-version: ["3.10"] defaults: run: shell: bash -l {0} @@ -125,7 +125,7 @@ jobs: cd conda.recipe echo "versioningit $(versioningit ../)" # build the package - VERSION=$(versioningit ../) conda mambabuild --channel conda-forge --output-folder . . + VERSION=$(versioningit ../) conda mambabuild --channel conda-forge --channel mantid/label/nightly --output-folder . . conda verify noarch/snapred*.tar.bz2 - name: Deploy to Anaconda shell: bash -l {0} diff --git a/pyproject.toml b/pyproject.toml index b04de4928..ab9b5d5f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ description = "A desktop application for Lifecycle Managment of data collected f dynamic = ["version"] requires-python = ">=3.10" dependencies = [ - "mantidworkbench >= 6.10.0.2rc1", + "mantidworkbench >= 6.11.20241203", "pyoncat ~= 1.6" ] readme = "README.md" From a7d0f20afb314c84fce62b1a57dc93a614a588a6 Mon Sep 17 00:00:00 2001 From: Daniel Caballero <132241327+dlcaballero16@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:15:20 -0500 Subject: [PATCH 4/7] Updated integration tests to work with ultralite data (#516) * Updated integration tests to work with ultralite data, Integration tests now use ultralite version of run 58882, set more fields in integration_test.yml to get tests to work, moved ultralite stuff into tests location * More fixes * Fixed toggle issue * Fixed issue with grouping not being selected correctly * Updated git LFS reference --- tests/data/snapred-data | 2 +- .../test_workflow_panels_happy_path.py | 60 +++-- tests/resources/integration_test.yml | 50 ++++ .../ultralite/CRACKLEFocGroup_Column.xml | 21 ++ .../ultralite/CRACKLELiteDataMap.xml | 59 +++++ .../ultralite/CRACKLELite_Definition.xml | 237 ++++++++++++++++++ .../ultralite/CRACKLE_Definition.xml | 221 ++++++++++++++++ .../ultralite/create_ultralite_data.py | 141 +++++++++++ 8 files changed, 764 insertions(+), 27 deletions(-) create mode 100644 tests/resources/ultralite/CRACKLEFocGroup_Column.xml create mode 100644 tests/resources/ultralite/CRACKLELiteDataMap.xml create mode 100644 tests/resources/ultralite/CRACKLELite_Definition.xml create mode 100644 tests/resources/ultralite/CRACKLE_Definition.xml create mode 100644 tests/resources/ultralite/create_ultralite_data.py diff --git a/tests/data/snapred-data b/tests/data/snapred-data index bd6930ff5..59443a567 160000 --- a/tests/data/snapred-data +++ b/tests/data/snapred-data @@ -1 +1 @@ -Subproject commit bd6930ff57eef257a17adecdab9b7d9cea76850f +Subproject commit 59443a567f5b78447a9ee13ef5483af28fcdcaeb diff --git a/tests/integration/test_workflow_panels_happy_path.py b/tests/integration/test_workflow_panels_happy_path.py index eca307711..8351f2b10 100644 --- a/tests/integration/test_workflow_panels_happy_path.py +++ b/tests/integration/test_workflow_panels_happy_path.py @@ -227,7 +227,10 @@ def _setup_gui(self, qapp): lambda self, *args, **kwargs: QMessageBox.Ok if ( "The backend has encountered warning(s)" in self.text() - and "InstrumentDonor will only be used if GroupingFilename is in XML format." in self.detailedText() + and ( + "InstrumentDonor will only be used if GroupingFilename is in XML format." in self.detailedText() + or "No valid FocusGroups were specified for mode: 'lite'" in self.detailedText() + ) ) else pytest.fail( "unexpected QMessageBox.exec:" @@ -861,28 +864,29 @@ def test_diffraction_calibration_panel_happy_path(self, qtbot, qapp, calibration self.testSummary.SUCCESS() # set "Run Number", "Convergence Threshold", ,: - requestView.runNumberField.setText("46680") - requestView.fieldConvergenceThreshold.setText("0.1") - requestView.fieldNBinsAcrossPeakWidth.setText("10") + requestView.runNumberField.setText("58882") + requestView.litemodeToggle.setState(False) # set all dropdown selections, but make sure that the dropdown contents are as expected - requestView.sampleDropdown.setCurrentIndex(0) - assert requestView.sampleDropdown.currentIndex() == 0 - assert requestView.sampleDropdown.currentText().endswith("Diamond_001.json") + requestView.sampleDropdown.setCurrentIndex(3) + assert requestView.sampleDropdown.currentIndex() == 3 + assert requestView.sampleDropdown.currentText().endswith("Silicon_NIST_640D_001.json") # Without this next 'qtbot.wait(1000)', # the 'groupingFileDropdown' gets reset after this successful initialization. # I assume this is because somehow the 'populateGroupingDropdown', # triggered by the 'runNumberField' 'editComplete' hasn't actually occurred yet? qtbot.wait(1000) - requestView.groupingFileDropdown.setCurrentIndex(1) - assert requestView.groupingFileDropdown.currentIndex() == 1 - assert requestView.groupingFileDropdown.currentText() == "Bank" + requestView.groupingFileDropdown.setCurrentIndex(0) + assert requestView.groupingFileDropdown.currentIndex() == 0 + assert requestView.groupingFileDropdown.currentText() == "Column" requestView.peakFunctionDropdown.setCurrentIndex(0) assert requestView.peakFunctionDropdown.currentIndex() == 0 assert requestView.peakFunctionDropdown.currentText() == "Gaussian" + requestView.skipPixelCalToggle.setState(False) + self.testSummary.SUCCESS() # execute the request @@ -940,7 +944,7 @@ def test_diffraction_calibration_panel_happy_path(self, qtbot, qapp, calibration # Now that there is a new state, we need to reselect the grouping file ... : # Why was this error box being swallowed? - requestView.groupingFileDropdown.setCurrentIndex(1) + requestView.groupingFileDropdown.setCurrentIndex(0) # (2) execute the calibration workflow with qtbot.waitSignal(actionCompleted, timeout=60000): @@ -962,11 +966,11 @@ def test_diffraction_calibration_panel_happy_path(self, qtbot, qapp, calibration warningMessageBox.start() # --------------------------------------------------------------------------- - # set "xtal dMin", "FWHM left", and "FWHM right": these are sufficient to get "46680" to pass. + # set "xtal dMin", "FWHM left", and "FWHM right": these are sufficient to get "58882" to pass. # TODO: set ALL of the relevant fields, and use a test initialization template for this. - tweakPeakView.fieldXtalDMin.setText("0.72") - tweakPeakView.fieldFWHMleft.setText("2.0") - tweakPeakView.fieldFWHMright.setText("2.0") + tweakPeakView.fieldFWHMleft.setText("1.5") + tweakPeakView.fieldFWHMright.setText("2") + tweakPeakView.maxChiSqField.setText("1000.0") tweakPeakView.peakFunctionDropdown.setCurrentIndex(0) assert tweakPeakView.peakFunctionDropdown.currentIndex() == 0 assert tweakPeakView.peakFunctionDropdown.currentText() == "Gaussian" @@ -1113,23 +1117,26 @@ def test_normalization_panel_happy_path(self, qtbot, qapp, calibration_home_from self.testSummary.SUCCESS() # set "Run Number", "Background run number": - requestView.runNumberField.setText("46680") - requestView.backgroundRunNumberField.setText("46680") + requestView.runNumberField.setText("58882") + requestView.backgroundRunNumberField.setText("58882") + + requestView.litemodeToggle.setState(False) # set all dropdown selections, but make sure that the dropdown contents are as expected - requestView.sampleDropdown.setCurrentIndex(0) - assert requestView.sampleDropdown.currentIndex() == 0 - assert requestView.sampleDropdown.currentText().endswith("Diamond_001.json") + requestView.sampleDropdown.setCurrentIndex(3) + assert requestView.sampleDropdown.currentIndex() == 3 + assert requestView.sampleDropdown.currentText().endswith("Silicon_NIST_640D_001.json") # Without this next 'qtbot.wait(1000)', # the 'groupingFileDropdown' gets reset after this successful initialization. # I assume this is because somehow the 'populateGroupingDropdown', # triggered by the 'runNumberField' 'editComplete' hasn't actually occurred yet? qtbot.wait(1000) - requestView.groupingFileDropdown.setCurrentIndex(1) - assert requestView.groupingFileDropdown.currentIndex() == 1 - assert requestView.groupingFileDropdown.currentText() == "Bank" + requestView.groupingFileDropdown.setCurrentIndex(0) + assert requestView.groupingFileDropdown.currentIndex() == 0 + assert requestView.groupingFileDropdown.currentText() == "Column" self.testSummary.SUCCESS() + """ # Why no "peak function" for normalization calibration?! requestView.peakFunctionDropdown.setCurrentIndex(0) assert requestView.peakFunctionDropdown.currentIndex() == 0 @@ -1193,7 +1200,7 @@ def test_normalization_panel_happy_path(self, qtbot, qapp, calibration_home_from # Now that there is a new state, we need to reselect the grouping file ... : # Why was this error box being swallowed? - requestView.groupingFileDropdown.setCurrentIndex(1) + requestView.groupingFileDropdown.setCurrentIndex(0) warningMessageBox = mock.patch( # noqa: PT008 "qtpy.QtWidgets.QMessageBox.warning", @@ -1220,7 +1227,7 @@ def test_normalization_panel_happy_path(self, qtbot, qapp, calibration_home_from qtbot.wait(1000) tweakPeakView.groupingFileDropdown.setCurrentIndex(0) assert tweakPeakView.groupingFileDropdown.currentIndex() == 0 - assert tweakPeakView.groupingFileDropdown.currentText() == "All" + assert tweakPeakView.groupingFileDropdown.currentText() == "Column" # recalculate using the new values # * recalculate => peak display is recalculated, @@ -1278,7 +1285,7 @@ def test_reduction_panel_happy_path(self, qtbot, qapp, reduction_home_from_mirro ## # TODO: these could be initialized in the 'setup', but the current plan is to use a YML test template. - reductionRunNumber = "46680" + reductionRunNumber = "58882" reductionStateId = "04bd2c53f6bf6754" # Override the standard reduction-output location, using a temporary directory @@ -1370,6 +1377,7 @@ def completionMessageBoxAssert(*args, **kwargs): # noqa: ARG001 qtbot.wait(1000) # enter a "Run Number": + requestView.liteModeToggle.setState(False) requestView.runNumberInput.setText(reductionRunNumber) qtbot.mouseClick(requestView.enterRunNumberButton, Qt.MouseButton.LeftButton) diff --git a/tests/resources/integration_test.yml b/tests/resources/integration_test.yml index 2be427cad..bb52b9671 100644 --- a/tests/resources/integration_test.yml +++ b/tests/resources/integration_test.yml @@ -12,7 +12,57 @@ IPTS: constants: # For tests with '46680' this seems to be necessary. maskedPixelThreshold: 1.0 + logsLocation: "/mantid_workspace_1/logs" + + DetectorPeakPredictor: + fwhm: 1.17741002252 # used to convert gaussian to fwhm (2 * log_e(2)) RawVanadiumCorrection: numberOfSlices: 1 numberOfAnnuli: 1 + +instrument: + native: + pixelResolution: 72 + definition: + file: ${module.root}/resources/ultralite/CRACKLE_Definition.xml + lite: + pixelResolution: 18 + definition: + file: ${module.root}/resources/ultralite/CRACKLELite_Definition.xml + map: + file: ${module.root}/resources/ultralite/CRACKLELiteDataMap.xml + +mantid: + workspace: + nameTemplate: + delimiter: "_" + template: + run: "{unit},{group},{lite},{auxiliary},{runNumber}" + diffCal: + input: "{unit},{runNumber},raw" + table: "diffract_consts,{runNumber},{version}" + output: "{unit},{group},{runNumber},{version}" + diagnostic: "diagnostic,{group},{runNumber},{version}" + mask: "diffract_consts,mask,{runNumber},{version}" + metric: "calib_metrics,{metricName},{runNumber},{version}" + timedMetric: "calib_metrics,{metricName},{runNumber},{timestamp}" + normCal: + rawVanadium: "{unit},{group},{runNumber},raw_van_corr,{version}" + focusedRawVanadium: "{unit},{group},{runNumber},raw_van_corr,{version}" + smoothedFocusedRawVanadium: "{unit},{group},{runNumber},fitted_van_corr,{version}" + +calibration: + parameters: + default: + alpha: 0.1 + # alpha: 1.1 + beta: + - 0.02 + - 0.05 + # beta: + # - 1 + # - 2 + fitting: + # minSignal2Noise: 0.0 + minSignal2Noise: 10 diff --git a/tests/resources/ultralite/CRACKLEFocGroup_Column.xml b/tests/resources/ultralite/CRACKLEFocGroup_Column.xml new file mode 100644 index 000000000..3a6fe2fb8 --- /dev/null +++ b/tests/resources/ultralite/CRACKLEFocGroup_Column.xml @@ -0,0 +1,21 @@ + + + + 0-11 + + + 12-23 + + + 24-35 + + + 36-47 + + + 48-59 + + + 60-71 + + diff --git a/tests/resources/ultralite/CRACKLELiteDataMap.xml b/tests/resources/ultralite/CRACKLELiteDataMap.xml new file mode 100644 index 000000000..81f4bcd02 --- /dev/null +++ b/tests/resources/ultralite/CRACKLELiteDataMap.xml @@ -0,0 +1,59 @@ + + + + + 0,1,2,3 + + + 4,5,6,7 + + + 8,9,10,11 + + + 12,13,14,15 + + + 16,17,18,19 + + + 20,21,22,23 + + + 24,25,26,27 + + + 28,29,30,31 + + + 32,33,34,35 + + + + 36,37,38,39 + + + 40,41,42,43 + + + 44,45,46,47 + + + 48,49,50,51 + + + 52,53,54,55 + + + 56,57,58,59 + + + 60,61,62,63 + + + 64,65,66,67 + + + 68,69,70,71 + + diff --git a/tests/resources/ultralite/CRACKLELite_Definition.xml b/tests/resources/ultralite/CRACKLELite_Definition.xml new file mode 100644 index 000000000..f849c97b6 --- /dev/null +++ b/tests/resources/ultralite/CRACKLELite_Definition.xml @@ -0,0 +1,237 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/resources/ultralite/CRACKLE_Definition.xml b/tests/resources/ultralite/CRACKLE_Definition.xml new file mode 100644 index 000000000..05e80e474 --- /dev/null +++ b/tests/resources/ultralite/CRACKLE_Definition.xml @@ -0,0 +1,221 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/resources/ultralite/create_ultralite_data.py b/tests/resources/ultralite/create_ultralite_data.py new file mode 100644 index 000000000..b48a9f74c --- /dev/null +++ b/tests/resources/ultralite/create_ultralite_data.py @@ -0,0 +1,141 @@ +# import mantid algorithms, numpy and matplotlib + +from mantid.simpleapi import * + +from snapred.backend.dao.ingredients.GroceryListItem import GroceryListItem +from snapred.backend.data.DataFactoryService import DataFactoryService +from snapred.backend.data.GroceryService import GroceryService +from snapred.meta.Config import Resource + +Resource._resourcesPath = os.path.expanduser("~/SNS/SNAP/shared/Calibration_next/Powder/") +liteInstrumentFile = Resource.getPath("CRACKLE_Definition.xml") +dfs = DataFactoryService() + + +def superID(nativeID, xdim, ydim): + # accepts a numpy array of native ID from standard SNAP nexus file and returns a numpy array with + # super pixel ID according to provided dimensions xdim and ydim of the super pixel. + # xdim and ydim shall be multiples of 2 + + Nx = 256 # native number of horizontal pixels + Ny = 256 # native number of vertical pixels + NNat = Nx * Ny # native number of pixels per panel + + firstPix = (nativeID // NNat) * NNat + redID = nativeID % NNat # reduced ID beginning at zero in each panel + + (i, j) = divmod(redID, Ny) # native (reduced) coordinates on pixel face + superi = divmod(i, xdim)[0] + superj = divmod(j, ydim)[0] + + # some basics of the super panel + superNx = Nx / xdim # 32 running from 0 to 31 + superNy = Ny / ydim + superN = superNx * superNy + + superFirstPix = (firstPix / NNat) * superN + + superVal = superi * superNy + superj + superFirstPix + + return superVal + + +# create the mapping +LoadEmptyInstrument( + Filename="/SNS/SNAP/shared/Malcolm/dataFiles/SNAP_Definition.xml", + OutputWorkspace="SNAP", +) + +mapToCrackle = "map_from_SNAP_to_CRACKLE" +if mapToCrackle not in mtd: + # create the lite grouping ws using input run as template + CreateGroupingWorkspace( + InputWorkspace="SNAP", + GroupDetectorsBy="All", + OutputWorkspace=mapToCrackle, + ) + ws = mtd[mapToCrackle] + nHst = ws.getNumberHistograms() + for spec in range(nHst): + ws.setY(spec, [superID(spec, 128, 128) + 1]) + +# select run to convert to ultralite data, can convert multiple runs at once +runs_to_reduce = ["58882"] # ["46680", "58810", "58813", "57514"] + +clerk = GroceryListItem.builder() +for x in runs_to_reduce: + clerk.neutron(x).native().add() +groceries = GroceryService().fetchGroceryList(clerk.buildList()) + + +# The FileName should point to a "diffract_consts__v#.h5 file, this gets saved at the end of a diffcal run +LoadDiffCal( + InputWorkspace=groceries[0], + FileName="/SNS/users/8l2/SNS/SNAP/shared/Calibration_next/Powder/04bd2c53f6bf6754/native/diffraction/v_0003/diffract_consts_057514_v0003.h5", + WorkspaceName="57514", +) +# If set to False, will output data as histograms +eventMode = True + +for grocery in groceries: + ws = mtd[grocery] + ultralite = f"{grocery}_ULTRALITE" + CloneWorkspace( + InputWorkspace=grocery, + OutputWorkspace=ultralite, + ) + ConvertUnits( + InputWorkspace=ultralite, + OutputWorkspace=ultralite, + Target="dSpacing", + ) + if not eventMode: + uws = mtd[ultralite] + Rebin(InputWorkspace=ultralite, OutputWorkspace=ultralite, Params=(uws.getTofMin(), -0.001, uws.getTofMax())) + DiffractionFocussing( + InputWorkspace=ultralite, + OutputWorkspace=ultralite, + GroupingWorkspace=mapToCrackle, + PreserveEvents=eventMode, + ) + LoadInstrument( + Workspace=ultralite, + Filename=liteInstrumentFile, + RewriteSpectraMap=True, + ) + ConvertUnits( + InputWorkspace=ultralite, + OutputWorkspace=ultralite, + Target="TOF", + ) + if eventMode: + CompressEvents( + InputWorkspace=ultralite, + OutputWorkspace=ultralite, + BinningMode="Logarithmic", + Tolerance=-0.0001, + ) + uws = mtd[ultralite] + Rebin(InputWorkspace=ultralite, OutputWorkspace=ultralite, Params=(uws.getTofMax() - uws.getTofMin())) + logs = ( + "BL3:Det:TH:BL:Frequency", + "BL3:Mot:OpticsPos:Pos", + "BL3:Chop:Gbl:WavelengthReq", + "BL3:Chop:Skf1:WavelengthUserReq", + "BL3:Chop:Gbl:WavelengthReq", + "BL3:Chop:Skf1:WavelengthUserReq", + "det_arc1", + "det_arc2", + "BL3:Det:TH:BL:Frequency", + "BL3:Mot:OpticsPos:Pos", + "det_lin1", + "det_lin2", + "proton_charge", + "gd_prtn_chrg", + ) + RemoveLogs(Workspace=ultralite, KeepLogs=logs) + SaveNexusProcessed( + InputWorkspace=ultralite, + Filename=f"~/Documents/ultralite/{ultralite}.nxs.h5", + CompressNexus=True, + ) From c8603586e0009749185194e4fa6db356b9fbf65a Mon Sep 17 00:00:00 2001 From: Reece Boston <52183986+rboston628@users.noreply.github.com> Date: Fri, 20 Dec 2024 10:17:17 -0500 Subject: [PATCH 5/7] Update the use of pointers within certain algorithms (#523) * update the use of pointers within certain algorithms * fix minor defect with toggles --- .../recipe/GenerateFocussedVanadiumRecipe.py | 3 +- .../DiffractionSpectrumWeightCalculator.py | 44 +++++---- .../backend/recipe/algorithm/MantidSnapper.py | 2 +- .../recipe/algorithm/RemoveEventBackground.py | 3 +- .../algorithm/SmoothDataExcludingPeaksAlgo.py | 90 ++++++++++--------- .../backend/service/NormalizationService.py | 5 +- src/snapred/meta/pointer.py | 23 ++++- src/snapred/ui/workflow/DiffCalWorkflow.py | 5 +- .../cis_tests/smooth_data_excluding_peaks.py | 41 ++++++--- ...est_DiffractionSpectrumWeightCalculator.py | 26 +++--- .../algorithm/test_RemoveEventBackground.py | 25 ++---- .../test_SmoothDataExcludingPeaksAlgo.py | 29 ++++-- .../service/test_NormalizationService.py | 4 +- tests/unit/meta/test_pointer.py | 26 +++++- 14 files changed, 191 insertions(+), 135 deletions(-) diff --git a/src/snapred/backend/recipe/GenerateFocussedVanadiumRecipe.py b/src/snapred/backend/recipe/GenerateFocussedVanadiumRecipe.py index 6623b9b81..b29780b19 100644 --- a/src/snapred/backend/recipe/GenerateFocussedVanadiumRecipe.py +++ b/src/snapred/backend/recipe/GenerateFocussedVanadiumRecipe.py @@ -5,7 +5,6 @@ from snapred.backend.recipe.RebinFocussedGroupDataRecipe import RebinFocussedGroupDataRecipe from snapred.backend.recipe.Recipe import Recipe from snapred.meta.decorators.Singleton import Singleton -from snapred.meta.redantic import list_to_raw logger = snapredLogger.getLogger(__name__) @@ -26,7 +25,7 @@ class GenerateFocussedVanadiumRecipe(Recipe[Ingredients]): def chopIngredients(self, ingredients: Ingredients): self.smoothingParameter = ingredients.smoothingParameter - self.detectorPeaks = list_to_raw(ingredients.detectorPeaks) if ingredients.detectorPeaks is not None else None + self.detectorPeaks = ingredients.detectorPeaks self.pixelGroup = ingredients.pixelGroup self.artificialNormalizationIngredients = ingredients.artificialNormalizationIngredients diff --git a/src/snapred/backend/recipe/algorithm/DiffractionSpectrumWeightCalculator.py b/src/snapred/backend/recipe/algorithm/DiffractionSpectrumWeightCalculator.py index ba90aa356..730a64e27 100644 --- a/src/snapred/backend/recipe/algorithm/DiffractionSpectrumWeightCalculator.py +++ b/src/snapred/backend/recipe/algorithm/DiffractionSpectrumWeightCalculator.py @@ -1,14 +1,19 @@ -import json from typing import Dict, List import numpy as np -import pydantic -from mantid.api import AlgorithmFactory, IEventWorkspace, MatrixWorkspaceProperty, PropertyMode, PythonAlgorithm +from mantid.api import IEventWorkspace, MatrixWorkspaceProperty, PropertyMode, PythonAlgorithm from mantid.kernel import Direction +from mantid.kernel import ULongLongPropertyWithValue as PointerProperty +from mantid.simpleapi import ( + CloneWorkspace, + ConvertToEventWorkspace, + ConvertToMatrixWorkspace, + mtd, +) from snapred.backend.dao.GroupPeakList import GroupPeakList from snapred.backend.log.logger import snapredLogger -from snapred.backend.recipe.algorithm.MantidSnapper import MantidSnapper +from snapred.meta.pointer import access_pointer, inspect_pointer logger = snapredLogger.getLogger(__name__) @@ -27,10 +32,12 @@ def PyInit(self): MatrixWorkspaceProperty("WeightWorkspace", "", Direction.Output, PropertyMode.Mandatory), doc="The output workspace to be created by the algorithm", ) - self.declareProperty("DetectorPeaks", defaultValue="", direction=Direction.Input) + self.declareProperty( + PointerProperty("DetectorPeaks", id(None)), + doc="The memory address pointing to the list of grouped peaks.", + ) self.setRethrows(True) - self.mantidSnapper = MantidSnapper(self, __name__) def chopIngredients(self, ingredients: List[GroupPeakList]): self.groupIDs = [] @@ -44,25 +51,22 @@ def unbagGroceries(self): self.inputWorkspaceName = self.getPropertyValue("InputWorkspace") self.weightWorkspaceName = self.getPropertyValue("WeightWorkspace") # clone input workspace to create a weight workspace - self.mantidSnapper.CloneWorkspace( - "Cloning a weighting workspce...", + CloneWorkspace( InputWorkspace=self.inputWorkspaceName, OutputWorkspace=self.weightWorkspaceName, ) # check if this is event data, and if so covnert it to historgram data - self.isEventWorkspace = isinstance(self.mantidSnapper.mtd[self.inputWorkspaceName], IEventWorkspace) + self.isEventWorkspace = isinstance(mtd[self.inputWorkspaceName], IEventWorkspace) if self.isEventWorkspace: - self.mantidSnapper.ConvertToMatrixWorkspace( - "Converting event workspace to histogram workspace", + ConvertToMatrixWorkspace( InputWorkspace=self.weightWorkspaceName, OutputWorkspace=self.weightWorkspaceName, ) - self.mantidSnapper.executeQueue() def validateInputs(self) -> Dict[str, str]: errors = {} ws = self.getProperty("InputWorkspace").value - ingredients = json.loads(self.getPropertyValue("DetectorPeaks")) + ingredients = inspect_pointer(self.getProperty("DetectorPeaks").value) if ws.getNumberHistograms() != len(ingredients): msg = f""" Number of histograms {ws.getNumberHistograms()} @@ -73,13 +77,12 @@ def validateInputs(self) -> Dict[str, str]: return errors def PyExec(self): - predictedPeaksList = pydantic.TypeAdapter(List[GroupPeakList]).validate_json( - self.getPropertyValue("DetectorPeaks") - ) + peak_ptr: PointerProperty = self.getProperty("DetectorPeaks").value + predictedPeaksList = access_pointer(peak_ptr) self.chopIngredients(predictedPeaksList) self.unbagGroceries() - weight_ws = self.mantidSnapper.mtd[self.weightWorkspaceName] + weight_ws = mtd[self.weightWorkspaceName] for index, groupID in enumerate(self.groupIDs): # get spectrum X,Y x = weight_ws.readX(index) @@ -95,13 +98,8 @@ def PyExec(self): weight_ws.setY(index, weights) if self.isEventWorkspace: - self.mantidSnapper.ConvertToEventWorkspace( - "Converting histogram workspace back to event workspace", + ConvertToEventWorkspace( InputWorkspace=self.weightWorkspaceName, OutputWorkspace=self.weightWorkspaceName, ) - self.mantidSnapper.executeQueue() self.setPropertyValue("WeightWorkspace", self.weightWorkspaceName) - - -AlgorithmFactory.subscribe(DiffractionSpectrumWeightCalculator) diff --git a/src/snapred/backend/recipe/algorithm/MantidSnapper.py b/src/snapred/backend/recipe/algorithm/MantidSnapper.py index 13b7f8eb4..99b259bf1 100644 --- a/src/snapred/backend/recipe/algorithm/MantidSnapper.py +++ b/src/snapred/backend/recipe/algorithm/MantidSnapper.py @@ -147,7 +147,7 @@ def executeAlgorithm(self, name, outputs, **kwargs): # for pointer property, set via its pointer # allows for "pass-by-reference"-like behavior # this is safe even if the memory address is directly passed - if isinstance(algorithm.getProperty(prop), PointerProperty): + if isinstance(algorithm.getProperty(prop), PointerProperty) and type(val) is not int: val = create_pointer(val) algorithm.setProperty(prop, val) if not algorithm.execute(): diff --git a/src/snapred/backend/recipe/algorithm/RemoveEventBackground.py b/src/snapred/backend/recipe/algorithm/RemoveEventBackground.py index 30dcb92ac..2c31ee02a 100644 --- a/src/snapred/backend/recipe/algorithm/RemoveEventBackground.py +++ b/src/snapred/backend/recipe/algorithm/RemoveEventBackground.py @@ -9,7 +9,7 @@ PythonAlgorithm, WorkspaceUnitValidator, ) -from mantid.kernel import Direction +from mantid.kernel import Direction, FloatBoundedValidator from mantid.kernel import ULongLongPropertyWithValue as PointerProperty from mantid.simpleapi import ( ConvertToMatrixWorkspace, @@ -58,6 +58,7 @@ def PyInit(self): "SmoothingParameter", defaultValue=Config["calibration.diffraction.smoothingParameter"], direction=Direction.Input, + validator=FloatBoundedValidator(lower=0.0), ) self.setRethrows(True) diff --git a/src/snapred/backend/recipe/algorithm/SmoothDataExcludingPeaksAlgo.py b/src/snapred/backend/recipe/algorithm/SmoothDataExcludingPeaksAlgo.py index 17fdeed51..9013e40e4 100644 --- a/src/snapred/backend/recipe/algorithm/SmoothDataExcludingPeaksAlgo.py +++ b/src/snapred/backend/recipe/algorithm/SmoothDataExcludingPeaksAlgo.py @@ -10,22 +10,27 @@ create new workspace with csaps data """ -from datetime import datetime from typing import Dict from mantid.api import ( - AlgorithmFactory, IEventWorkspace, MatrixWorkspaceProperty, PropertyMode, PythonAlgorithm, WorkspaceUnitValidator, ) -from mantid.kernel import Direction +from mantid.kernel import Direction, FloatBoundedValidator +from mantid.kernel import ULongLongPropertyWithValue as PointerProperty +from mantid.simpleapi import ( + CloneWorkspace, + ConvertToMatrixWorkspace, + DiffractionSpectrumWeightCalculator, + WashDishes, + mtd, +) from scipy.interpolate import make_smoothing_spline from snapred.backend.log.logger import snapredLogger -from snapred.backend.recipe.algorithm.MantidSnapper import MantidSnapper logger = snapredLogger.getLogger(__name__) @@ -39,9 +44,9 @@ def PyInit(self): self.declareProperty( MatrixWorkspaceProperty( "InputWorkspace", - "", - Direction.Input, - PropertyMode.Mandatory, + defaultValue="", + direction=Direction.Input, + optional=PropertyMode.Mandatory, validator=WorkspaceUnitValidator("dSpacing"), ), doc="Workspace containing the peaks to be removed", @@ -49,17 +54,23 @@ def PyInit(self): self.declareProperty( MatrixWorkspaceProperty( "OutputWorkspace", - "", - Direction.Output, - PropertyMode.Mandatory, + defaultValue="", + direction=Direction.Output, + optional=PropertyMode.Mandatory, validator=WorkspaceUnitValidator("dSpacing"), ), doc="Histogram Workspace with removed peaks", ) - self.declareProperty("DetectorPeaks", defaultValue="", direction=Direction.Input) - self.declareProperty("SmoothingParameter", defaultValue=-1.0, direction=Direction.Input) + self.declareProperty( + PointerProperty("DetectorPeaks", id(None)), + "The memory address pointing to the list of grouped peaks.", + ) + self.declareProperty( + "SmoothingParameter", + defaultValue=-1.0, + validator=FloatBoundedValidator(lower=0.0), + ) self.setRethrows(True) - self.mantidSnapper = MantidSnapper(self, __name__) def chopIngredients(self, ingredients): # noqa ARG002 # NOTE there are no ingredients @@ -67,7 +78,6 @@ def chopIngredients(self, ingredients): # noqa ARG002 def unbagGroceries(self): self.inputWorkspaceName = self.getPropertyValue("InputWorkspace") - self.weightWorkspaceName = datetime.now().ctime() + "_weight_ws" self.outputWorkspaceName = self.getPropertyValue("OutputWorkspace") def validateInputs(self) -> Dict[str, str]: @@ -76,72 +86,66 @@ def validateInputs(self) -> Dict[str, str]: def PyExec(self): self.log().notice("Removing peaks and smoothing data") - self.lam = float(self.getPropertyValue("SmoothingParameter")) + self.lam = self.getProperty("SmoothingParameter").value self.unbagGroceries() # copy input to make output workspace if self.inputWorkspaceName != self.outputWorkspaceName: - self.mantidSnapper.CloneWorkspace( - "Cloning new workspace for smoothed spectrum data...", + CloneWorkspace( InputWorkspace=self.inputWorkspaceName, OutputWorkspace=self.outputWorkspaceName, ) # check if input is an event workspace - if isinstance(self.mantidSnapper.mtd[self.inputWorkspaceName], IEventWorkspace): + if isinstance(mtd[self.inputWorkspaceName], IEventWorkspace): # convert it to a histogram - self.mantidSnapper.ConvertToMatrixWorkspace( - "Converting event workspace to histogram...", + ConvertToMatrixWorkspace( InputWorkspace=self.outputWorkspaceName, OutputWorkspace=self.outputWorkspaceName, ) # call the diffraction spectrum weight calculator - self.mantidSnapper.DiffractionSpectrumWeightCalculator( - "Calculating spectrum weights...", + weightWSname = mtd.unique_name(prefix="_weight_") + DiffractionSpectrumWeightCalculator( InputWorkspace=self.outputWorkspaceName, - DetectorPeaks=self.getPropertyValue("DetectorPeaks"), - WeightWorkspace=self.weightWorkspaceName, + DetectorPeaks=self.getProperty("DetectorPeaks").value, + WeightWorkspace=weightWSname, ) - self.mantidSnapper.executeQueue() - # get handles to the workspaces - inputWorkspace = self.mantidSnapper.mtd[self.inputWorkspaceName] - outputWorkspace = self.mantidSnapper.mtd[self.outputWorkspaceName] - weightWorkspace = self.mantidSnapper.mtd[self.weightWorkspaceName] + inputWorkspace = mtd[self.inputWorkspaceName] + outputWorkspace = mtd[self.outputWorkspaceName] + weightWorkspace = mtd[weightWSname] numSpec = weightWorkspace.getNumberHistograms() for index in range(numSpec): - x = inputWorkspace.readX(index) - y = inputWorkspace.readY(index) - + # get the weights weightX = weightWorkspace.readX(index) weightY = weightWorkspace.readY(index) + # use the weight midpoint weightXMidpoints = (weightX[:-1] + weightX[1:]) / 2 - xMidpoints = (x[:-1] + x[1:]) / 2 - weightXMidpoints = weightXMidpoints[weightY != 0] + + # get the data for background and remove peaks + y = inputWorkspace.readY(index).copy() y = y[weightY != 0] + x = inputWorkspace.readX(index) + xMidpoints = (x[:-1] + x[1:]) / 2.0 # throw an exception if y or weightXMidpoints are empty if len(y) == 0 or len(weightXMidpoints) == 0: raise ValueError("No data in the workspace, all data removed by peak removal.") + # Generate spline with purged dataset tck = make_smoothing_spline(weightXMidpoints, y, lam=self.lam) # fill in the removed data using the spline function and original datapoints smoothing_results = tck(xMidpoints, extrapolate=False) + smoothing_results[smoothing_results < 0] = 0 outputWorkspace.setY(index, smoothing_results) - self.mantidSnapper.WashDishes( - "Cleaning up weight workspace...", - Workspace=self.weightWorkspaceName, - ) - self.mantidSnapper.executeQueue() - self.setProperty("OutputWorkspace", outputWorkspace) + # cleanup + WashDishes(weightWSname) - -# Register algorithm with Mantid -AlgorithmFactory.subscribe(SmoothDataExcludingPeaksAlgo) + self.setProperty("OutputWorkspace", outputWorkspace) diff --git a/src/snapred/backend/service/NormalizationService.py b/src/snapred/backend/service/NormalizationService.py index 9c20ff0a1..aa7ae9718 100644 --- a/src/snapred/backend/service/NormalizationService.py +++ b/src/snapred/backend/service/NormalizationService.py @@ -45,6 +45,7 @@ WorkspaceName, ) from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceNameGenerator as wng +from snapred.meta.pointer import create_pointer from snapred.meta.redantic import parse_obj_as logger = snapredLogger.getLogger(__name__) @@ -164,7 +165,7 @@ def normalization(self, request: NormalizationRequest): # 3. smooth SmoothDataExcludingPeaksRecipe().executeRecipe( InputWorkspace=focusedVanadium, - DetectorPeaks=ingredients.detectorPeaks, + DetectorPeaks=create_pointer(ingredients.detectorPeaks), SmoothingParameter=request.smoothingParameter, OutputWorkspace=smoothedVanadium, ) @@ -369,7 +370,7 @@ def smoothDataExcludingPeaks(self, request: SmoothDataExcludingPeaksRequest): SmoothDataExcludingPeaksRecipe().executeRecipe( InputWorkspace=request.inputWorkspace, OutputWorkspace=request.outputWorkspace, - DetectorPeaks=peaks, + DetectorPeaks=create_pointer(peaks), SmoothingParameter=request.smoothingParameter, ) diff --git a/src/snapred/meta/pointer.py b/src/snapred/meta/pointer.py index ddb7f2d86..350a004c8 100644 --- a/src/snapred/meta/pointer.py +++ b/src/snapred/meta/pointer.py @@ -15,8 +15,25 @@ def create_pointer(thing: Any) -> int: return id(thing) -def access_pointer(pointer: int) -> Any: - thing = ctypes.cast(pointer, ctypes.py_object).value +def inspect_pointer(pointer: int) -> Any: + """ + Fetch an object referenced by the pointer, without removing it from the cache. + Useful for validateInputs with a pointer property. + @param the pointer to the object + @return the object pointed to + """ if pointer in OBJCACHE: - del OBJCACHE[pointer] + return ctypes.cast(pointer, ctypes.py_object).value + else: + raise RuntimeError(f"No appropriate object held at address {hex(pointer)}") + + +def access_pointer(pointer: int) -> Any: + """ + Fetch an objected referenced by the pointer, and remove it from the cache. + @param the pointer to the object + @return the object pointed to + """ + thing = inspect_pointer(pointer) + del OBJCACHE[pointer] return thing diff --git a/src/snapred/ui/workflow/DiffCalWorkflow.py b/src/snapred/ui/workflow/DiffCalWorkflow.py index 648343393..93e3eb082 100644 --- a/src/snapred/ui/workflow/DiffCalWorkflow.py +++ b/src/snapred/ui/workflow/DiffCalWorkflow.py @@ -143,6 +143,7 @@ def _continueAnywayHandlerTweak(self, continueInfo: ContinueWarning.Model): # n def __setInteraction(self, state: bool): self._requestView.litemodeToggle.setEnabled(state) + self._requestView.skipPixelCalToggle.setEnabled(state) self._requestView.groupingFileDropdown.setEnabled(state) @ExceptionToErrLog @@ -261,10 +262,6 @@ def _specifyRun(self, workflowPresenter): self._renewFitPeaks(self.peakFunction) response = self._calculateResidual() - # freeze these toggles, as they can no longer function - self._requestView.litemodeToggle.setEnabled(False) - self._requestView.skipPixelCalToggle.setEnabled(False) - self._tweakPeakView.updateGraphs( self.focusedWorkspace, self.ingredients.groupedPeakLists, diff --git a/tests/cis_tests/smooth_data_excluding_peaks.py b/tests/cis_tests/smooth_data_excluding_peaks.py index cd68546ca..86219be39 100644 --- a/tests/cis_tests/smooth_data_excluding_peaks.py +++ b/tests/cis_tests/smooth_data_excluding_peaks.py @@ -4,7 +4,7 @@ import snapred.backend.recipe.algorithm -from mantid.simpleapi import Rebin, SmoothDataExcludingPeaksAlgo +from mantid.simpleapi import Rebin, SmoothDataExcludingPeaksAlgo, ConvertUnits, DiffractionFocussing # try to make the logger shutup from snapred.backend.log.logger import snapredLogger @@ -18,13 +18,13 @@ from snapred.backend.dao.ingredients.GroceryListItem import GroceryListItem from snapred.backend.data.GroceryService import GroceryService -from snapred.meta.redantic import list_to_raw +from snapred.meta.pointer import create_pointer #User inputs ########################### runNumber = "58882" #58409 isLite = True groupingScheme = "Column" -cifPath = "/SNS/SNAP/shared/Calibration/CalibrantSamples/Silicon_NIST_640d.cif" +calibrantSamplePath = "Silicon_NIST_640D_001.json" smoothingParameter = 0.05 ####################################### @@ -33,26 +33,41 @@ runNumber = runNumber, useLiteMode=isLite, focusGroups=[{"name": groupingScheme, "definition": ""}], - cifPath=cifPath, + calibrantSamplePath=calibrantSamplePath, ) peaks = SousChef().prepDetectorPeaks(farmFresh) ## FETCH GROCERIES simpleList = GroceryListItem.builder().neutron(runNumber).useLiteMode(isLite).buildList() -grocery = GroceryService().fetchGroceryList(simpleList)[0] + +clerk = GroceryListItem.builder() +clerk.name("inputWorkspace").neutron(runNumber).useLiteMode(isLite).add() +clerk.name("groupingWorkspace").fromRun(runNumber).grouping(groupingScheme).useLiteMode(isLite).add() +groceries = GroceryService().fetchGroceryDict(clerk.buildDict()) + +inputWS = groceries["inputWorkspace"] +focusWS = groceries["groupingWorkspace"] + +## PREPARE +# data must be in units of d-spacing # we must convert the event data to histogram data -# this rebin step will accomplish that, due to PreserveEvents = False -Rebin( - InputWorkspace = grocery, - OutputWorkspace = grocery, - Params = (1,-0.01,1667.7), - PreserveEvents = False, +ConvertUnits( + InputWorkspace = inputWS, + OutputWorkspace="in_ws", + Target="dSpacing", +) +DiffractionFocussing( + InputWorkspace="in_ws", + GroupingWorkspace=focusWS, + OutputWorkspace="in_ws", + PreserveEvents=False, # will convert to histogram ) ## RUN ALGORITHM + assert SmoothDataExcludingPeaksAlgo( - InputWorkspace = grocery, - DetectorPeaks = list_to_raw(peaks), + InputWorkspace = "in_ws", + DetectorPeaks = create_pointer(peaks), SmoothingParameter = smoothingParameter, OutputWorkspace = "out_ws", ) diff --git a/tests/unit/backend/recipe/algorithm/test_DiffractionSpectrumWeightCalculator.py b/tests/unit/backend/recipe/algorithm/test_DiffractionSpectrumWeightCalculator.py index d1935791c..8c68ed746 100644 --- a/tests/unit/backend/recipe/algorithm/test_DiffractionSpectrumWeightCalculator.py +++ b/tests/unit/backend/recipe/algorithm/test_DiffractionSpectrumWeightCalculator.py @@ -1,12 +1,14 @@ import socket import unittest.mock as mock +from typing import List import pytest from mantid.testing import assert_almost_equal from util.diffraction_calibration_synthetic_data import SyntheticData from snapred.backend.dao import CrystallographicPeak, DetectorPeak, GroupPeakList -from snapred.meta.redantic import list_to_raw +from snapred.meta.pointer import create_pointer +from snapred.meta.redantic import parse_file_as with mock.patch.dict( "sys.modules", @@ -64,7 +66,7 @@ def test_chop_ingredients(): # initialize the algo and chop ingredients weightCalculatorAlgo = DiffractionSpectrumWeightCalculator() weightCalculatorAlgo.initialize() - weightCalculatorAlgo.setProperty("DetectorPeaks", list_to_raw(peaks)) + weightCalculatorAlgo.setProperty("DetectorPeaks", create_pointer(peaks)) weightCalculatorAlgo.chopIngredients(peaks) # verify result @@ -157,15 +159,13 @@ def test_validate_fail_wrong_sizes(): ) # create input detector peaks with ONE histogram - peaks_json = list_to_raw( - [GroupPeakList(peaks=SyntheticData.fakeDetectorPeaks(), groupID=i) for i in range(len_peaks)] - ) + peaks = [GroupPeakList(peaks=SyntheticData.fakeDetectorPeaks(), groupID=i) for i in range(len_peaks)] # initialize the algo and try to run -- verify that it fails with error weightCalculatorAlgo = DiffractionSpectrumWeightCalculator() weightCalculatorAlgo.initialize() weightCalculatorAlgo.setProperty("InputWorkspace", input_ws_name) - weightCalculatorAlgo.setProperty("DetectorPeaks", peaks_json) + weightCalculatorAlgo.setProperty("DetectorPeaks", create_pointer(peaks)) weightCalculatorAlgo.setProperty("WeightWorkspace", weight_ws_name) with pytest.raises(RuntimeError) as e: weightCalculatorAlgo.execute() @@ -190,15 +190,13 @@ def test_validate_pass_and_execute(): ) # create input detector peaks - peaks_json = list_to_raw( - [GroupPeakList(peaks=SyntheticData.fakeDetectorPeaks(), groupID=i) for i in range(len_wksp)] - ) + peaks = [GroupPeakList(peaks=SyntheticData.fakeDetectorPeaks(), groupID=i) for i in range(len_wksp)] # initialize and run the weight algo -- verify it runs weightCalculatorAlgo = DiffractionSpectrumWeightCalculator() weightCalculatorAlgo.initialize() weightCalculatorAlgo.setProperty("InputWorkspace", input_ws_name) - weightCalculatorAlgo.setProperty("DetectorPeaks", peaks_json) + weightCalculatorAlgo.setProperty("DetectorPeaks", create_pointer(peaks)) weightCalculatorAlgo.setProperty("WeightWorkspace", weight_ws_name) assert weightCalculatorAlgo.execute() @@ -246,13 +244,13 @@ def test_execute_correct_weights(): for x in range(len(peaks)) if peaks[x] == peak_hi ] - peaks_json = list_to_raw([GroupPeakList(peaks=peakList, groupID=0)]) + peaks = [GroupPeakList(peaks=peakList, groupID=0)] # initialize and run the weight algo weightCalculatorAlgo = DiffractionSpectrumWeightCalculator() weightCalculatorAlgo.initialize() weightCalculatorAlgo.setProperty("InputWorkspace", input_ws_name) - weightCalculatorAlgo.setProperty("DetectorPeaks", peaks_json) + weightCalculatorAlgo.setProperty("DetectorPeaks", create_pointer(peaks)) weightCalculatorAlgo.setProperty("WeightWorkspace", weight_ws_name) assert weightCalculatorAlgo.execute() @@ -288,13 +286,13 @@ def test_with_predicted_peaks(): ) # load predicted peaks - peaks_json = Resource.read("inputs/weight_spectra/peaks.json") + peaks = parse_file_as(List[GroupPeakList], Resource.getPath("inputs/weight_spectra/peaks.json")) # initialize and run the weight algo weightCalculatorAlgo = DiffractionSpectrumWeightCalculator() weightCalculatorAlgo.initialize() weightCalculatorAlgo.setProperty("InputWorkspace", input_ws_name) - weightCalculatorAlgo.setProperty("DetectorPeaks", peaks_json) + weightCalculatorAlgo.setProperty("DetectorPeaks", create_pointer(peaks)) weightCalculatorAlgo.setProperty("WeightWorkspace", weight_ws_name) assert weightCalculatorAlgo.execute() diff --git a/tests/unit/backend/recipe/algorithm/test_RemoveEventBackground.py b/tests/unit/backend/recipe/algorithm/test_RemoveEventBackground.py index 68dcbebd3..ed1f97d5c 100644 --- a/tests/unit/backend/recipe/algorithm/test_RemoveEventBackground.py +++ b/tests/unit/backend/recipe/algorithm/test_RemoveEventBackground.py @@ -129,31 +129,22 @@ def test_missing_properties(self): algo.execute() def test_smoothing_parameter_edge_cases(self): - peaks = self.create_test_peaks() ConvertToEventWorkspace( InputWorkspace=self.fakeData, OutputWorkspace=self.fakeData, ) algo = Algo() algo.initialize() - algo.setProperty("InputWorkspace", self.fakeData) - algo.setProperty("GroupingWorkspace", self.fakeGroupingWorkspace) - algo.setProperty("DetectorPeaks", create_pointer(peaks)) - algo.setProperty("SmoothingParameter", 0) - algo.setProperty("OutputWorkspace", "output_test_ws_no_smoothing") - - assert algo.execute() - algo.setProperty("SmoothingParameter", -1) - algo.setProperty("OutputWorkspace", "output_test_ws_negative_smoothing") + # negative values are excluded + with self.assertRaises(ValueError): # noqa: PT027 + algo.setProperty("SmoothingParameter", -1) - with self.assertRaises(RuntimeError): # noqa: PT027 - algo.execute() - - algo.setProperty("SmoothingParameter", 1000) - algo.setProperty("OutputWorkspace", "output_test_ws_large_smoothing") - - assert algo.execute() + # zero is valid, large numbers valid, floats valid + valid_values = [0, 1000, 3.141592] + for value in valid_values: + algo.setProperty("SmoothingParameter", value) + assert value == algo.getProperty("SmoothingParameter").value def test_output_workspace_creation(self): peaks = self.create_test_peaks() diff --git a/tests/unit/backend/recipe/algorithm/test_SmoothDataExcludingPeaksAlgo.py b/tests/unit/backend/recipe/algorithm/test_SmoothDataExcludingPeaksAlgo.py index 76f982065..c661da268 100644 --- a/tests/unit/backend/recipe/algorithm/test_SmoothDataExcludingPeaksAlgo.py +++ b/tests/unit/backend/recipe/algorithm/test_SmoothDataExcludingPeaksAlgo.py @@ -12,7 +12,7 @@ from snapred.backend.dao.request import FarmFreshIngredients from snapred.backend.recipe.algorithm.SmoothDataExcludingPeaksAlgo import SmoothDataExcludingPeaksAlgo as Algo from snapred.meta.Config import Resource -from snapred.meta.redantic import list_to_raw +from snapred.meta.pointer import create_pointer class TestSmoothDataAlgo(unittest.TestCase): @@ -39,16 +39,27 @@ def test_unbag_groceries(self): def test_execute_with_peaks(self): # input data testWS = CreateWorkspace(DataX=[0, 1, 2, 3, 4, 5, 6], DataY=[2, 2, 2, 2, 2, 2], UnitX="dSpacing") - jsonString = ( - '[{"groupID": 1, "peaks": [{"position": {"value":1, "minimum":0, "maximum":2},' - ' "peak": {"hkl": [1, 1, 1], "dSpacing": 3.13592994862768,' - '"fSquared": 535.9619564273586, "multiplicity": 8}}]}]' - ) + peaks = [ + { + "groupID": 1, + "peaks": [ + { + "position": {"value": 1, "minimum": 0, "maximum": 2}, + "peak": { + "hkl": [1, 1, 1], + "dSpacing": 3.13592994862768, + "fSquared": 535.9619564273586, + "multiplicity": 8, + }, + } + ], + } + ] algo = Algo() algo.initialize() algo.setProperty("InputWorkspace", testWS) algo.setPropertyValue("OutputWorkspace", "test_out_ws") - algo.setProperty("DetectorPeaks", jsonString) + algo.setProperty("DetectorPeaks", create_pointer(peaks)) algo.setProperty("SmoothingParameter", 0.0) assert algo.execute() @@ -72,6 +83,6 @@ def test_SmoothDataExcludingPeaksAlgo(self): smoothDataAlgo.initialize() smoothDataAlgo.setPropertyValue("InputWorkspace", test_ws_name) smoothDataAlgo.setPropertyValue("OutputWorkspace", "_output") - smoothDataAlgo.setPropertyValue("DetectorPeaks", list_to_raw(peaks)) - smoothDataAlgo.setPropertyValue("SmoothingParameter", "0.9") + smoothDataAlgo.setProperty("DetectorPeaks", create_pointer(peaks)) + smoothDataAlgo.setProperty("SmoothingParameter", 0.9) assert smoothDataAlgo.execute() diff --git a/tests/unit/backend/service/test_NormalizationService.py b/tests/unit/backend/service/test_NormalizationService.py index dbc1076ec..a9c4d4512 100644 --- a/tests/unit/backend/service/test_NormalizationService.py +++ b/tests/unit/backend/service/test_NormalizationService.py @@ -2,7 +2,7 @@ import unittest import unittest.mock as mock from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from mantid.simpleapi import ( @@ -206,7 +206,7 @@ def test_smoothDataExcludingPeaks( mockRecipeInst.executeRecipe.assert_called_once_with( InputWorkspace=mockRequest.inputWorkspace, OutputWorkspace=mockRequest.outputWorkspace, - DetectorPeaks=self.instance.sousChef.prepDetectorPeaks(FarmFreshIngredients.return_value), + DetectorPeaks=ANY, # NOTE it is impossible to see this pointer SmoothingParameter=mockRequest.smoothingParameter, ) diff --git a/tests/unit/meta/test_pointer.py b/tests/unit/meta/test_pointer.py index a4d585ee9..cddfa806e 100644 --- a/tests/unit/meta/test_pointer.py +++ b/tests/unit/meta/test_pointer.py @@ -1,7 +1,9 @@ import ctypes import gc -from snapred.meta.pointer import access_pointer, create_pointer +import pytest + +from snapred.meta.pointer import access_pointer, create_pointer, inspect_pointer def make_a(): @@ -42,3 +44,25 @@ def test_pointer_persistence(): aa = access_pointer(pa) assert aa.__dir__() != {} assert aa == make_a() + + +def test_inspect_pointer(): + """Ensure pointers can be expected and still remain in the queue""" + a = make_a() + pa = create_pointer(a) + aa = inspect_pointer(pa) + aaa = access_pointer(pa) + assert a == aa + assert a == aaa + + +def test_inspect_bad_pointer_error(): + """Ensure accessing a bad pointer raises an error""" + + # create a pointer then access it, which removes it from the cache + a = make_a() + pa = create_pointer(a) + access_pointer(pa) + + with pytest.raises(RuntimeError): + inspect_pointer(pa) From e1d2f3381c9e4f7d6be0d5a703b8576e1bf6beef Mon Sep 17 00:00:00 2001 From: Reece Boston <52183986+rboston628@users.noreply.github.com> Date: Fri, 20 Dec 2024 10:50:08 -0500 Subject: [PATCH 6/7] refactor and rename RemoveEventBackground (#505) --- .../backend/recipe/PixelDiffCalRecipe.py | 44 +-- .../recipe/algorithm/RemoveEventBackground.py | 181 ------------ .../algorithm/RemoveSmoothedBackground.py | 198 +++++++++++++ src/snapred/ui/view/DiffCalRequestView.py | 4 +- src/snapred/ui/workflow/DiffCalWorkflow.py | 1 + tests/cis_tests/cc-and-smooth.py | 272 ++++++++++++++++++ tests/cis_tests/check_remove.py | 85 ++++++ ...ffraction_background_subtraction_script.py | 22 +- ...nd.py => test_RemoveSmoothedBackground.py} | 74 ++--- .../backend/recipe/test_PixelDiffCalRecipe.py | 23 +- 10 files changed, 615 insertions(+), 289 deletions(-) delete mode 100644 src/snapred/backend/recipe/algorithm/RemoveEventBackground.py create mode 100644 src/snapred/backend/recipe/algorithm/RemoveSmoothedBackground.py create mode 100644 tests/cis_tests/cc-and-smooth.py create mode 100644 tests/cis_tests/check_remove.py rename tests/unit/backend/recipe/algorithm/{test_RemoveEventBackground.py => test_RemoveSmoothedBackground.py} (74%) diff --git a/src/snapred/backend/recipe/PixelDiffCalRecipe.py b/src/snapred/backend/recipe/PixelDiffCalRecipe.py index 225098493..091746c72 100644 --- a/src/snapred/backend/recipe/PixelDiffCalRecipe.py +++ b/src/snapred/backend/recipe/PixelDiffCalRecipe.py @@ -51,13 +51,14 @@ def chopIngredients(self, ingredients: Ingredients) -> None: self.runNumber: str = ingredients.runConfig.runNumber # from grouping parameters, read the overall min/max d-spacings - dMin = ingredients.pixelGroup.dMin() - dMax = ingredients.pixelGroup.dMax() - dBin = ingredients.pixelGroup.dBin() + dMin: List[float] = ingredients.pixelGroup.dMin() + dMax: List[float] = ingredients.pixelGroup.dMax() + dBin: List[float] = ingredients.pixelGroup.dBin() self.overallDMin: float = min(dMin) self.overallDMax: float = max(dMax) - self.dBin: float = max([abs(d) for d in dBin]) + self.dBin: float = min([abs(d) for d in dBin]) self.dSpaceParams = (self.overallDMin, self.dBin, self.overallDMax) + self.tofParams = ingredients.pixelGroup.timeOfFlight.params self.removeBackground = ingredients.removeBackground self.detectorPeaks = ingredients.groupedPeakLists self.threshold = ingredients.convergenceThreshold @@ -82,7 +83,6 @@ def unbagGroceries(self, groceries: Dict[str, WorkspaceName]) -> None: # noqa A # the name of the output calibration table self.DIFCpixel = groceries["calibrationTable"] self.DIFCprev = groceries.get("previousCalibration", "") - self.isEventWs = self.mantidSnapper.mtd[self.wsTOF].id() == "EventWorkspace" # the input data converted to d-spacing self.wsDSP = wng.diffCalInputDSP().runNumber(self.runNumber).build() self.convertUnitsAndRebin(self.wsTOF, self.wsDSP) @@ -124,36 +124,20 @@ def stirInputs(self): self.mantidSnapper.executeQueue() def stripBackground(self, peaks: List[Any], inputWS: WorkspaceName, groupingWS: WorkspaceName): - wsBG: str = inputWS + "_bg" - - self.mantidSnapper.CloneWorkspace( - "Cloning input workspace for background subtraction", + self.mantidSnapper.Rebin( + "Rebin thedata before removing baackground", InputWorkspace=inputWS, - OutPutWorkspace=wsBG, + OutputWorkspace=inputWS, + Params=self.tofParams, + BinningMode="Logarithmic", ) - self.mantidSnapper.RemoveEventBackground( - "Extracting background events...", - InputWorkspace=wsBG, - OutputWorkspace=wsBG, + self.mantidSnapper.RemoveSmoothedBackground( + "Extracting smoothed background from input data", + InputWorkspace=inputWS, + OutputWorkspace=inputWS, GroupingWorkspace=groupingWS, DetectorPeaks=peaks, ) - if self.isEventWs: - self.mantidSnapper.ConvertToEventWorkspace( - "Converting TOF data to EventWorkspace...", - InputWorkspace=wsBG, - OutputWorkspace=wsBG, - ) - self.mantidSnapper.Minus( - "Subtracting background from input data", - LHSWorkspace=inputWS, - RHSWorkspace=wsBG, - OutputWorkspace=inputWS, - ) - self.mantidSnapper.WashDishes( - "Delete the background after subtraction", - Workspace=wsBG, - ) def convertUnitsAndRebin(self, inputWS: str, outputWS: str) -> None: """ diff --git a/src/snapred/backend/recipe/algorithm/RemoveEventBackground.py b/src/snapred/backend/recipe/algorithm/RemoveEventBackground.py deleted file mode 100644 index 2c31ee02a..000000000 --- a/src/snapred/backend/recipe/algorithm/RemoveEventBackground.py +++ /dev/null @@ -1,181 +0,0 @@ -from typing import Dict, List, Tuple - -import numpy as np -from mantid.api import ( - AlgorithmFactory, - IEventWorkspaceProperty, - MatrixWorkspaceProperty, - PropertyMode, - PythonAlgorithm, - WorkspaceUnitValidator, -) -from mantid.kernel import Direction, FloatBoundedValidator -from mantid.kernel import ULongLongPropertyWithValue as PointerProperty -from mantid.simpleapi import ( - ConvertToMatrixWorkspace, - ConvertUnits, - GroupedDetectorIDs, - MakeDirtyDish, - mtd, -) -from scipy.interpolate import make_smoothing_spline - -from snapred.backend.dao.GroupPeakList import GroupPeakList -from snapred.meta.Config import Config -from snapred.meta.pointer import access_pointer - - -class RemoveEventBackground(PythonAlgorithm): - def category(self): - return "SNAPRed Data Processing" - - def PyInit(self): - # declare properties - self.declareProperty( - IEventWorkspaceProperty( - "InputWorkspace", "", Direction.Input, PropertyMode.Mandatory, validator=WorkspaceUnitValidator("TOF") - ), - doc="Event workspace containing the data with peaks and background, in TOF units", - ) - self.declareProperty( - MatrixWorkspaceProperty("GroupingWorkspace", "", Direction.Input, PropertyMode.Mandatory), - doc="Workspace holding the detector grouping information, for assigning peak windows", - ) - self.declareProperty( - MatrixWorkspaceProperty( - "OutputWorkspace", - "", - Direction.Output, - validator=WorkspaceUnitValidator("TOF"), - ), - doc="Histogram workspace representing the extracted background", - ) - self.declareProperty( - PointerProperty("DetectorPeaks", id(None)), - "The memory adress pointing to the list of grouped peaks.", - ) - self.declareProperty( - "SmoothingParameter", - defaultValue=Config["calibration.diffraction.smoothingParameter"], - direction=Direction.Input, - validator=FloatBoundedValidator(lower=0.0), - ) - self.setRethrows(True) - - def validateInputs(self) -> Dict[str, str]: - err = {} - if self.getProperty("DetectorPeaks").isDefault: - err["DetectorPeaks"] = "You must pass a pointer to a DetectorPeaks object" - return err - - def chopIngredients(self, predictedPeaksList: List[GroupPeakList]): - # for each group, create a list of regions to mask (i.e., the peak regions) - # these are the ranges between the peak min and max values - self.maskRegions: Dict[int, List[Tuple[float, float]]] = {} - for peakList in predictedPeaksList: - self.maskRegions[peakList.groupID] = [] - for peak in peakList.peaks: - self.maskRegions[peakList.groupID].append((peak.minimum, peak.maximum)) - # get handle to group focusing workspace and retrieve all detector IDs in each group - focusWSname: str = str(self.getPropertyValue("GroupingWorkspace")) - # get a list of the detector IDs in each group - result_ptr: PointerProperty = GroupedDetectorIDs(focusWSname) - self.groupDetectorIDs = access_pointer(result_ptr) - self.groupIDs: List[int] = list(self.groupDetectorIDs.keys()) - peakgroupIDs = [peakList.groupID for peakList in predictedPeaksList] - if self.groupIDs != peakgroupIDs: - raise RuntimeError(f"Groups IDs in workspace and peak list do not match: {self.groupIDs} vs {peakgroupIDs}") - - def unbagGroceries(self): - self.inputWorkspaceName = self.getPropertyValue("InputWorkspace") - self.outputBackgroundWorkspaceName = self.getPropertyValue("OutputWorkspace") - self.smoothingParameter = float(self.getPropertyValue("SmoothingParameter")) - - def PyExec(self): - """ - Extracts background from event data by masking the peak regions. - This background can later be subtracted from the main data. - """ - self.log().notice("Extracting background") - - # get the peak predictions from user input - peak_ptr: PointerProperty = self.getProperty("DetectorPeaks").value - predictedPeaksList = access_pointer(peak_ptr) - self.chopIngredients(predictedPeaksList) - self.unbagGroceries() - - # Creating copy of initial TOF data - MakeDirtyDish( - InputWorkspace=self.inputWorkspaceName, - OutputWorkspace=self.inputWorkspaceName + "_extractBegin", - ) - # Convert to d-spacing to match peak windows - ConvertUnits( - InputWorkspace=self.inputWorkspaceName, - OutputWorkspace=self.outputBackgroundWorkspaceName, - Target="dSpacing", - ) - # Creating copy of initial d-spacing data - MakeDirtyDish( - InputWorkspace=self.outputBackgroundWorkspaceName, - OutputWorkspace=self.outputBackgroundWorkspaceName + "_extractDSP", - ) - - # Converting EventWorkspace to MatrixWorkspace... - ConvertToMatrixWorkspace( - InputWorkspace=self.outputBackgroundWorkspaceName, - OutputWorkspace=self.outputBackgroundWorkspaceName, - ) - - # Replace peak regions with interpolated values from surrounding data - ws = mtd[self.outputBackgroundWorkspaceName] - for groupID in self.groupIDs: - for detid in self.groupDetectorIDs[groupID]: - y_data = ws.readY(detid).copy() - x_data = ws.readX(detid) - - for mask in self.maskRegions[groupID]: - mask_indices = (x_data >= mask[0]) & (x_data <= mask[1]) - before_mask = np.where(x_data < mask[0])[0][-1] - after_mask = np.where(x_data > mask[1])[0][0] - - # Linear interpolation across the masked region - interp_values = np.linspace(y_data[before_mask], y_data[after_mask], mask_indices.sum()) - y_data[mask_indices[:-1]] = interp_values - - ws.setY(detid, y_data) - - # Apply smoothing to the entire dataset - self.applySmoothing(ws) - - # Convert back to TOF - ConvertUnits( - InputWorkspace=self.outputBackgroundWorkspaceName, - OutputWorkspace=self.outputBackgroundWorkspaceName, - Target="TOF", - ) - - self.setPropertyValue("OutputWorkspace", self.outputBackgroundWorkspaceName) - - def applySmoothing(self, workspace): - """ - Applies smoothing to the entire workspace data. - """ - numSpec = workspace.getNumberHistograms() - - for index in range(numSpec): - x = workspace.readX(index) - y = workspace.readY(index) - - # Apply spline smoothing to the entire dataset - tck = make_smoothing_spline(x[:-1], y, lam=self.smoothingParameter) - y_smooth = tck(x[:-1], extrapolate=False) - - # Ensure no negative values after smoothing - y_smooth[y_smooth < 0] = 0 - - workspace.setY(index, y_smooth) - - -# Register algorithm with Mantid -AlgorithmFactory.subscribe(RemoveEventBackground) diff --git a/src/snapred/backend/recipe/algorithm/RemoveSmoothedBackground.py b/src/snapred/backend/recipe/algorithm/RemoveSmoothedBackground.py new file mode 100644 index 000000000..b014e955d --- /dev/null +++ b/src/snapred/backend/recipe/algorithm/RemoveSmoothedBackground.py @@ -0,0 +1,198 @@ +from typing import Dict, List + +from mantid.api import ( + IEventWorkspaceProperty, + MatrixWorkspaceProperty, + PropertyMode, + PythonAlgorithm, + WorkspaceUnitValidator, +) +from mantid.dataobjects import GroupingWorkspaceProperty +from mantid.kernel import Direction, FloatBoundedValidator +from mantid.kernel import ULongLongPropertyWithValue as PointerProperty +from mantid.simpleapi import ( + ConvertToEventWorkspace, + ConvertToMatrixWorkspace, + ConvertUnits, + DeleteWorkspaces, + GroupDetectors, + GroupedDetectorIDs, + MakeDirtyDish, + SmoothDataExcludingPeaksAlgo, + mtd, +) + +from snapred.backend.dao.GroupPeakList import GroupPeakList +from snapred.meta.Config import Config +from snapred.meta.pointer import access_pointer, create_pointer + + +class RemoveSmoothedBackground(PythonAlgorithm): + def category(self): + return "SNAPRed Data Processing" + + def PyInit(self): + # declare properties + self.declareProperty( + IEventWorkspaceProperty( + "InputWorkspace", + defaultValue="", + direction=Direction.Input, + optional=PropertyMode.Mandatory, + validator=WorkspaceUnitValidator("TOF"), + ), + doc="Event workspace containing the data with peaks and baseline, in TOF units", + ) + self.declareProperty( + GroupingWorkspaceProperty( + "GroupingWorkspace", + defaultValue="", + direction=Direction.Input, + optional=PropertyMode.Mandatory, + ), + doc="Workspace holding the detector grouping information, for assigning peak windows", + ) + self.declareProperty( + MatrixWorkspaceProperty( + "OutputWorkspace", + defaultValue="", + direction=Direction.Output, + validator=WorkspaceUnitValidator("TOF"), + ), + doc="Histogram workspace representing the extracted background", + ) + self.declareProperty( + PointerProperty("DetectorPeaks", id(None)), + doc="The memory adress pointing to the list of grouped peaks.", + ) + self.declareProperty( + "SmoothingParameter", + defaultValue=Config["calibration.diffraction.smoothingParameter"], + validator=FloatBoundedValidator(lower=0.0), + ) + self.setRethrows(True) + + def validateInputs(self) -> Dict[str, str]: + err = {} + if self.getProperty("DetectorPeaks").isDefault: + err["DetectorPeaks"] = "You must pass a pointer to a DetectorPeaks object" + return err + + def chopIngredients(self, predictedPeaksList: List[GroupPeakList]): + self.smoothingParameter = float(self.getProperty("SmoothingParameter").value) + # get handle to group focusing workspace and retrieve all detector IDs in each group + focusWSname: str = str(self.getPropertyValue("GroupingWorkspace")) + # get a list of the detector IDs in each group + result_ptr: PointerProperty = GroupedDetectorIDs(focusWSname) + self.groupDetectorIDs: Dict[int, List[int]] = access_pointer(result_ptr) + self.groupIDs: List[int] = list(self.groupDetectorIDs.keys()) + peakgroupIDs: List[int] = [peakList.groupID for peakList in predictedPeaksList] + if self.groupIDs != peakgroupIDs: + raise RuntimeError(f"Groups IDs in workspace and peak list do not match: {self.groupIDs} vs {peakgroupIDs}") + + def unbagGroceries(self): + self.inputWorkspaceName = self.getPropertyValue("InputWorkspace") + self.outputWorkspaceName = self.getPropertyValue("OutputWorkspace") + self.focusWorkspace = self.getPropertyValue("GroupingWorkspace") + self.isEventWs = mtd[self.inputWorkspaceName].id() == "EventWorkspace" + + def PyExec(self): + """ + Extracts background from event data by masking the peak regions. + This background can later be subtracted from the main data. + """ + self.log().notice("Extracting background") + + # get the peak predictions from user input + peak_ptr: PointerProperty = self.getProperty("DetectorPeaks").value + predictedPeaksList = access_pointer(peak_ptr) + self.chopIngredients(predictedPeaksList) + self.unbagGroceries() + + tmpDSPws = mtd.unique_name(prefix="dsp_") + diffocWSname = mtd.unique_name(prefix="diffoc_") + backgroundWSname = mtd.unique_name(prefix="bkgr_") + + # Creating copy of initial TOF data + MakeDirtyDish( + InputWorkspace=self.inputWorkspaceName, + OutputWorkspace=self.inputWorkspaceName + "_extractTOF_before", + ) + # Convert to d-spacing to match peak windows + ConvertUnits( + InputWorkspace=self.inputWorkspaceName, + OutputWorkspace=tmpDSPws, + Target="dSpacing", + ) + # Creating copy of initial d-spacing data + MakeDirtyDish( + InputWorkspace=tmpDSPws, + OutputWorkspace=self.outputWorkspaceName + "_extractDSP_before", + ) + + # find average spectra over focus groups + if self.isEventWs: + ConvertToMatrixWorkspace( + InputWorkspace=tmpDSPws, + OutputWorkspace=tmpDSPws, + ) + GroupDetectors( + InputWorkspace=self.inputWorkspaceName, + CopyGroupingFromWorkspace=self.focusWorkspace, + OutputWorkspace=diffocWSname, + Behaviour="Average", + PreserveEvents=False, + ) + ConvertUnits( + InputWorkspace=diffocWSname, + OutputWorkspace=diffocWSname, + Target="dSpacing", + ) + # create the smoothed background from averaged spectra + SmoothDataExcludingPeaksAlgo( + InputWorkspace=diffocWSname, + OutputWorkspace=backgroundWSname, + DetectorPeaks=create_pointer(predictedPeaksList), + SmoothingParameter=self.smoothingParameter, + ) + + # Subtract off the scaled background estimation + focusWS = mtd[diffocWSname] + smoothWS = mtd[backgroundWSname] + outputWS = mtd[tmpDSPws] + for wkspindx, groupID in enumerate(self.groupIDs): + y_smooth = smoothWS.readY(wkspindx).copy() + y_data = focusWS.readY(wkspindx).copy() + scale_denom = sum(y_data) / len(y_data) + for detid in self.groupDetectorIDs[groupID]: + # subtract off the background and update the date in the workspace + y_new = outputWS.readY(detid).copy() + scale_num = sum(y_new) / len(y_new) + + y_new = y_new - (scale_num / scale_denom) * y_smooth + y_new[y_new < 0] = 0 + outputWS.setY(detid, y_new) + + MakeDirtyDish( + InputWorkspace=tmpDSPws, + OutputWorkspace=self.outputWorkspaceName + "_extractDSP_after", + ) + ConvertUnits( + InputWorkspace=tmpDSPws, + OutputWorkspace=self.outputWorkspaceName, + Target="TOF", + ) + MakeDirtyDish( + InputWorkspace=self.outputWorkspaceName, + OutputWorkspace=self.outputWorkspaceName + "_extractTOF_after", + ) + if self.isEventWs: + ConvertToEventWorkspace( + InputWorkspace=self.outputWorkspaceName, + OutputWorkspace=self.outputWorkspaceName, + ) + + # Cleanup + DeleteWorkspaces(WorkspaceList=[diffocWSname, backgroundWSname, tmpDSPws]) + + self.setPropertyValue("OutputWorkspace", self.outputWorkspaceName) diff --git a/src/snapred/ui/view/DiffCalRequestView.py b/src/snapred/ui/view/DiffCalRequestView.py index 7c3775a2e..e79e140f0 100644 --- a/src/snapred/ui/view/DiffCalRequestView.py +++ b/src/snapred/ui/view/DiffCalRequestView.py @@ -1,4 +1,3 @@ -from snapred.meta.Config import Config from snapred.meta.decorators.Resettable import Resettable from snapred.meta.mantid.AllowedPeakTypes import SymmetricPeakEnum from snapred.ui.view.BackendRequestView import BackendRequestView @@ -34,9 +33,8 @@ def __init__(self, samples=[], groups=[], parent=None): self.peakFunctionDropdown = self._sampleDropDown("Peak Function", [p.value for p in SymmetricPeakEnum]) # checkbox for removing background - # NOTE not enabled unless in CIS mode until remove event background is fixed -- then re-enable self.removeBackgroundToggle = self._labeledToggle("RemoveBackground", False) - self.removeBackgroundToggle.setEnabled(Config["cis_mode"]) + self.removeBackgroundToggle.setEnabled(True) # set field properties self.litemodeToggle.setEnabled(True) diff --git a/src/snapred/ui/workflow/DiffCalWorkflow.py b/src/snapred/ui/workflow/DiffCalWorkflow.py index 93e3eb082..df2c3db2b 100644 --- a/src/snapred/ui/workflow/DiffCalWorkflow.py +++ b/src/snapred/ui/workflow/DiffCalWorkflow.py @@ -468,6 +468,7 @@ def _triggerDiffractionCalibration(self, workflowPresenter): self.focusGroupPath = view.groupingFileDropdown.currentText() self.groceries["previousCalibration"] = self.prevDiffCal + # perform the group calibration payload = SimpleDiffCalRequest( ingredients=self.ingredients, groceries=self.groceries, diff --git a/tests/cis_tests/cc-and-smooth.py b/tests/cis_tests/cc-and-smooth.py new file mode 100644 index 000000000..39937b4ef --- /dev/null +++ b/tests/cis_tests/cc-and-smooth.py @@ -0,0 +1,272 @@ +# import mantid algorithms, numpy and matplotlib +import snapred.backend.recipe.algorithm +from mantid.simpleapi import * +import matplotlib.pyplot as plt +import numpy as np +import json +import time +## for creating ingredients +from snapred.backend.dao.request.FarmFreshIngredients import FarmFreshIngredients +from snapred.backend.service.SousChef import SousChef + +## for loading data +from snapred.backend.dao.ingredients.GroceryListItem import GroceryListItem +from snapred.backend.data.GroceryService import GroceryService + +from snapred.meta.Config import Config +from snapred.meta.pointer import create_pointer, access_pointer + +#User input ########################### +runNumber = "58882" +groupingScheme = "Column" +calibrantSamplePath = "Silicon_NIST_640D_001.json" +peakThreshold = 0.05 +offsetConvergenceLimit = 0.1 +isLite = True +Config._config["cis_mode"] = True +####################################### + +### PREP INGREDIENTS ################ +farmFresh = FarmFreshIngredients( + runNumber=runNumber, + useLiteMode=isLite, + focusGroups=[{"name": groupingScheme, "definition": ""}], + # cifPath=cifPath, + calibrantSamplePath=calibrantSamplePath, + peakIntensityThreshold=peakThreshold, + convergenceThreshold=offsetConvergenceLimit, + maxOffset=100.0, +) +pixelGroup = SousChef().prepPixelGroup(farmFresh) +detectorPeaks = SousChef().prepDetectorPeaks(farmFresh) + +total = "total" +background = "background" +peaks = "peaks" +ref = "blanked" + +### FETCH GROCERIES ################## + +clerk = GroceryListItem.builder() +clerk.name("inputWorkspace").neutron(runNumber).useLiteMode(isLite).add() +clerk.name("groupingWorkspace").fromRun(runNumber).grouping(groupingScheme).useLiteMode(isLite).add() +groceries = GroceryService().fetchGroceryDict(clerk.buildDict()) + +## UNBAG GROCERIES +inputWorkspace = groceries["inputWorkspace"] +focusWorkspace = groceries["groupingWorkspace"] +Rebin( + InputWorkspace=inputWorkspace, + OutputWorkspace=inputWorkspace, + Params=pixelGroup.timeOfFlight.params, +) + +## CHOP INGREDIENTS +groupIDs = [] +for peakList in detectorPeaks: + groupIDs.append(peakList.groupID) +groupDetectorIDs = access_pointer(GroupedDetectorIDs(focusWorkspace)) + +def getRefID(detectorIDs): + return sorted(detectorIDs)[int(np.round((len(detectorIDs)-1) / 2.0))] + +def performCrossCorrelation(inws): + inws_tof_update = f"{inws}_tof_temp" + inws_dsp_final = f"{inws}_dsp_after" + inws_dsp = f"{inws}_dsp_tmp" + outws = f"{inws}_cc" + difcTable = f"{inws}_difc" + CloneWorkspace( + InputWorkspace=inws, + OutputWorkspace=inws_tof_update, + ) + CalculateDiffCalTable( + InputWorkspace=inws_tof_update, + CalibrationTable=difcTable, + ) + medianOffsets = [100] + while medianOffsets[-1] > 0.5: + ConvertUnits( + InputWorkspace=inws_tof_update, + OutputWorkspace=inws_dsp, + Target="dSpacing", + ) + for i, groupID in enumerate(groupIDs): + workspaceIndices = list(groupDetectorIDs[groupID]) + refID = getRefID(workspaceIndices) + wstemp = f"{outws}_{i}" + CrossCorrelate( + InputWorkspace=inws_dsp, + OutputWorkspace=wstemp, + ReferenceSpectra=refID, + WorkspaceIndexList=workspaceIndices, + XMin = 0.4, + XMax = 4.0, + MaxDSpaceShift=0.1, + ) + if i==0: + CloneWorkspace( + InputWorkspace=wstemp, + OutputWorkspace=outws, + ) + DeleteWorkspace(wstemp) + else: + ConjoinWorkspaces( + InputWorkspace1=outws, + InputWorkspace2=wstemp, + ) + GetDetectorOffsets( + InputWorkspace=outws, + OutputWorkspace=f"{inws}_offset", + MaskWorkspace=f"{inws}_mask", + OffsetMode="Signed", + Xmin=-50, + Xmax=50, + MaxOffset=10, + ) + ConvertDiffCal( + PreviousCalibration=difcTable, + OffsetsWorkspace=f"{inws}_offset", + OutputWorkspace=difcTable, + OffsetMode="Signed", + BinWidth=min(pixelGroup.dBin()), + ) + ApplyDiffCal( + InstrumentWorkspace=inws_tof_update, + CalibrationWorkspace=difcTable, + ) + offsetStats = access_pointer(OffsetStatistics(f"{inws}_offset")) + medianOffsets.append(offsetStats["medianOffset"]) + # process over -- apply DIFC to raw data + CloneWorkspace( + InputWorkspace=inputWorkspace, + OutputWorkspace=inws_dsp_final, + ) + ApplyDiffCal( + InstrumentWorkspace=inws_dsp_final, + CalibrationWorkspace=difcTable, + ) + ConvertUnits( + InputWorkspace=inws_dsp_final, + OutputWorkspace=inws_dsp_final, + Target="dSpacing", + ) + return medianOffsets + + +ConvertUnits( + InputWorkspace=inputWorkspace, + OutputWorkspace="total_dsp_before", + Target="dSpacing", +) + +### NO BACKGROUND REMOVAL ## + +CloneWorkspace( + InputWorkspace=inputWorkspace, + OutputWorkspace=total, +) + +offsets = performCrossCorrelation(total) +print(offsets) + +### REMOVE EVENT BACKGROUND BY BLANKS ## + +""" Logic notes: + Given event data, and a list of known peak windows, remove all events not in a peak window. + The events can be removed with masking. + The peak windows are usually given in d-spacing, so requries first converting units to d-space. + The peak windows are specific to a grouping, so need to act by-group. + On each group, remove non-peak events from all detectors in that group. +""" + +# perform the steps of the prototype algo + +blanks = {} +for peakList in detectorPeaks: + blanks[peakList.groupID] = [(0, peakList.peaks[0].minimum)] + for i in range(len(peakList.peaks) - 1): + blanks[peakList.groupID].append((peakList.peaks[i].maximum, peakList.peaks[i+1].minimum)) + blanks[peakList.groupID].append((peakList.peaks[-1].maximum, 10.0)) + +ws = ConvertUnits( + InputWorkspace=inputWorkspace, + OutputWorkspace=ref, + Target="dSpacing", +) +for groupID in groupIDs: + for detid in groupDetectorIDs[groupID]: + event_list = ws.getEventList(detid) + for blank in blanks[groupID]: + event_list.maskTof(blank[0], blank[1]) +ConvertUnits( + InputWorkspace=ref, + OutputWorkspace=ref, + Target="TOF", +) + +blankOffsets = performCrossCorrelation(ref) +print(blankOffsets) + +### REMOVE EVENT BACKGROUND BY SMOOTHING ## + +start = time.time() +RemoveSmoothedBackground( + InputWorkspace=inputWorkspace, + GroupingWorkspace=focusWorkspace, + OutputWorkspace=peaks, + DetectorPeaks = create_pointer(detectorPeaks), + SmoothingParameter=0.5, +) +end = time.time() +print(f"TIME FOR ALGO = {end-start}") + +ConvertUnits( + InputWorkspace=peaks, + OutputWorkspace=peaks, + Target="dSpacing", +) + +removeOffsets = performCrossCorrelation(peaks) +print(removeOffsets) + + +## CONVERT SPECTRUM AXIS +for x in ["total_cc", "peaks_cc", "blanked_cc"]: + ConvertSpectrumAxis( + InputWorkspace=x, + OutputWorkspace=x, + Target="SignedTheta", + ) + +## FOCUS FOR SAKE OF GRAPHING + +for x in ["total_dsp_after", "total_dsp_before", "peaks_dsp_after", "blanked_dsp_after"]: + ConvertSpectrumAxis( + InputWorkspace=x, + Outputworkspace=x, + Target="SignedTheta", + ) + DiffractionFocussing( + InputWorkspace=x, + OutputWorkspace=f"{x}_foc", + GroupingWorkspace=focusWorkspace, + ) + +### PLOT PEAK RESULTS ################################# +fig, ax = plt.subplots(subplot_kw={'projection':'mantid'}) +ax.plot(mtd[f"{total}_dsp_before_foc"], wkspIndex=0, label="Raw", normalize_by_bin_width=True) +ax.plot(mtd[f"{total}_dsp_after_foc"], wkspIndex=0, label="Total Data", normalize_by_bin_width=True) +ax.plot(mtd[f"{ref}_dsp_after_foc"], wkspIndex=0, label="Event Blanking", normalize_by_bin_width=True) +ax.plot(mtd[f"{peaks}_dsp_after_foc"], wkspIndex=0, label="Smoothing Subtraction", normalize_by_bin_width=True) +ax.legend() +fig.show() + + +### PLOT CC RESULTS ################################# +fig, ax = plt.subplots(subplot_kw={'projection':'mantid'}) +ax.plot(mtd["no_removal_foc"], wkspIndex=0, label="Total Data", normalize_by_bin_width=True) +ax.plot(mtd["event_blank_foc"], wkspIndex=0, label="Event Blanking", normalize_by_bin_width=True) +ax.plot(mtd["smoothing_foc"], wkspIndex=0, label="Smoothing Subtraction", normalize_by_bin_width=True) +ax.legend() +fig.show() \ No newline at end of file diff --git a/tests/cis_tests/check_remove.py b/tests/cis_tests/check_remove.py new file mode 100644 index 000000000..b483b72cf --- /dev/null +++ b/tests/cis_tests/check_remove.py @@ -0,0 +1,85 @@ +# import mantid algorithms, numpy and matplotlib +import snapred.backend.recipe.algorithm +from mantid.simpleapi import DiffractionFocussing, Rebin, RemoveSmoothedBackground +import matplotlib.pyplot as plt +import numpy as np +import json +import time +## for creating ingredients +from snapred.backend.dao.request.FarmFreshIngredients import FarmFreshIngredients +from snapred.backend.service.SousChef import SousChef + +## for loading data +from snapred.backend.dao.ingredients.GroceryListItem import GroceryListItem +from snapred.backend.data.GroceryService import GroceryService + +from snapred.meta.Config import Config +from snapred.meta.pointer import create_pointer, access_pointer +from snapred.meta.redantic import list_to_raw + +#User input ########################### +runNumber = "58882" +groupingScheme = "Column" +calibrantSamplePath = "Silicon_NIST_640D_001.json" +peakThreshold = 0.05 +offsetConvergenceLimit = 0.1 +isLite = True +Config._config["cis_mode"] = True +####################################### + +### PREP INGREDIENTS ################ + +farmFresh = FarmFreshIngredients( + runNumber=runNumber, + useLiteMode=isLite, + focusGroups=[{"name": groupingScheme, "definition": ""}], + calibrantSamplePath=calibrantSamplePath, + peakIntensityThreshold=peakThreshold, + convergenceThreshold=offsetConvergenceLimit, + maxOffset=100.0, +) +pixelGroup = SousChef().prepPixelGroup(farmFresh) +detectorPeaks = SousChef().prepDetectorPeaks(farmFresh) + +peaks = "peaks" + +### FETCH GROCERIES ################## + +clerk = GroceryListItem.builder() +clerk.name("inputWorkspace").neutron(runNumber).useLiteMode(isLite).add() +clerk.name("groupingWorkspace").fromRun(runNumber).grouping(groupingScheme).useLiteMode(isLite).add() +groceries = GroceryService().fetchGroceryDict(clerk.buildDict()) + +## UNBAG GROCERIES + +inputWorkspace = groceries["inputWorkspace"] +focusWorkspace = groceries["groupingWorkspace"] +Rebin( + InputWorkspace=inputWorkspace, + OutputWorkspace=inputWorkspace, + Params=pixelGroup.timeOfFlight.params, +) + +### REMOVE EVENT BACKGROUND BY SMOOTHING ## + +start = time.time() +RemoveSmoothedBackground( + InputWorkspace=inputWorkspace, + GroupingWorkspace=focusWorkspace, + OutputWorkspace=peaks, + DetectorPeaks = create_pointer(detectorPeaks), + SmoothingParameter=0.5, +) +end = time.time() +print(f"TIME FOR ALGO = {end-start}") + +DiffractionFocussing( + InputWorkspace="peaks_extractDSP_before", + OutputWorkspace="peaks_extractDSP_before_foc", + GroupingWorkspace=focusWorkspace, +) +DiffractionFocussing( + InputWorkspace="peaks_extractDSP_after", + OutputWorkspace="peaks_extractDSP_after_foc", + GroupingWorkspace=focusWorkspace, +) \ No newline at end of file diff --git a/tests/cis_tests/diffcal_pixel_diffraction_background_subtraction_script.py b/tests/cis_tests/diffcal_pixel_diffraction_background_subtraction_script.py index 98ef358f6..02c12deac 100644 --- a/tests/cis_tests/diffcal_pixel_diffraction_background_subtraction_script.py +++ b/tests/cis_tests/diffcal_pixel_diffraction_background_subtraction_script.py @@ -1,8 +1,10 @@ # Use this script to test Pixel Diffraction Background Subtraction +import snapred.backend.recipe.algorithm from mantid.simpleapi import * import matplotlib.pyplot as plt import numpy as np import json +import time from typing import List @@ -18,18 +20,18 @@ from snapred.backend.recipe.PixelDiffCalRecipe import PixelDiffCalRecipe as PixelRx from snapred.meta.Config import Config +from snapred.meta.pointer import create_pointer #User input ########################### runNumber = "58882" groupingScheme = "Column" -cifPath = "/SNS/SNAP/shared/Calibration/CalibrantSamples/cif/Silicon_NIST_640d.cif" calibrantSamplePath = "Silicon_NIST_640D_001.json" peakThreshold = 0.05 offsetConvergenceLimit = 0.1 isLite = True removeBackground = True Config._config["cis_mode"] = True -Config._config["diffraction.smoothingParameter"] = 0.01 #This is the smoothing parameter to be set. +Config._config["diffraction.smoothingParameter"] = 0.5 #This is the smoothing parameter to be set. ####################################### ### PREP INGREDIENTS ################ @@ -37,8 +39,8 @@ runNumber=runNumber, useLiteMode=isLite, focusGroups=[{"name": groupingScheme, "definition": ""}], - cifPath=cifPath, calibrantSamplePath=calibrantSamplePath, + peakIntensityThreshold=peakThreshold, convergenceThreshold=offsetConvergenceLimit, maxOffset=100.0, ) @@ -57,18 +59,17 @@ outputWorkspace="_out_", diagnosticWorkspace="_diag", maskWorkspace="_mask_", - calibrationTable="_DIFC_", + calibrationTable="_DIFC_", ) ### RUN PIXEL CALIBRATION ########## pixelRx = PixelRx() -pixelRx.prep(ingredients, groceries) -pixelRes = pixelRx.execute() +pixelRx.cook(ingredients, groceries) ### PREPARE OUTPUTS ################ DiffractionFocussing( - InputWorkspace=f"dsp_0{runNumber}_raw_startOfPixelDiffCal", + InputWorkspace=f"dsp_0{runNumber}_raw_beforeCrossCor", OutputWorkspace="BEFORE_REMOVAL", GroupingWorkspace=groceries["groupingWorkspace"], ) @@ -77,9 +78,4 @@ OutputWorkspace="AFTER_REMOVAL", GroupingWorkspace=groceries["groupingWorkspace"], ) -DiffractionFocussing( - InputWorkspace="tof_all_lite_copy1_058882", - OutputWorkspace="FINAL", - GroupingWorkspace=groceries["groupingWorkspace"], -) - + diff --git a/tests/unit/backend/recipe/algorithm/test_RemoveEventBackground.py b/tests/unit/backend/recipe/algorithm/test_RemoveSmoothedBackground.py similarity index 74% rename from tests/unit/backend/recipe/algorithm/test_RemoveEventBackground.py rename to tests/unit/backend/recipe/algorithm/test_RemoveSmoothedBackground.py index ed1f97d5c..6c25940b3 100644 --- a/tests/unit/backend/recipe/algorithm/test_RemoveEventBackground.py +++ b/tests/unit/backend/recipe/algorithm/test_RemoveSmoothedBackground.py @@ -2,17 +2,18 @@ from mantid.simpleapi import ( ConvertToEventWorkspace, + Rebin, mtd, ) from util.diffraction_calibration_synthetic_data import SyntheticData from snapred.backend.dao.GroupPeakList import GroupPeakList -from snapred.backend.recipe.algorithm.RemoveEventBackground import RemoveEventBackground as Algo +from snapred.backend.recipe.algorithm.RemoveSmoothedBackground import RemoveSmoothedBackground as Algo from snapred.backend.recipe.algorithm.Utensils import Utensils from snapred.meta.pointer import create_pointer -class TestRemoveEventBackground(unittest.TestCase): +class TestRemoveSmoothedBackground(unittest.TestCase): def setUp(self): inputs = SyntheticData() self.fakeIngredients = inputs.ingredients @@ -22,6 +23,16 @@ def setUp(self): self.fakeGroupingWorkspace = f"_test_remove_event_background_{runNumber}_grouping" self.fakeMaskWorkspace = f"_test_remove_event_background_{runNumber}_mask" inputs.generateWorkspaces(self.fakeData, self.fakeGroupingWorkspace, self.fakeMaskWorkspace) + # this algorithm requires event workspacws + ConvertToEventWorkspace( + InputWorkspace=self.fakeData, + OutputWorkspace=self.fakeData, + ) + Rebin( + InputWorkspace=self.fakeData, + OutputWorkspace=self.fakeData, + Params=self.fakeIngredients.pixelGroup.timeOfFlight.params, + ) def tearDown(self) -> None: mtd.clear() @@ -37,6 +48,12 @@ def create_test_peaks(self): ] return peaks + def test_validate_inputs(self): + algo = Algo() + algo.initialize() + err = algo.validateInputs() + assert "DetectorPeaks" in err + def test_chop_ingredients(self): peaks = self.create_test_peaks() @@ -49,38 +66,13 @@ def test_chop_ingredients(self): for peakList in peaks: groupID = peakList.groupID - assert groupID in algo.maskRegions, f"Group ID {groupID} not found in maskRegions" - - expected_peak_count = len(peakList.peaks) - actual_peak_count = len(algo.maskRegions[groupID]) - assert ( - actual_peak_count == expected_peak_count - ), f"Mismatch in number of peaks for group {groupID}: expected {expected_peak_count}, found {actual_peak_count}" # noqa: E501 - - for peak, mask in zip(peakList.peaks, algo.maskRegions[groupID]): - assert ( - mask == (peak.minimum, peak.maximum) - ), f"Mask region mismatch for group {groupID}, peak {peak}: expected {(peak.minimum, peak.maximum)}, found {mask}" # noqa: E501 + assert groupID in algo.groupIDs, f"Group ID {groupID} not found in maskRegions" expected_group_ids = [peakList.groupID for peakList in peaks] assert ( algo.groupIDs == expected_group_ids ), f"Group IDs in workspace and peak list do not match: {algo.groupIDs} vs {expected_group_ids}" - def test_execute(self): - peaks = self.create_test_peaks() - ConvertToEventWorkspace( - InputWorkspace=self.fakeData, - OutputWorkspace=self.fakeData, - ) - algo = Algo() - algo.initialize() - algo.setProperty("InputWorkspace", self.fakeData) - algo.setProperty("GroupingWorkspace", self.fakeGroupingWorkspace) - algo.setProperty("DetectorPeaks", create_pointer(peaks)) - algo.setProperty("OutputWorkspace", "output_test_ws") - assert algo.execute() - def test_incorrect_group_ids(self): peaks = [ GroupPeakList(peaks=SyntheticData.fakeDetectorPeaks(), groupID=999), # Incorrect group ID @@ -107,11 +99,6 @@ def test_missing_properties(self): with self.assertRaises(RuntimeError): # noqa: PT027 algo.execute() - ConvertToEventWorkspace( - InputWorkspace=self.fakeData, - OutputWorkspace=self.fakeData, - ) - algo = Algo() algo.initialize() algo.setProperty("InputWorkspace", self.fakeData) @@ -129,10 +116,6 @@ def test_missing_properties(self): algo.execute() def test_smoothing_parameter_edge_cases(self): - ConvertToEventWorkspace( - InputWorkspace=self.fakeData, - OutputWorkspace=self.fakeData, - ) algo = Algo() algo.initialize() @@ -146,35 +129,22 @@ def test_smoothing_parameter_edge_cases(self): algo.setProperty("SmoothingParameter", value) assert value == algo.getProperty("SmoothingParameter").value - def test_output_workspace_creation(self): + def test_execute(self): peaks = self.create_test_peaks() - - ConvertToEventWorkspace( - InputWorkspace=self.fakeData, - OutputWorkspace=self.fakeData, - ) - algo = Algo() algo.initialize() algo.setProperty("InputWorkspace", self.fakeData) algo.setProperty("GroupingWorkspace", self.fakeGroupingWorkspace) algo.setProperty("DetectorPeaks", create_pointer(peaks)) algo.setProperty("OutputWorkspace", "output_test_ws") - assert algo.execute() - assert "output_test_ws" in mtd, "Output workspace not found in the Mantid workspace dictionary" - output_ws = mtd["output_test_ws"] # noqa: F841 def test_execute_from_mantidSnapper(self): peaks = self.create_test_peaks() - ConvertToEventWorkspace( - InputWorkspace=self.fakeData, - OutputWorkspace=self.fakeData, - ) utensils = Utensils() utensils.PyInit() - utensils.mantidSnapper.RemoveEventBackground( + utensils.mantidSnapper.RemoveSmoothedBackground( "Run in mantid snapper", InputWorkspace=self.fakeData, GroupingWorkspace=self.fakeGroupingWorkspace, diff --git a/tests/unit/backend/recipe/test_PixelDiffCalRecipe.py b/tests/unit/backend/recipe/test_PixelDiffCalRecipe.py index 761395a0c..1333557cf 100644 --- a/tests/unit/backend/recipe/test_PixelDiffCalRecipe.py +++ b/tests/unit/backend/recipe/test_PixelDiffCalRecipe.py @@ -13,20 +13,12 @@ from snapred.backend.recipe.PixelDiffCalRecipe import PixelDiffCalRecipe as Recipe from snapred.meta.Config import Config -""" -NOTE this is in fact a test of a recipe. Its location and name are a -TEMPORARY assignment as part of a refactor. This helps the git diff -be as useful as possible to reviewing devs. -As soon as the change with this string is merged, this file can be -renamed to `test_PixelDiffCalReipe.py` and moved to the recipe tests folder -""" - class TestPixelDiffCalRecipe(unittest.TestCase): def setUp(self): """Create a set of mocked ingredients for calculating DIFC corrected by offsets""" inputs = SyntheticData() - self.ingredients = inputs.ingredients + self.ingredients = inputs.ingredients.copy() runNumber = self.ingredients.runConfig.runNumber fakeRawData = f"_test_pixelcal_{runNumber}" @@ -53,7 +45,18 @@ def test_chop_ingredients(self): assert rx.runNumber == self.ingredients.runConfig.runNumber assert rx.overallDMin == min(self.ingredients.pixelGroup.dMin()) assert rx.overallDMax == max(self.ingredients.pixelGroup.dMax()) - assert rx.dBin == max([abs(db) for db in self.ingredients.pixelGroup.dBin()]) + assert rx.dBin == min([abs(db) for db in self.ingredients.pixelGroup.dBin()]) + + def test_removeBackground(self): + ingredients = self.ingredients.copy() + ingredients.removeBackground = True + + rx = Recipe() + rx.chopIngredients(ingredients) + rx.unbagGroceries(self.groceries) + algoQueue = rx.mantidSnapper._algorithmQueue + algoNames = [x[0] for x in algoQueue] + assert "RemoveSmoothedBackground" in algoNames def test_execute(self): """Test that the algorithm executes""" From eb12b3e0248baf9723090b67deed8f76ee52ea04 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Dec 2024 11:07:50 -0500 Subject: [PATCH 7/7] [pre-commit.ci] pre-commit autoupdate (#518) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.8.1 → v0.8.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.8.1...v0.8.3) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 46a5eee4c..107f2d6cc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - id: trailing-whitespace exclude: "tests/cis_tests/.*" - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.1 + rev: v0.8.3 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix]