Skip to content

Commit b195255

Browse files
committed
Bug fix: some cloning methods to use super/abstract cloning
1 parent 3692303 commit b195255

File tree

4 files changed

+177
-74
lines changed

4 files changed

+177
-74
lines changed

pyNexafs/nexafs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
This includes treating the data, such as normalising it, and performing background subtraction.
55
"""
66

7-
from pyNexafs.nexafs.scan import scan_base, scan_abstract
7+
from pyNexafs.nexafs.scan import scan_base, scan_abstract, scan_simple
88
from pyNexafs.nexafs.scan_normalised import (
99
scan_normalised,
1010
scan_normalised_edges,

pyNexafs/nexafs/scan.py

Lines changed: 70 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,44 @@ def __init__(self) -> None:
5050
self._y_units = [] # List[str]
5151
return
5252

53+
@abc.abstractmethod
54+
def copy(self, *args, **kwargs) -> scan_abstract:
55+
"""
56+
Creates a copy of the scan object.
57+
Reloads parser object data, but links to the same parser object as `self`.
58+
59+
Returns
60+
-------
61+
scan_abstract
62+
A copy of the scan object with unique data.
63+
"""
64+
newobj = type(self)(*args, **kwargs)
65+
66+
# Copy Data
67+
newobj._x = self._x.copy() if self._x is not None else None
68+
newobj._y = self._y.copy() if self._y is not None else None
69+
newobj._y_errs = self._y_errs.copy() if self._y_errs is not None else None
70+
newobj._x_errs = self._x_errs.copy() if self._x_errs is not None else None
71+
72+
# Copy Labels and Units
73+
newobj._x_label = self._x_label
74+
newobj._x_unit = self._x_unit
75+
if self._y_labels is not None:
76+
if isinstance(self._y_labels, list):
77+
newobj._y_labels = [label for label in self._y_labels]
78+
else: # string
79+
newobj._y_labels = self._y_labels
80+
else:
81+
newobj._y_labels = None
82+
if self._y_units is not None:
83+
if isinstance(self._y_units, list):
84+
newobj._y_units = [unit for unit in self._y_units]
85+
else: # string
86+
newobj._y_units = self._y_units
87+
else:
88+
newobj._y_units = None
89+
return newobj
90+
5391
@property
5492
def ctime(self) -> datetime.datetime:
5593
"""
@@ -366,44 +404,6 @@ def y_units(self, units: list[str] | None):
366404
raise ValueError(f"Provided 'units' {units} is not a list of strings.")
367405
return
368406

