Skip to content

Commit

Permalink
redesigning the interface to constrain fitting parameters in Fit1D class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Dec 2, 2024
1 parent 4fef66a commit c215dfd
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 59 deletions.
93 changes: 50 additions & 43 deletions src/tavi/data/fit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import dataclass
from typing import Literal, Optional

import numpy as np
Expand All @@ -7,16 +6,17 @@

from tavi.data.scan_data import ScanData1D

# @dataclass
# class FitParam:
# name: str
# value: Optional[float] = None
# vary: bool = True
# min: float = -np.inf
# max: float = np.inf
# expr: Optional[str] = None


@dataclass
class FitParams:
name: str
value: Optional[float] = None
vary: bool = True
min: float = -np.inf
max: float = np.inf
expr: Optional[str] = None
brute_step: Optional[float] = None
# brute_step: Optional[float] = None


# @dataclass
Expand All @@ -28,11 +28,7 @@ class FitParams:

class FitData1D(object):

def __init__(
self,
x: np.ndarray,
y: np.ndarray,
) -> None:
def __init__(self, x: np.ndarray, y: np.ndarray) -> None:

self.x = x
self.y = y
Expand Down Expand Up @@ -81,8 +77,8 @@ def __init__(
self.y: np.ndarray = data.y
self.err: Optional[np.ndarray] = data.err

self.background_models: models = []
self.signal_models: models = []
self._background_models: models = []
self._signal_models: models = []
self._parameters: Optional[Parameters] = None
self._num_backgrounds = 0
self._num_signals = 0
Expand Down Expand Up @@ -122,7 +118,7 @@ def add_signal(
):
self._num_signals += 1
prefix = f"s{self._num_signals}_"
self.signal_models.append(Fit1D._add_model(model, prefix, nan_policy=self.nan_policy))
self._signal_models.append(Fit1D._add_model(model, prefix, nan_policy=self.nan_policy))

def add_background(
self,
Expand All @@ -137,7 +133,7 @@ def add_background(
):
self._num_backgrounds += 1
prefix = f"b{self._num_backgrounds}_"
self.background_models.append(Fit1D._add_model(model, prefix, nan_policy=self.nan_policy))
self._background_models.append(Fit1D._add_model(model, prefix, nan_policy=self.nan_policy))

@staticmethod
def _get_param_names(models) -> list[list[str]]:
Expand All @@ -148,51 +144,62 @@ def _get_param_names(models) -> list[list[str]]:

@property
def signal_param_names(self):
return Fit1D._get_param_names(self.signal_models)
"""Get parameter names of all signals"""
return Fit1D._get_param_names(self._signal_models)

@property
def background_param_names(self):
return Fit1D._get_param_names(self.background_models)
"""Get parameter names of all backgrounds"""
return Fit1D._get_param_names(self._background_models)

# TODO
@property
def params(self) -> dict[str, tuple[FitParams, ...]]:
def params(self) -> dict[str, dict]:
"""Get fitting parameters as a dictionary with the model prefix being the key"""

all_pars = self.guess() if self._parameters is None else self._parameters
parsms_names = Fit1D._get_param_names(self.signal_models + self.background_models)
params_dict = {}
params_names = Fit1D._get_param_names(self._signal_models + self._background_models)

for params in parsms_names:
key = params[0].split("_")[0]
params_list = []
for param_name in params:
params_dict = {}
for names in params_names:
if len(names) < 1:
raise ValueError(f"Should have at least 1 parameter in {names}.")
prefix, _ = names[0].split("_")
param_dict = {}
for param_name in names:
param = all_pars[param_name]
params_list.append(
FitParams(
name=param.name,
value=param.value,
vary=param.vary,
min=param.min,
max=param.max,
expr=param.expr,
brute_step=param.brute_step,
)
param_dict.update(
{
"name": param.name,
"value": param.value,
"vary": param.vary,
"min": param.min,
"max": param.max,
"expr": param.expr,
}
)
params_dict.update({key: tuple(params_list)})

params_dict.update({prefix: param_dict})

return params_dict

def guess(self) -> Parameters:
"""Guess fitting parameters' values
Reutrn:
Parameters class in LMFIT"""

pars = Parameters()
for signal in self.signal_models:
for signal in self._signal_models:
pars += signal.guess(self.y, x=self.x)
for bkg in self.background_models:
for bkg in self._background_models:
pars += bkg.guess(self.y, x=self.x)
self._parameters = pars
return pars

@property
def model(self):
compposite_model = np.sum(self.signal_models + self.background_models)
"""Return the composite model of all singals and backgrounds"""
compposite_model = np.sum(self._signal_models + self._background_models)
return compposite_model

def x_to_plot(self, num_of_pts: Optional[int]):
Expand All @@ -215,4 +222,4 @@ def fit(self, pars: Parameters) -> ModelResult:

result = self.model.fit(self.y, pars, x=self.x, weights=self.err)
self.result = result
return self
return result
41 changes: 27 additions & 14 deletions src/tavi/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,32 +54,44 @@ def add_scan(self, scan_data: ScanData1D, **kwargs):
for key, val in kwargs.items():
scan_data.fmt.update({key: val})

def _add_fit_from_eval(self, fit_data: FitData1D, **kwargs):
self.fit_data.append(fit_data)
for key, val in kwargs.items():
fit_data.fmt.update({key: val})

def _add_fit_from_fitting(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100, **kwargs):
if (result := fit_data.result) is None:
raise ValueError("Fitting result is None.")
x = fit_data.x_to_plot(num_of_pts)
data = FitData1D(x=x, y=result.eval(param=result.params, x=x))
self.fit_data.append(data)
for key, val in kwargs.items():
data.fmt.update({key: val})

def add_fit(self, fit_data: Union[FitData1D, Fit1D], num_of_pts: Optional[int] = 100, **kwargs):
"""
Note:
PLOT_COMPONENTS is ignored if fit_data has the type FitData1D
"""
if isinstance(fit_data, FitData1D):
self.fit_data.append(fit_data)
for key, val in kwargs.items():
fit_data.fmt.update({key: val})
elif isinstance(fit_data, Fit1D) and (result := fit_data.result) is not None:
x = fit_data.x_to_plot(num_of_pts)
data = FitData1D(x=x, y=result.eval(param=result.params, x=x))
self.fit_data.append(data)
for key, val in kwargs.items():
data.fmt.update({key: val})
self._add_fit_from_eval(fit_data, **kwargs)
elif isinstance(fit_data, Fit1D):
self._add_fit_from_fitting(fit_data, num_of_pts, **kwargs)
else:
raise ValueError(f"Invalid input fit_data={fit_data}")

def add_fit_components(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100, **kwargs):
if isinstance(fit_data, Fit1D) and (result := fit_data.result) is not None:
x = fit_data.x_to_plot(num_of_pts)
components = result.eval_components(result.params, x=x)

num_components = len(components)
for k, v in kwargs.items():
if len(v) != num_components:
raise ValueError(
f"Length of key word argument {k}={v} dose not match the number of fitting models."
)

for i, (prefix, y) in enumerate(components.items()):
data = FitData1D(x=x, y=y)
self.fit_data.append(data)
data.fmt.update({"label": prefix[:-1]})
data.fmt.update({"label": prefix[:-1]}) # remove "_"
for key, val in kwargs.items():
data.fmt.update({key: val[i]})

Expand All @@ -92,6 +104,7 @@ def plot(self, ax):
ax.plot(data.x, data.y, **data.fmt)
else:
ax.errorbar(x=data.x, y=data.y, yerr=data.err, **data.fmt)

for fit in self.fit_data:
ax.plot(fit.x, fit.y, **fit.fmt)

Expand Down
7 changes: 5 additions & 2 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ def test_fit_single_peak_internal_model(fit_data):
f1.add_background(model="Constant")
pars = f1.guess()
fit_result = f1.fit(pars)
assert np.allclose(fit_result.redchi, 37.6, atol=1)

if PLOT:
p1 = Plot1D()
p1.add_scan(s1_scan, fmt="o", label="data")
p1.add_fit(fit_result, label="fit", color="C3", num_of_pts=50, marker="^")
p1.add_fit_components(fit_result, color=["C4", "C5"])
p1.add_fit(f1, label="fit", color="C3", num_of_pts=50, marker="^")
p1.add_fit_components(f1, color=["C4", "C5"])

fig, ax = plt.subplots()
p1.plot(ax)
Expand All @@ -121,6 +122,8 @@ def test_fit_two_peak(fit_data):
f1 = Fit1D(s1_scan, fit_range=(0.0, 4.0), name="scan42_fit2peaks")

f1.add_background(model="Constant")
f1.add_signal(model="Gaussian")

f1.add_signal(values=(None, 3.5, 0.29), vary=(True, True, True))
f1.add_signal(
model="Gaussian",
Expand Down

0 comments on commit c215dfd

Please sign in to comment.