Skip to content

Commit

Permalink
Historical SNAPInstPrm (#540)
Browse files Browse the repository at this point in the history
* refactor indexer to rounte saves and loads through a generic VersionedObject form of each method

integrate new lookup, passing tests

fix some tests, break others maybe

fixes from manual testing and script for initializing the new indexed snap inst prm

Cleanup some names

added some additional test coverage of new code

always pull new instrument state when writing new calibrations

fix tests

point at new data commit for changes

up coverage a bit

last bit of coverage?

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix missed rebase conflict?

* fix tests, give validate_arguments better stack trace

* take a guess at why remote is failing but not local?

* another change to try and fix the remote weirdness

* update in response to comments

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
walshmm and pre-commit-ci[bot] authored Feb 19, 2025
1 parent 4b6648a commit 677c5b2
Show file tree
Hide file tree
Showing 33 changed files with 439 additions and 246 deletions.
6 changes: 2 additions & 4 deletions src/snapred/backend/dao/InstrumentConfig.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from pydantic import BaseModel

from snapred.backend.dao.indexing.Versioning import VersionedObject
from snapred.meta.Config import Config


class InstrumentConfig(BaseModel):
class InstrumentConfig(VersionedObject):
"""Class to hold the instrument parameters."""

version: str
facility: str
name: str
nexusFileExtension: str
Expand Down
32 changes: 19 additions & 13 deletions src/snapred/backend/dao/indexing/IndexEntry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,20 @@ class IndexEntry(VersionedObject, extra="ignore"):
author: Optional[str] = None
timestamp: float = Field(default_factory=lambda: time.time())

def parseAppliesTo(appliesTo: str):
@classmethod
def parseConditional(cls, conditional: str):
symbols = [">=", "<=", "<", ">"]
# find first
symbol = next((s for s in symbols if s in appliesTo), "")
symbol = next((s for s in symbols if s in conditional), "")
# parse runnumber
runNumber = appliesTo if symbol == "" else appliesTo.split(symbol)[-1]
runNumber = conditional if symbol == "" else conditional.split(symbol)[-1]
return symbol, runNumber

@classmethod
def parseAppliesTo(cls, appliesTo: str):
conditionals = appliesTo.split(",")
return [cls.parseConditional(c.strip()) for c in conditionals]

@field_validator("appliesTo", mode="before")
def appliesToFormatChecker(cls, v):
"""
Expand All @@ -44,16 +50,16 @@ def appliesToFormatChecker(cls, v):
"""
testValue = v
if testValue is not None:
symbol, _ = cls.parseAppliesTo(v)
if symbol != "":
testValue = testValue.split(symbol)[-1]
try:
int(testValue)
except ValueError:
raise ValueError(
"appliesTo must be in the format of 'runNumber',"
"or '{{symbol}}runNumber' where symbol is one of '>', '<', '>=', '<='.."
)
conditionals = cls.parseAppliesTo(v)
for _, runNumber in conditionals:
try:
# if runnumber isnt just an int, there were extra unrecognized characters
int(runNumber)
except ValueError:
raise ValueError(
"appliesTo must be in the format of 'runNumber',"
"or '{{symbol}}runNumber, ...' where symbol is one of '>', '<', '>=', '<='.."
)

return v

Expand Down
2 changes: 2 additions & 0 deletions src/snapred/backend/dao/indexing/Versioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class VersionState(StrEnum):
DEFAULT = Config["version.friendlyName.default"]
LATEST = "latest"
NEXT = "next"
# NOTE: This is only so we may read old saved IntrumentConfigs
LEGACY_INST_PRM = "1.4"


# I'm not sure why ci is failing without this, it doesn't seem to be used anywhere
Expand Down
2 changes: 1 addition & 1 deletion src/snapred/backend/data/DataFactoryService.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def getRunConfig(self, runId: str) -> RunConfig: # noqa: ARG002
return self.lookupService.readRunConfig(runId)

def getInstrumentConfig(self, runId: str) -> InstrumentConfig: # noqa: ARG002
return self.lookupService.getInstrumentConfig()
return self.lookupService.readInstrumentConfig(runId)

def getStateConfig(self, runId: str, useLiteMode: bool) -> StateConfig: # noqa: ARG002
return self.lookupService.readStateConfig(runId, useLiteMode)
Expand Down
32 changes: 24 additions & 8 deletions src/snapred/backend/data/GroceryService.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def fetchNeutronDataSingleUse(self, item: GroceryListItem) -> Dict[str, Any]:

case (True, _, True, _):
# lite mode and lite-mode exists on disk
data = self.grocer.executeRecipe(str(liteModeFilePath), workspaceName, loader)
data = self.grocer.executeRecipe(str(liteModeFilePath), workspaceName, loader, runNumber=runNumber)
success = True

case (True, True, _, _):
Expand All @@ -821,13 +821,17 @@ def fetchNeutronDataSingleUse(self, item: GroceryListItem) -> Dict[str, Any]:

case (True, _, _, True):
# lite mode and native exists on disk
data = self.grocer.executeRecipe(str(nativeModeFilePath), workspaceName, loader)
data = self.grocer.executeRecipe(
str(nativeModeFilePath), workspaceName, loader, runNumber=runNumber
)
convertToLiteMode = True
success = True

case (False, _, _, True):
# native mode and native exists on disk
data = self.grocer.executeRecipe(str(nativeModeFilePath), workspaceName, loader)
data = self.grocer.executeRecipe(
str(nativeModeFilePath), workspaceName, loader, runNumber=runNumber
)
success = True

case _:
Expand All @@ -850,7 +854,10 @@ def fetchNeutronDataSingleUse(self, item: GroceryListItem) -> Dict[str, Any]:
"StartTime": startTime,
}
data = self.grocer.executeRecipe(
workspace=workspaceName, loader="LoadLiveData", loaderArgs=json.dumps(loaderArgs)
workspace=workspaceName,
loader="LoadLiveData",
loaderArgs=json.dumps(loaderArgs),
runNumber=runNumber,
)
if data["result"]:
logs = self.mantidSnapper.mtd[workspaceName].getRun()
Expand Down Expand Up @@ -928,7 +935,9 @@ def fetchNeutronDataCached(self, item: GroceryListItem) -> Dict[str, Any]:

