Skip to content

Commit 4fef66a

Browse files
author
Bing Li
committed
updating Fit1D class
1 parent 56bc809 commit 4fef66a

File tree

2 files changed

+25
-49
lines changed

2 files changed

+25
-49
lines changed

src/tavi/data/fit.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383

8484
self.background_models: models = []
8585
self.signal_models: models = []
86-
self.parameters: Optional[Parameters] = None
86+
self._parameters: Optional[Parameters] = None
8787
self._num_backgrounds = 0
8888
self._num_signals = 0
8989
self.result: Optional[ModelResult] = None
@@ -122,13 +122,7 @@ def add_signal(
122122
):
123123
self._num_signals += 1
124124
prefix = f"s{self._num_signals}_"
125-
self.signal_models.append(
126-
Fit1D._add_model(
127-
model,
128-
prefix,
129-
nan_policy=self.nan_policy,
130-
)
131-
)
125+
self.signal_models.append(Fit1D._add_model(model, prefix, nan_policy=self.nan_policy))
132126

133127
def add_background(
134128
self,
@@ -143,13 +137,7 @@ def add_background(
143137
):
144138
self._num_backgrounds += 1
145139
prefix = f"b{self._num_backgrounds}_"
146-
self.background_models.append(
147-
Fit1D._add_model(
148-
model,
149-
prefix,
150-
nan_policy=self.nan_policy,
151-
)
152-
)
140+
self.background_models.append(Fit1D._add_model(model, prefix, nan_policy=self.nan_policy))
153141

154142
@staticmethod
155143
def _get_param_names(models) -> list[list[str]]:
@@ -166,11 +154,15 @@ def signal_param_names(self):
166154
def background_param_names(self):
167155
return Fit1D._get_param_names(self.background_models)
168156

169-
@staticmethod
170-
def _get_params(all_pars, parsms_names: list[list[str]]) -> tuple[tuple[FitParams, ...], ...]:
171-
signal_params_list = []
157+
@property
158+
def params(self) -> dict[str, tuple[FitParams, ...]]:
159+
160+
all_pars = self.guess() if self._parameters is None else self._parameters
161+
parsms_names = Fit1D._get_param_names(self.signal_models + self.background_models)
162+
params_dict = {}
172163

173164
for params in parsms_names:
165+
key = params[0].split("_")[0]
174166
params_list = []
175167
for param_name in params:
176168
param = all_pars[param_name]
@@ -185,34 +177,17 @@ def _get_params(all_pars, parsms_names: list[list[str]]) -> tuple[tuple[FitParam
185177
brute_step=param.brute_step,
186178
)
187179
)
188-
signal_params_list.append(tuple(params_list))
189-
return tuple(signal_params_list)
190-
191-
@property
192-
def signal_params(self) -> tuple[tuple[FitParams, ...], ...]:
193-
194-
pars = self.guess() if self.parameters is None else self.parameters
195-
names = Fit1D._get_param_names(self.signal_models)
196-
signal_params = Fit1D._get_params(pars, names)
197-
198-
return signal_params
199-
200-
@property
201-
def background_params(self) -> tuple[tuple[FitParams, ...], ...]:
202-
203-
pars = self.guess() if self.parameters is None else self.parameters
204-
names = Fit1D._get_param_names(self.background_models)
205-
background_params = Fit1D._get_params(pars, names)
180+
params_dict.update({key: tuple(params_list)})
206181

207-
return background_params
182+
return params_dict
208183

209184
def guess(self) -> Parameters:
210185
pars = Parameters()
211186
for signal in self.signal_models:
212187
pars += signal.guess(self.y, x=self.x)
213188
for bkg in self.background_models:
214189
pars += bkg.guess(self.y, x=self.x)
215-
self.parameters = pars
190+
self._parameters = pars
216191
return pars
217192

218193
@property

tests/test_fit.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,17 @@ def test_get_fitting_variables(fit_data):
5858
assert f1.signal_param_names == [["s1_amplitude", "s1_center", "s1_sigma"]]
5959
assert f1.background_param_names == [["b1_c"]]
6060

61-
assert len(f1.signal_params) == 1
62-
assert len(f1.signal_params[0]) == 5
61+
assert len(f1.params) == 2
6362

64-
assert f1.signal_params[0][0].name == "s1_amplitude"
65-
assert f1.signal_params[0][1].name == "s1_center"
66-
assert f1.signal_params[0][2].name == "s1_sigma"
67-
assert f1.signal_params[0][3].name == "s1_fwhm"
68-
assert f1.signal_params[0][4].name == "s1_height"
69-
assert f1.signal_params[0][4].expr == "0.3989423*s1_amplitude/max(1e-15, s1_sigma)"
70-
assert f1.background_params[0][0].name == "b1_c"
63+
s1_params = f1.params["s1"]
64+
assert len(s1_params) == 5
65+
assert s1_params[0].name == "s1_amplitude"
66+
assert s1_params[1].name == "s1_center"
67+
assert s1_params[2].name == "s1_sigma"
68+
assert s1_params[3].name == "s1_fwhm"
69+
assert s1_params[4].name == "s1_height"
70+
assert s1_params[4].expr == "0.3989423*s1_amplitude/max(1e-15, s1_sigma)"
71+
assert f1.params["b1"][0].name == "b1_c"
7172

7273

7374
def test_guess_initial(fit_data):
@@ -117,9 +118,9 @@ def test_fit_two_peak(fit_data):
117118

118119
s1_scan, PLOT = fit_data
119120

120-
f1 = Fit1D(s1_scan, fit_range=(0.0, 4.0))
121+
f1 = Fit1D(s1_scan, fit_range=(0.0, 4.0), name="scan42_fit2peaks")
121122

122-
f1.add_background(values=(0.7,))
123+
f1.add_background(model="Constant")
123124
f1.add_signal(values=(None, 3.5, 0.29), vary=(True, True, True))
124125
f1.add_signal(
125126
model="Gaussian",

0 commit comments

Comments
 (0)