369-
def copy(self, *args, **kwargs) -> Type[Self]:
370-
"""
371-
Creates a copy of the scan object.
372-
Does reload parser object data, but does link to the same parser object.
373-
374-
Returns
375-
-------
376-
Type[scan_base]
377-
A copy of the scan object with unique data.
378-
"""
379-
newobj = type(self)(parser=None, *args, **kwargs)
380-
newobj.parser = self.parser
381-
382-
# Copy Data
383-
newobj._x = self._x.copy()
384-
newobj._y = self._y.copy()
385-
newobj._y_errs = self._y_errs.copy() if self._y_errs is not None else None
386-
newobj._x_errs = self._x_errs.copy() if self._x_errs is not None else None
387-
388-
# Copy Labels and Units
389-
newobj._x_label = self._x_label
390-
newobj._x_unit = self._x_unit
391-
if self._y_labels is not None:
392-
if isinstance(self._y_labels, list):
393-
newobj._y_labels = [label for label in self._y_labels]
394-
else: # string
395-
newobj._y_labels = self._y_labels
396-
else:
397-
newobj._y_labels = None
398-
if self._y_units is not None:
399-
if isinstance(self._y_units, list):
400-
newobj._y_units = [unit for unit in self._y_units]
401-
else: # string
402-
newobj._y_units = self._y_units
403-
else:
404-
newobj._y_units = None
405-
return newobj
406-
407407
def snapshot(self, columns: int = None) -> matplotlib.figure.Figure:
408408
"""
409409
Generates a grid of plots, showing all scan data.
@@ -438,6 +438,31 @@ def reload_labels_from_parser(self) -> None:
438438
pass
439439

440440

441+
class scan_simple(scan_abstract):
442+
"""
443+
Basic interface class for raw data that is not bundled in a parser object.
444+
"""
445+
446+
def __init__(self, x: npt.NDArray, y: npt.NDArray, *args, **kwargs) -> None:
447+
super().__init__(*args, **kwargs)
448+
self._x = x
449+
self._y = y
450+
return
451+
452+
def reload_labels_from_parser(self) -> None:
453+
return None
454+
455+
@overrides.overrides
456+
def copy(self) -> scan_simple:
457+
"""
458+
Creates a copy of the scan object.
459+
460+
Data is unique for a `scan_simple` object, so no need to reload parser data unlike scan_abstract.
461+
"""
462+
new_obj = scan_simple(x=self.x.copy(), y=self.y.copy())
463+
return new_obj
464+
465+
441466
class scan_base(scan_abstract):
442467
"""
443468
Base class for synchrotron measurements that scans across photon beam energies (eV).
@@ -481,6 +506,13 @@ def __init__(
481506
self._load_from_parser(load_all_columns=load_all_columns)
482507
self._all_columns_loaded = load_all_columns
483508

509+
@overrides.overrides
510+
def copy(self) -> type[scan_base]:
511+
newobj = super().copy(
512+
parser=self.parser, load_all_columns=self._all_columns_loaded
513+
)
514+
return newobj
515+
484516
@property
485517
def ctime(self) -> datetime.datetime:
486518
"""

pyNexafs/nexafs/scan_normalised.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from scipy import optimize as sopt
99
from enum import Enum
1010
from types import NoneType
11-
from typing import Type
11+
from typing import Type, Self
1212
from pyNexafs.nexafs.scan import scan_base, scan_abstract
1313

1414

@@ -22,12 +22,26 @@ class scan_abstract_normalised(scan_abstract, metaclass=abc.ABCMeta):
2222
A scan object. Could be already normalised scan or a scan_base object.
2323
"""
2424

25-
def __init__(self, scan: Type[scan_abstract]):
25+
def __init__(self, scan: scan_abstract):
2626
self._origin = scan
2727
return
2828

29+
@overrides.overrides
30+
@abc.abstractmethod
31+
def copy(self) -> scan_abstract_normalised:
32+
"""
33+
Returns a copy of the scan object.
34+
35+
Returns
36+
-------
37+
scan_abstract
38+
A copy of the scan object.
39+
"""
40+
copy_obj = type(self)(scan=self._origin)
41+
return copy_obj
42+
2943
@property
30-
def origin(self) -> Type[scan_base]:
44+
def origin(self) -> scan_abstract:
3145
"""
3246
Property for the original scan object.
3347
@@ -113,7 +127,7 @@ def x_label(self) -> str:
113127
return self.origin.x_label
114128

115129
@x_label.setter
116-
def x_label(self, label: str) -> None:
130+
def x_label(self, label: str | None) -> None:
117131
"""
118132
Property setter for the x-axis label, to the origin scan.
119133
@@ -223,7 +237,7 @@ def y_units(self) -> list[str]:
223237
return self.origin.y_units
224238

225239
@y_units.setter
226-
def y_units(self, units: list[str]) -> None:
240+
def y_units(self, units: list[str] | None) -> None:
227241
"""
228242
Property setter for the y-axis units, to the origin scan.
229243
@@ -245,8 +259,12 @@ def __init__(
245259
raise ValueError("X data for scan and background scan do not match.")
246260
if scan.y.shape != scan_background.y.shape:
247261
raise ValueError("Y data for scan and background scan do not match.")
248-
self._origin = scan
262+
263+
super().__init__(scan)
249264
self._background = scan_background
265+
266+
# Run the normalisation
267+
self.load_and_normalise()
250268
return
251269

252270
@overrides.overrides
@@ -287,6 +305,10 @@ def _scale_from_normalisation_data(self) -> None:
287305
)
288306
return
289307

308+
@overrides.overrides
309+
def copy(self) -> scan_background_subtraction:
310+
return scan_background_subtraction(self._origin, self._background)
311+
290312

291313
class scan_normalised(scan_abstract_normalised):
292314
"""
@@ -585,6 +607,17 @@ def y_units(self, units: list[str] | None) -> None:
585607
self._origin.y_units = units
586608
return
587609

610+
def copy(self) -> Self:
611+
"""
612+
Returns a copy of the scan object.
613+
614+
Returns
615+
-------
616+
scan_abstract
617+
A copy of the scan object.
618+
"""
619+
return scan_normalised(self._origin, norm_channel=self._norm_channel)
620+
588621

589622
class scan_normalised_background_channel(scan_normalised):
590623
"""
@@ -599,8 +632,8 @@ class scan_normalised_background_channel(scan_normalised):
599632
@overrides.overrides
600633
def __init__(
601634
self,
602-
scan: Type[scan_abstract],
603-
background_scan: Type[scan_abstract],
635+
scan: scan_abstract,
636+
background_scan: scan_abstract,
604637
norm_channel: str,
605638
) -> None:
606639
# Save background scan and channel information
@@ -619,7 +652,8 @@ def __init__(
619652
@overrides.overrides
620653
def _load_from_origin(self):
621654
# Reload background_scan data
622-
self._background_scan.load_from_origin()
655+
if hasattr(self._background_scan, "_load_from_origin"):
656+
self._background_scan.load_from_origin()
623657
# Reload scan data
624658
super()._load_from_origin()
625659

@@ -746,9 +780,9 @@ class POSTEDGE_NORM_TYPE(Enum):
746780

747781
def __init__(
748782
self,
749-
scan: Type[scan_base],
750-
pre_edge_domain=list[int] | tuple[float, float] | None,
751-
post_edge_domain=list[int] | tuple[float, float] | None,
783+
scan: Type[scan_abstract],
784+
pre_edge_domain: list[int] | tuple[float, float] | None,
785+
post_edge_domain: list[int] | tuple[float, float] | None,
752786
pre_edge_normalisation: PREEDGE_NORM_TYPE = PREEDGE_NORM_TYPE.CONSTANT,
753787
post_edge_normalisation: POSTEDGE_NORM_TYPE = POSTEDGE_NORM_TYPE.CONSTANT,
754788
pre_edge_level: float | None = None,
@@ -780,6 +814,19 @@ def __init__(
780814
self.load_and_normalise()
781815
return
782816

817+
@overrides.overrides
818+
def copy(self) -> scan_normalised_edges:
819+
clone = scan_normalised_edges(
820+
self._origin,
821+
pre_edge_domain=self._pre_edge_domain,
822+
post_edge_domain=self._post_edge_domain,
823+
pre_edge_normalisation=self.pre_edge_normalisation,
824+
post_edge_normalisation=self.post_edge_normalisation,
825+
pre_edge_level=self._pre_edge_level,
826+
post_edge_level=self._post_edge_level,
827+
)
828+
return clone
829+
783830
@overrides.overrides
784831
def _scale_from_normalisation_data(self) -> None:
785832
"""
@@ -811,7 +858,9 @@ def _scale_from_normalisation_data(self) -> None:
811858
).nonzero()
812859
# Check indexes of each tuple element
813860
if len(pre_inds[0]) == 0:
814-
raise ValueError("Pre-edge domain is empty.")
861+
raise ValueError(
862+
f"Pre-edge domain ({self.pre_edge_domain[0]} to {self.pre_edge_domain[1]}) contains no datapoints."
863+
)
815864
else:
816865
raise AttributeError(
817866
"Pre-edge domain is not defined correctly. Should be a list of indexes or the range in a tuple."
@@ -879,15 +928,17 @@ def _scale_from_normalisation_data(self) -> None:
879928
if isinstance(self.post_edge_domain, list):
880929
post_inds = self.post_edge_domain
881930
if len(post_inds) == 0:
882-
raise ValueError("Pre-edge index list is empty.")
931+
raise ValueError("Post-edge index list is empty.")
883932
elif isinstance(self.post_edge_domain, tuple):
884933
post_inds = np.where(
885934
(self.x >= self.post_edge_domain[0])
886935
& (self.x <= self.post_edge_domain[1])
887936
)
888937
# Check indexes of each tuple element
889938
if len(post_inds[0]) == 0:
890-
raise ValueError("Pre-edge domain is empty.")
939+
raise ValueError(
940+
f"Post-edge domain ({self.post_edge_domain[0]} to {self.post_edge_domain[1]}) contains no datapoints."
941+
)
891942
else:
892943
raise AttributeError(
893944
"Post-edge domain is not defined correctly. Should be a list of indexes or the range in a tuple."

0 commit comments

Comments
 (0)