Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Update example to use spec object and units #163

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 92 additions & 60 deletions examples/fitting_simulated_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
import numpy as np
from matplotlib.colors import LogNorm

import astropy.units as u
from astropy.modeling import fitting
from astropy.modeling.functional_models import Gaussian1D, Linear1D
from astropy.visualization import quantity_support

from sunkit_spex.data.simulated_data import simulate_square_response_matrix
from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func
from sunkit_spex.fitting.optimizer_tools.minimizer_tools import scipy_minimize
from sunkit_spex.fitting.statistics.gaussian import chi_squared
from sunkit_spex.models.instrument_response import MatrixModel
from sunkit_spex.models.models import GaussianModel, StraightLineModel
from sunkit_spex.spectrum import Spectrum
from sunkit_spex.spectrum.spectrum import SpectralAxis

#####################################################
#
Expand All @@ -37,111 +41,125 @@

start, inc = 1.6, 0.04
stop = 80 + inc / 2
ph_energies = np.arange(start, stop, inc)
ph_energies = np.arange(start, stop, inc) * u.keV

#####################################################
#
# Let's start making a simulated photon spectrum

sim_cont = {"slope": -1, "intercept": 100}
sim_line = {"amplitude": 100, "mean": 30, "stddev": 2}
sim_cont = {"slope": -1 * u.ph / u.keV, "intercept": 100 * u.ph}
sim_line = {"amplitude": 100 * u.ph, "mean": 30 * u.keV, "stddev": 2 * u.keV}
# use a straight line model for a continuum, Gaussian for a line
ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line)
ph_model = Linear1D(**sim_cont) + Gaussian1D(**sim_line)

plt.figure()
plt.plot(ph_energies, ph_model(ph_energies))
plt.xlabel("Energy [keV]")
plt.ylabel("ph s$^{-1}$ cm$^{-2}$ keV$^{-1}$")
plt.title("Simulated Photon Spectrum")
plt.show()
with quantity_support():
plt.figure()
plt.plot(ph_energies, ph_model(ph_energies))
plt.xlabel(f"Energy [{ph_energies.unit}]")
plt.title("Simulated Photon Spectrum")
plt.show()

#####################################################
#
# Now want a response matrix

srm = simulate_square_response_matrix(ph_energies.size)
srm_model = MatrixModel(matrix=srm)

plt.figure()
plt.imshow(
srm, origin="lower", extent=[ph_energies[0], ph_energies[-1], ph_energies[0], ph_energies[-1]], norm=LogNorm()
srm_model = MatrixModel(
matrix=srm, input_axis=SpectralAxis(ph_energies), output_axis=SpectralAxis(ph_energies), c=1 * u.ct / u.ph
)
plt.ylabel("Photon Energies [keV]")
plt.xlabel("Count Energies [keV]")
plt.title("Simulated SRM")
plt.show()
srm_model.input_units = {"x": u.ph}

with quantity_support():
plt.figure()
plt.imshow(
srm_model.matrix,
origin="lower",
extent=(
srm_model.inputs_axis[0].value,
srm_model.inputs_axis[-1].value,
srm_model.output_axis[0].value,
srm_model.output_axis[-1].value,
),
norm=LogNorm(),
)
plt.ylabel(f"Photon Energies [{srm_model.inputs_axis.unit}]")
plt.xlabel(f"Count Energies [{srm_model.output_axis.unit}]")
plt.title("Simulated SRM")
plt.show()

#####################################################
#
# Start work on a count model

sim_gauss = {"amplitude": 70, "mean": 40, "stddev": 2}
sim_gauss = {"amplitude": 70 * u.ct, "mean": 40 * u.keV, "stddev": 2 * u.keV}
# the brackets are very necessary
ct_model = (ph_model | srm_model) + GaussianModel(**sim_gauss)
ct_model = (ph_model | srm_model) + Gaussian1D(**sim_gauss)

#####################################################
#
# Generate simulated count data to (almost) fit

sim_count_model = ct_model(ph_energies)
sim_count_model = ct_model(SpectralAxis(ph_energies))

#####################################################
#
# Add some noise
np_rand = np.random.default_rng(seed=10)
sim_count_model_wn = sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model)
sim_count_model_wn = (
sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model.value) * u.ct
)

obs_spec = Spectrum(sim_count_model_wn.reshape(-1), spectral_axis=ph_energies)

#####################################################
#
# Can plot all the different components in the simulated count spectrum

plt.figure()
plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features")
plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature")
plt.plot(ph_energies, sim_count_model, label="total sim. spectrum")
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise", lw=0.5)
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum")
plt.legend()
with quantity_support():
plt.figure()
plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features")
plt.plot(ph_energies, Gaussian1D(**sim_gauss)(ph_energies), label="gaussian feature")
plt.plot(ph_energies, sim_count_model, label="total sim. spectrum")
plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5)
plt.xlabel(f"Energy [{ph_energies.unit}]")
plt.title("Simulated Count Spectrum")
plt.legend()

plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold")
plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold")
plt.show()
plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold")
plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold")
plt.show()

#####################################################
#
# Now we have the simulated data, let's start setting up to fit it
#
# Get some initial guesses that are off from the simulated data above

guess_cont = {"slope": -0.5, "intercept": 80}
guess_line = {"amplitude": 150, "mean": 32, "stddev": 5}
guess_gauss = {"amplitude": 350, "mean": 39, "stddev": 0.5}
guess_cont = {"slope": -0.5 * u.ph / u.keV, "intercept": 80 * u.ph}
guess_line = {"amplitude": 150 * u.ph, "mean": 32 * u.keV, "stddev": 5 * u.keV}
guess_gauss = {"amplitude": 350 * u.ct, "mean": 39 * u.keV, "stddev": 0.5 * u.keV}

#####################################################
#
# Define a new model since we have a rough idea of the mode we should use

ph_mod_4fit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line)
count_model_4fit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss)
ph_mod_4fit = Linear1D(**guess_cont) + Gaussian1D(**guess_line)
count_model_4fit = (ph_mod_4fit | srm_model) + Gaussian1D(**guess_gauss)

#####################################################
#
# Let's fit the simulated data and plot the result

opt_res = scipy_minimize(
minimize_func, count_model_4fit.parameters, (sim_count_model_wn, ph_energies, count_model_4fit, chi_squared)
)
opt_res = scipy_minimize(minimize_func, count_model_4fit.parameters, (obs_spec, count_model_4fit, chi_squared))

plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies, *opt_res.x), ls=":", label="model fit")
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum Fit with Scipy")
plt.legend()
plt.show()
with quantity_support():
plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit")
plt.xlabel(f"Energy [{ph_energies.unit}]")
plt.title("Simulated Count Spectrum Fit with Scipy")
plt.legend()
plt.show()


#####################################################
Expand All @@ -150,16 +168,26 @@
#
# Try and ensure we start fresh with new model definitions

ph_mod_4astropyfit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line)
count_model_4astropyfit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss)
ph_mod_4astropyfit = Linear1D(**guess_cont) + Gaussian1D(**guess_line)
for m in ph_mod_4astropyfit:
m.output_units = {"y": u.ph}
cgauss = Gaussian1D(**guess_gauss)
cgauss.output_units = {"y": u.ct}
srm_model.output_units = {"y": u.ct}

astropy_fit = fitting.LevMarLSQFitter()
count_model_4astropyfit = (ph_mod_4astropyfit | srm_model) + cgauss

astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, sim_count_model_wn)
astropy_fit = fitting.LevMarLSQFitter()
astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, obs_spec.data << obs_spec.unit)

plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies, astropy_fitted_result(ph_energies), ls=":", label="model fit")
plt.plot(
ph_energies,
count_model_4astropyfit.evaluate(ph_energies.value, *astropy_fitted_result.parameters),
ls=":",
label="model fit",
)
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum Fit with Astropy")
Expand All @@ -172,10 +200,14 @@

plt.figure(layout="constrained")

