Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calibration: Add source names for detector modules (using source_name_pattern) #146

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Added:

- A helper function named [fit_gaussian()][extra.utils.fit_gaussian] (!131).
- Support for module source names in `extra.calibration` (!146).

## [2024.1.1]

Expand Down
149 changes: 114 additions & 35 deletions src/extra/calibration.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import ast
import json
import re
from collections.abc import Mapping
from dataclasses import dataclass, field, replace
from datetime import date, datetime, time, timezone
from functools import lru_cache
from pathlib import Path
from string import Formatter
from typing import Dict, List, Optional, Union
from urllib.parse import urljoin
from warnings import warn
Expand Down Expand Up @@ -37,6 +39,21 @@ def __str__(self):
return f"No module named {self.name!r}"


class SourceExprChecker(ast.NodeVisitor):
def visit_Call(self, node):
raise ValueError("Function calls not allowed in source name patterns")


class SourceNameFormatter(Formatter):
"""String formatter that evaluates simple operations like {modno + 2}"""

def get_field(self, field_name, args, kwargs):
node = ast.parse(field_name, "<source pattern>", "eval")
SourceExprChecker().visit(node)
obj = eval(compile(node, "<source pattern>", "eval"), kwargs)
return obj, 0


class CalCatAPIError(requests.HTTPError):
"""Used when the response includes error details as JSON"""

Expand All @@ -58,6 +75,7 @@ def __init__(self, base_api_url, oauth_client=None, user_email=""):

def default_headers(self):
from . import __version__