case (True, _, True, _):
# lite mode and lite-mode exists on disk
data = self.grocer.executeRecipe(str(liteModeFilePath), rawWorkspaceName, loader)
data = self.grocer.executeRecipe(
str(liteModeFilePath), rawWorkspaceName, loader, runNumber=runNumber
)
self._loadedRuns[key] = 0
success = True

Expand All @@ -941,14 +950,18 @@ def fetchNeutronDataCached(self, item: GroceryListItem) -> Dict[str, Any]:
case (True, _, _, True):
# lite mode and native exists on disk
goingNative = self._key(runNumber, False)
data = self.grocer.executeRecipe(str(nativeModeFilePath), nativeRawWorkspaceName, loader="")
data = self.grocer.executeRecipe(
str(nativeModeFilePath), nativeRawWorkspaceName, loader="", runNumber=runNumber
)
self._loadedRuns[self._key(*goingNative)] = 0
convertToLiteMode = True
success = True

case (False, _, _, True):
# native mode and native exists on disk
data = self.grocer.executeRecipe(str(nativeModeFilePath), nativeRawWorkspaceName, loader)
data = self.grocer.executeRecipe(
str(nativeModeFilePath), nativeRawWorkspaceName, loader, runNumber=runNumber
)
self._loadedRuns[key] = 0
success = True

Expand All @@ -973,7 +986,10 @@ def fetchNeutronDataCached(self, item: GroceryListItem) -> Dict[str, Any]:
"StartTime": startTime,
}
data = self.grocer.executeRecipe(
workspace=nativeRawWorkspaceName, loader="LoadLiveData", loaderArgs=json.dumps(loaderArgs)
workspace=nativeRawWorkspaceName,
loader="LoadLiveData",
loaderArgs=json.dumps(loaderArgs),
runNumber=runNumber,
)
if data["result"]:
logs = self.mantidSnapper.mtd[nativeRawWorkspaceName].getRun()
Expand Down
111 changes: 73 additions & 38 deletions src/snapred/backend/data/Indexer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import sys
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Type, TypeVar

from pydantic import validate_call

from snapred.backend.dao import InstrumentConfig
from snapred.backend.dao.calibration.Calibration import Calibration
from snapred.backend.dao.calibration.CalibrationRecord import CalibrationDefaultRecord, CalibrationRecord
from snapred.backend.dao.indexing.CalculationParameters import CalculationParameters
Expand All @@ -21,6 +22,7 @@

logger = snapredLogger.getLogger(__name__)

T = TypeVar("T", bound=VersionedObject)