row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss))
row_labels = (
tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + ("C",) + tuple(f"{p}2" for p in tuple(sim_gauss))
)
column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit")
true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values()))
guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values()))
true_vals = tuple(sim_cont.values()) + tuple(sim_line.values()) + (1 * u.m,) + tuple(sim_gauss.values())
true_vals = [t.value for t in true_vals]
guess_vals = tuple(guess_cont.values()) + tuple(guess_line.values()) + (1 * u.m,) + tuple(guess_gauss.values())
guess_vals = [g.value for g in guess_vals]
scipy_vals = opt_res.x
astropy_vals = astropy_fitted_result.parameters
cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ authors = [
{ name = "The SunPy Community", email = "[email protected]" },
]
dependencies = [
"astropy @ git+https://github.com/samaloney/[email protected]",
"corner>=2.2",
"emcee>=3.1",
"matplotlib>=3.7",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__all__ = ["minimize_func"]


def minimize_func(params, data_y, model_x, model_func, statistic_func):
def minimize_func(params, obs_spec, model_func, statistic_func):
"""
Minimization function.

Expand All @@ -32,5 +32,5 @@ def minimize_func(params, data_y, model_x, model_func, statistic_func):
`float`
The value to be optimized that compares the model to the data.
"""
model_y = model_func.evaluate(model_x, *params)
return statistic_func(data_y, model_y)
model_y = model_func.evaluate(obs_spec._spectral_axis.value, *params)
return statistic_func(obs_spec.data, model_y)
13 changes: 7 additions & 6 deletions sunkit_spex/fitting/tests/test_objective_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,34 @@

import numpy as np

import astropy.units as u

from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func
from sunkit_spex.fitting.statistics.gaussian import chi_squared
from sunkit_spex.models.models import StraightLineModel
from sunkit_spex.spectrum import Spectrum


def test_minimize_func():
"""Test the `minimize_func` function against known outputs."""
sim_x0 = np.arange(3)
sim_x0 = np.arange(3) * u.keV
model_params0 = {"slope": 1, "intercept": 0}
sim_model0 = StraightLineModel(**model_params0)
sim_data0 = sim_model0.evaluate(sim_x0, **model_params0)
res0 = minimize_func(
params=tuple(model_params0.values()),
data_y=sim_data0,
model_x=sim_x0,
obs_spec=Spectrum(sim_data0, spectral_axis=sim_x0),
model_func=sim_model0,
statistic_func=chi_squared,
)

sim_x1 = np.arange(3)
sim_x1 = np.arange(3) * u.keV
model_params1 = {"slope": 1, "intercept": 0}
sim_model1 = StraightLineModel(**model_params1)
sim_data1 = sim_model1.evaluate(sim_x1, **model_params1)[::-1]
res1 = minimize_func(
params=tuple(model_params1.values()),
data_y=sim_data1,
model_x=sim_x1,
obs_spec=Spectrum(sim_data1, spectral_axis=sim_x1),
model_func=sim_model1,
statistic_func=chi_squared,
)
Expand Down
15 changes: 13 additions & 2 deletions sunkit_spex/fitting/tests/test_optimizer_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import numpy as np
from numpy.testing import assert_allclose

import astropy.units as u

from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func
from sunkit_spex.fitting.optimizer_tools.minimizer_tools import scipy_minimize
from sunkit_spex.fitting.statistics.gaussian import chi_squared
from sunkit_spex.models.models import StraightLineModel
from sunkit_spex.spectrum import Spectrum


def test_scipy_minimize():
Expand All @@ -18,14 +21,22 @@ def test_scipy_minimize():
model_param_values0 = tuple(model_params0.values())
sim_model0 = StraightLineModel(**model_params0)
sim_data0 = sim_model0.evaluate(sim_x0, **model_params0)
opt_res0 = scipy_minimize(minimize_func, model_param_values0, (sim_data0, sim_x0, sim_model0, chi_squared))
opt_res0 = scipy_minimize(
minimize_func,
model_param_values0,
(Spectrum(sim_data0 * u.dimensionless_unscaled, spectral_axis=sim_x0 * u.keV), sim_model0, chi_squared),
)

sim_x1 = np.arange(3)
model_params1 = {"slope": 8, "intercept": 5}
model_param_values1 = tuple(model_params1.values())
sim_model1 = StraightLineModel(**model_params1)
sim_data1 = sim_model1.evaluate(sim_x1, **model_params1)
opt_res1 = scipy_minimize(minimize_func, model_param_values1, (sim_data1, sim_x1, sim_model1, chi_squared))
opt_res1 = scipy_minimize(
minimize_func,
model_param_values1,
(Spectrum(sim_data1 * u.dimensionless_unscaled, spectral_axis=sim_x1 * u.keV), sim_model1, chi_squared),
)

assert_allclose(opt_res0.x, model_param_values0, rtol=1e-3)
assert_allclose(opt_res1.x, model_param_values1, rtol=1e-3)
29 changes: 24 additions & 5 deletions sunkit_spex/models/instrument_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,29 @@


class MatrixModel(Fittable1DModel):
def __init__(self, matrix):
self.matrix = Parameter(default=matrix, description="The matrix with which to multiply the input.", fixed=True)
super().__init__()
# input_units = {"x": u.ph}
c = Parameter(fixed=True)

def evaluate(self, model_y):
def __init__(self, matrix, input_axis, output_axis, c):
self._input_units = None
self.inputs_axis = input_axis
self.output_axis = output_axis
self.matrix = matrix
super().__init__(c)
# self.matrix.value = self.matrix.value.flatten()

def evaluate(self, x, c):
# Requires input must have a specific dimensionality
return model_y @ self.matrix
return x @ self.matrix * c

@property
def input_units(self):
return self._input_units

@input_units.setter
def input_units(self, units):
self._input_units = units

@staticmethod
def _parameter_units_for_data_units(inputs_unit, outputs_unit):
return {"c": outputs_unit["y"] / inputs_unit["x"]}
2 changes: 1 addition & 1 deletion sunkit_spex/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_MatrixModel():
"""Test the matrix model contents and compound model behaviour."""
size0 = 3
srm0 = simulate_square_response_matrix(size0)
srm_model0 = MatrixModel(matrix=srm0)
srm_model0 = MatrixModel(matrix=srm0, c=1, input_axis=None, output_axis=None)

assert_array_equal(srm_model0.matrix, srm0)

Expand Down
Loading