return {
"content-type": "application/json",
"Accept": "application/json; version=2",
Expand Down Expand Up @@ -184,15 +202,15 @@ def get_client():


def setup_client(
base_url,
client_id,
client_secret,
user_email,
scope="",
session_token=None,
oauth_retries=3,
oauth_timeout=12,
ssl_verify=True,
base_url,
client_id,
client_secret,
user_email,
scope="",
session_token=None,
oauth_retries=3,
oauth_timeout=12,
ssl_verify=True,
):
"""Configure the global CalCat API client."""
global global_client
Expand Down Expand Up @@ -297,7 +315,6 @@ def _load_calcat_metadata(self, client=None):
self._metadata = calcat_meta | self._metadata
self._have_calcat_metadata = True


def metadata(self, key, client=None):
"""Get a specific metadata field, e.g. 'begin_validity_at'

Expand All @@ -322,14 +339,25 @@ def metadata_dict(self, client=None):


def prepare_selection(
module_details, module_nums=None, aggregator_names=None, qm_names=None
module_details,
module_nums=None,
aggregator_names=None,
qm_names=None,
source_names=None,
):
aggs = aggregator_names # Shorter name -> fewer multi-line statements
n_specified = sum([module_nums is not None, aggs is not None, qm_names is not None])
n_specified = sum(
[
module_nums is not None,
aggs is not None,
qm_names is not None,
source_names is not None,
]
)
if n_specified > 1:
raise TypeError(
"select_modules() accepts only one of module_nums, aggregator_names "
"& qm_names"
"select_modules() accepts only one of module_nums, aggregator_names, "
"qm_names & source_names"
)

if module_nums is not None:
Expand All @@ -338,6 +366,11 @@ def prepare_selection(
elif qm_names is not None:
by_qm = {m["virtual_device_name"]: m for m in module_details}
return [by_qm[s]["karabo_da"] for s in qm_names]
elif source_names is not None:
by_src = {m["source_name"]: m for m in module_details if m.get("source_name")}
if not by_src:
raise KeyError("No source names found for this detector")
return [by_src[s]["karabo_da"] for s in source_names]
elif aggs is not None:
miss = set(aggs) - {m["karabo_da"] for m in module_details}
if miss:
Expand Down Expand Up @@ -383,7 +416,12 @@ def __getitem__(self, key):
candidate_kdas.add(key)

for m in self.module_details:
names = (m["module_number"], m["virtual_device_name"], m["physical_name"])
names = (
m["module_number"],
m["virtual_device_name"],
m["physical_name"],
m.get("source_name", None),
)
if key in names and m["karabo_da"] in self.constants:
candidate_kdas.add(m["karabo_da"])

Expand All @@ -395,14 +433,19 @@ def __getitem__(self, key):
return self.constants[candidate_kdas.pop()]

def select_modules(
self, module_nums=None, *, aggregator_names=None, qm_names=None
self,
module_nums=None,
*,
aggregator_names=None,
qm_names=None,
source_names=None,
) -> "MultiModuleConstant":
"""Return a new `MultiModuleConstant` object with only the selected modules

One of `module_nums`, `aggregator_names` or `qm_names` must be specified.
"""
aggs = prepare_selection(
self.module_details, module_nums, aggregator_names, qm_names
self.module_details, module_nums, aggregator_names, qm_names, source_names
)
d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggs}
mods = [m for m in self.module_details if m["karabo_da"] in d]
Expand Down Expand Up @@ -444,6 +487,17 @@ def pdu_names(self):
if m["karabo_da"] in self.constants
]

@property
def source_names(self):
"""Source names of the detector units making up the detector.

Only includes modules where we have this constant."""
return [
m["source_name"]
for m in self.module_details
if m["karabo_da"] in self.constants
]

def ndarray(self, caldb_root=None, *, parallel=0):
"""Load this constant as a Numpy array.

Expand Down Expand Up @@ -532,13 +586,13 @@ def _format_cond(condition):

@classmethod
def from_condition(
cls,
condition: "ConditionsBase",
detector_name,
calibrations=None,
client=None,
event_at=None,
pdu_snapshot_at=None,
cls,
condition: "ConditionsBase",
detector_name,
calibrations=None,
client=None,
event_at=None,
pdu_snapshot_at=None,
):
"""Look up constants for the given detector conditions & timestamp.

Expand All @@ -557,18 +611,22 @@ def from_condition(

client = client or get_client()

detector_id = client.detector_by_identifier(detector_name)["id"]
calcat_detector = client.detector_by_identifier(detector_name)
pdus = client.get(
"physical_detector_units/get_all_by_detector",
{
"detector_id": detector_id,
"detector_id": calcat_detector["id"],
"pdu_snapshot_at": client.format_time(pdu_snapshot_at),
},
)
module_details = sorted(pdus, key=lambda d: d["karabo_da"])
for mod in module_details:
if mod.get("module_number") is None:
mod["module_number"] = int(re.findall(r"\d+", mod["karabo_da"])[-1])
if (source_pat := calcat_detector["source_name_pattern"]) is not None:
mod["source_name"] = SourceNameFormatter().format(
source_pat, modno=mod["module_number"]
)

constant_groups = {}

Expand Down Expand Up @@ -603,9 +661,9 @@ def from_condition(

@classmethod
def from_report(
cls,
report_id_or_path: Union[int, str],
client=None,
cls,
report_id_or_path: Union[int, str],
client=None,
):
"""Look up constants by a report ID or path.

Expand Down Expand Up @@ -658,7 +716,14 @@ def from_report(
if len(det_ids) > 1:
raise Exception(f"Found multiple detector IDs in report: {det_ids}")
# The "identifier", "name" & "karabo_name" fields seem to have the same names
det_name = client.detector_by_id(det_ids.pop())["identifier"]
calcat_detector = client.detector_by_id(det_ids.pop())
det_name = calcat_detector["identifier"]

if (source_pat := calcat_detector["source_name_pattern"]) is not None:
for pdu in pdus.values():
pdu["source_name"] = SourceNameFormatter().format(
source_pat, modno=pdu["module_number"]
)

module_details = sorted(pdus.values(), key=lambda d: d["karabo_da"])
return cls(constant_groups, module_details, det_name)
Expand Down Expand Up @@ -714,6 +779,11 @@ def pdu_names(self):
May include missing modules."""
return [m["physical_name"] for m in self.module_details]

@property
def source_names(self):
"Source names of the detector modules. May include missing modules."
return [m["source_name"] for m in self.module_details]

def require_calibrations(self, calibrations) -> "CalibrationData":
"""Drop any modules missing the specified constant types"""
mods = set(self.aggregator_names)
Expand All @@ -722,7 +792,12 @@ def require_calibrations(self, calibrations) -> "CalibrationData":
return self.select_modules(aggregator_names=mods)

def select_modules(
self, module_nums=None, *, aggregator_names=None, qm_names=None
self,
module_nums=None,
*,
aggregator_names=None,
qm_names=None,
source_names=None,
) -> "CalibrationData":
"""Return a new `CalibrationData` object with only the selected modules

Expand All @@ -731,7 +806,7 @@ def select_modules(
# Validate the specified modules against those we know about.
# Each specific constant type may have only a subset of these modules.
aggs = prepare_selection(
self.module_details, module_nums, aggregator_names, qm_names
self.module_details, module_nums, aggregator_names, qm_names, source_names
)
constant_groups = {}
matched_aggregators = set()
Expand Down Expand Up @@ -815,12 +890,13 @@ def make_dict(self, parameters) -> dict:
@dataclass
class AGIPDConditions(ConditionsBase):
"""Conditions for AGIPD detectors"""

sensor_bias_voltage: float
memory_cells: int
acquisition_rate: float
gain_setting: Optional[int]
gain_mode: Optional[int]
source_energy: float
gain_setting: Optional[int] = None
gain_mode: Optional[int] = None
source_energy: Optional[float] = None
integration_time: int = 12
pixels_x: int = 512
pixels_y: int = 128
Expand Down Expand Up @@ -864,6 +940,7 @@ def make_dict(self, parameters):
@dataclass
class LPDConditions(ConditionsBase):
"""Conditions for LPD detectors"""

sensor_bias_voltage: float
memory_cells: int
memory_cell_order: Optional[str] = None
Expand Down Expand Up @@ -899,6 +976,7 @@ class LPDConditions(ConditionsBase):
@dataclass
class DSSCConditions(ConditionsBase):
"""Conditions for DSSC detectors"""

sensor_bias_voltage: float
memory_cells: int
pulse_id_checksum: Optional[float] = None
Expand Down Expand Up @@ -927,6 +1005,7 @@ class DSSCConditions(ConditionsBase):
@dataclass
class JUNGFRAUConditions(ConditionsBase):
"""Conditions for JUNGFRAU detectors"""

sensor_bias_voltage: float
memory_cells: int
integration_time: float
Expand Down
Loading