"""
The Indexer will automatically track versions and produce the next and current versions.
Expand All @@ -44,13 +46,15 @@ class IndexerType(StrEnum):
CALIBRATION = "Calibration"
NORMALIZATION = "Normalization"
REDUCTION = "Reduction"
INSTRUMENT_PARAMETER = "InstrumentParameter"


# the record type for each indexer type
RECORD_TYPE = {
IndexerType.CALIBRATION: CalibrationRecord,
IndexerType.NORMALIZATION: NormalizationRecord,
IndexerType.REDUCTION: ReductionRecord,
IndexerType.INSTRUMENT_PARAMETER: None,
IndexerType.DEFAULT: Record,
}

Expand All @@ -66,6 +70,17 @@ class IndexerType(StrEnum):
IndexerType.DEFAULT: CalculationParameters,
}

FRIENDLY_NAME_MAPPING = {
Calibration.__name__: "CalibrationParameters",
CalibrationDefaultRecord.__name__: "CalibrationRecord",
CalibrationRecord.__name__: "CalibrationRecord",
Normalization.__name__: "NormalizationParameters",
NormalizationRecord.__name__: "NormalizationRecord",
ReductionRecord.__name__: "ReductionRecord",
InstrumentConfig.__name__: "SNAPInstPrm",
CalculationParameters.__name__: "CalculationParameters",
}


class Indexer:
rootDirectory: Path
Expand Down Expand Up @@ -218,9 +233,11 @@ def _isApplicableEntry(self, entry: IndexEntry, runNumber1: str):
"""
Checks to see if an entry in the index applies to a given run id via numerical comparison.
"""

symbol, runNumber2 = self._parseAppliesTo(entry.appliesTo)
return self._compareRunNumbers(runNumber1, runNumber2, symbol)
isApplicable = True
conditionals = self._parseAppliesTo(entry.appliesTo)
for symbol, runNumber2 in conditionals:
isApplicable = isApplicable and self._compareRunNumbers(runNumber1, runNumber2, symbol)
return isApplicable

def _parseAppliesTo(self, appliesTo: str):
return IndexEntry.parseAppliesTo(appliesTo)
Expand Down Expand Up @@ -330,15 +347,8 @@ def readRecord(self, version: int) -> Record:
"""
If no version given, defaults to current version
"""
filePath = self.recordPath(version)
record = None
if filePath.exists():
record = parse_file_as(self._determineRecordType(version), filePath)
else:
raise FileNotFoundError(
f"No record found at {filePath} for version {version}, latest version is {self.currentVersion()}"
)
return record
recordType = self._determineRecordType(version)
return self.readVersionedObject(recordType, version)

def _flattenVersion(self, version: Version):
"""
Expand Down Expand Up @@ -382,24 +392,64 @@ def writeNewVersion(self, record: Record, entry: IndexEntry):
self.addIndexEntry(entry)
# make sure they flatten to the same value.
record.version = entry.version
record.calculationParameters.version = entry.version
self.writeRecord(record)

def writeRecord(self, record: Record):
def versionedObjectPath(self, type_: Type[T], version: Version):
"""
Will save at the version on the record.
Path to a specific version of a calculation record
"""
fileName = FRIENDLY_NAME_MAPPING.get(type_.__name__, type_.__name__)
return self.versionPath(version) / f"{fileName}.json"

def writeNewVersionedObject(self, obj: VersionedObject, entry: IndexEntry):
"""
Coupled write of parameters and an index entry.
As required for new parameters.
"""
if self.versionExists(obj.version):
raise ValueError(f"Version {obj.version} already exists in index, please write a new version.")

self.addIndexEntry(entry)
# make sure they flatten to the same value.
obj.version = entry.version
self.writeVersionedObject(obj)

def writeVersionedObject(self, obj: VersionedObject):
"""
Will save at the version on the object.
If the version is invalid, will throw an error and refuse to save.
"""
record.version = self._flattenVersion(record.version)
obj.version = self._flattenVersion(obj.version)

if not self.versionExists(record.version):
raise ValueError(f"Version {record.version} not found in index, please write an index entry first.")
if not self.versionExists(obj.version):
raise ValueError(f"Version {obj.version} not found in index, please write an index entry first.")

filePath = self.recordPath(record.version)
filePath = self.versionedObjectPath(type(obj), obj.version)
filePath.parent.mkdir(parents=True, exist_ok=True)

write_model_pretty(record, filePath)
write_model_pretty(obj, filePath)

self.dirVersions.add(record.version)
self.dirVersions.add(obj.version)

def readVersionedObject(self, type_: Type[T], version: Version) -> VersionedObject:
"""
If no version given, defaults to current version
"""
filePath = self.versionedObjectPath(type_, version)
obj = None
if filePath.exists():
obj = parse_file_as(type_, filePath)
else:
raise FileNotFoundError(f"No {type_.__name__} found at {filePath} for version {version}")
return obj

def writeRecord(self, record: Record):
"""
Will save at the version on the record.
If the version is invalid, will throw an error and refuse to save.
"""
self.writeVersionedObject(record)

## STATE PARAMETER READ / WRITE METHODS ##

Expand All @@ -413,26 +463,11 @@ def readParameters(self, version: Version) -> CalculationParameters:
"""
If no version given, defaults to current version
"""
filePath = self.parametersPath(version)
parameters = None
if filePath.exists():
parameters = parse_file_as(PARAMS_TYPE[self.indexerType], filePath)
else:
raise FileNotFoundError(
f"No {self.indexerType} calculation parameters found at {filePath} for version {version}"
)
return parameters
return self.readVersionedObject(PARAMS_TYPE[self.indexerType], version)

def writeParameters(self, parameters: CalculationParameters):
"""
Will save at the version on the calculation parameters.
If the version is invalid, will throw an error and refuse to save.
"""
parameters.version = self._flattenVersion(parameters.version)
parametersPath = self.parametersPath(parameters.version)
if parametersPath.exists():
logger.warn(f"Overwriting {self.indexerType} parameters at {parametersPath}")
else:
parametersPath.parent.mkdir(parents=True, exist_ok=True)
write_model_pretty(parameters, parametersPath)
self.dirVersions.add(parameters.version)
self.writeVersionedObject(parameters)
Loading

0 comments on commit 677c5b2

Please sign in to comment.