Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Type hints in pint.fitter #1885

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG-unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ the released changes.
- Type hints in `pint.models.timing_model`
### Fixed
- Made `TimingModel.is_binary()` more robust.
- Bug in `Fitter.plot()`
### Removed
- Definition of `@cached_property` to support Python<=3.7
- The broken `data.nanograv.org` URL from the list of solar system ephemeris mirrors
144 changes: 107 additions & 37 deletions src/pint/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@

import contextlib
import copy
from typing import List, Literal, Optional
from warnings import warn
from functools import cached_property

Expand All @@ -72,6 +73,7 @@

import pint
import pint.derived_quantities
from pint.models.timing_model import TimingModel
import pint.utils
from pint.exceptions import (
ConvergenceFailure,
Expand Down Expand Up @@ -155,7 +157,13 @@ class Fitter:
``GLSFitter`` is used to compute ``chi2`` for appropriate Residuals objects.
"""

def __init__(self, toas, model, track_mode=None, residuals=None):
def __init__(
self,
toas: TOAs,
model: TimingModel,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
residuals: Optional[Residuals] = None,
):
if not set(model.free_params).issubset(model.fittable_params):
free_unfittable_params = set(model.free_params).difference(
model.fittable_params
Expand All @@ -175,7 +183,7 @@ def __init__(self, toas, model, track_mode=None, residuals=None):
# residuals were provided, we're just going to use them
self.resids_init = residuals
# probably using GLSFitter to compute a chi-squared
self.model = copy.deepcopy(self.model_init)
self.model: TimingModel = copy.deepcopy(self.model_init)
self.resids = copy.deepcopy(self.resids_init)
self.fitresult = []
self.method = None
Expand All @@ -184,8 +192,14 @@ def __init__(self, toas, model, track_mode=None, residuals=None):

@classmethod
def auto(
cls, toas, model, downhill=True, track_mode=None, residuals=None, **kwargs
):
cls,
toas: TOAs,
model: TimingModel,
downhill: bool = True,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
residuals: Optional[Residuals] = None,
**kwargs,
) -> "Fitter":
"""Automatically return the proper :class:`pint.fitter.Fitter` object depending on the TOAs and model.

In general the `downhill` fitters are to be preferred.
Expand Down Expand Up @@ -270,7 +284,7 @@ def auto(
**kwargs,
)

def fit_toas(self, maxiter=None, debug=False):
def fit_toas(self, maxiter: Optional[int] = None, debug: bool = False):
"""Run fitting operation.

This method needs to be implemented by subclasses. All implementations
Expand All @@ -279,7 +293,7 @@ def fit_toas(self, maxiter=None, debug=False):
"""
raise NotImplementedError

def get_summary(self, nodmx=False):
def get_summary(self, nodmx: bool = False) -> str:
"""Return a human-readable summary of the Fitter results.

Parameters
Expand All @@ -302,6 +316,7 @@ def get_summary(self, nodmx=False):
# First, print fit quality metrics
s = f"Fitted model using {self.method} method with {len(self.model.free_params)} free parameters to {self.toas.ntoas} TOAs\n"
if is_wideband:
self.resids_init: WidebandTOAResiduals
s += f"Prefit TOA residuals Wrms = {self.resids_init.toa.rms_weighted()}, Postfit TOA residuals Wrms = {self.resids.toa.rms_weighted()}\n"
s += f"Prefit DM residuals Wrms = {self.resids_init.dm.rms_weighted()}, Postfit DM residuals Wrms = {self.resids.dm.rms_weighted()}\n"
else:
Expand Down Expand Up @@ -405,7 +420,7 @@ def get_summary(self, nodmx=False):
s += "\n" + self.model.get_derived_params()
return s

def get_derived_params(self, returndict=False):
def get_derived_params(self, returndict: bool = False):
"""Return a string with various derived parameters from the fitted model

Parameters
Expand Down Expand Up @@ -433,7 +448,7 @@ def get_derived_params(self, returndict=False):
returndict=returndict,
)

def print_summary(self):
def print_summary(self) -> None:
"""Write a summary of the TOAs to stdout."""
print(self.get_summary())

Expand All @@ -452,41 +467,41 @@ def plot(self):
ax.set_xlabel("MJD")
ax.set_ylabel("Residuals")
try:
psr = self.model.PSR
psr = self.model["PSR"].value
except AttributeError:
psr = self.model.PSRJ
psr = self.model["PSRJ"].value
else:
psr = "Residuals"
ax.set_title(psr)
ax.grid(True)
plt.show()

def update_model(self, chi2=None):
def update_model(self, chi2: Optional[float] = None):
"""Update the model to reflect fit results and TOA properties.

This is called by ``fit_toas`` to ensure that parameters like
``START``, ``FINISH``, ``EPHEM``, and ``DMDATA`` are set in the model
to reflect the TOAs in actual use.
"""
self.model.START.value = self.toas.first_MJD
self.model.FINISH.value = self.toas.last_MJD
self.model.NTOA.value = len(self.toas)
self.model.EPHEM.value = self.toas.ephem
self.model.DMDATA.value = hasattr(self.resids, "dm")
self.model.CLOCK.value = (
self.model["START"].value = self.toas.first_MJD
self.model["FINISH"].value = self.toas.last_MJD
self.model["NTOA"].value = len(self.toas)
self.model["EPHEM"].value = self.toas.ephem
self.model["DMDATA"].value = hasattr(self.resids, "dm")
self.model["CLOCK"].value = (
f"TT({self.toas.clock_corr_info['bipm_version']})"
if self.toas.clock_corr_info["include_bipm"]
else "TT(TAI)"
)
if chi2 is not None:
# assume a fit has been done
self.model.CHI2.value = chi2
self.model.CHI2R.value = chi2 / self.resids.dof
self.model["CHI2"].value = chi2
self.model["CHI2R"].value = chi2 / self.resids.dof
if not self.is_wideband:
self.model.TRES.quantity = self.resids.rms_weighted()
self.model["TRES"].quantity = self.resids.rms_weighted()
else:
self.model.TRES.quantity = self.resids.rms_weighted()["toa"]
self.model.DMRES.quantity = self.resids.rms_weighted()["dm"]
self.model["TRES"].quantity = self.resids.rms_weighted()["toa"]
self.model["DMRES"].quantity = self.resids.rms_weighted()["dm"]

def reset_model(self):
"""Reset the current model to the initial model."""
Expand All @@ -501,7 +516,7 @@ def update_resids(self):
"""
self.resids = self.make_resids(self.model)

def make_resids(self, model):
def make_resids(self, model: TimingModel):
return Residuals(toas=self.toas, model=model, track_mode=self.track_mode)

def get_designmatrix(self):
Expand Down Expand Up @@ -831,10 +846,12 @@ class ModelState:
These objects should be regarded as immutable but lazily evaluated.
"""

def __init__(self, fitter, model):
def __init__(self, fitter: Fitter, model: TimingModel):
self.fitter = fitter
self.model = model

self.params: List[str]

@cached_property
def resids(self):
try:
Expand All @@ -855,7 +872,7 @@ def step(self):
raise NotImplementedError

@cached_property
def parameter_covariance_matrix(self):
def parameter_covariance_matrix(self) -> CovarianceMatrix:
raise NotImplementedError

@property
Expand All @@ -878,7 +895,7 @@ def take_step_model(self, step, lambda_=1):
try:
with contextlib.suppress(ValueError):
log.trace(f"Adjusting {getattr(self.model, p)} by {s}")
pm = getattr(new_model, p)
pm = new_model[p]
if pm.value is None:
pm.value = 0
pm.value += s
Expand All @@ -890,7 +907,7 @@ def take_step_model(self, step, lambda_=1):
log.warning(f"Unexpected parameter {p}")
return new_model

def take_step(self, step, lambda_):
def take_step(self, step, lambda_) -> "ModelState":
"""Return a new state moved by lambda_*step."""
raise NotImplementedError

Expand All @@ -906,7 +923,13 @@ class DownhillFitter(Fitter):
for correlated or uncorrelated TOA errors and narrowband or wideband TOAs.
"""

def __init__(self, toas, model, track_mode=None, residuals=None):
def __init__(
self,
toas: TOAs,
model: TimingModel,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
residuals=None,
):
super().__init__(
toas=toas, model=model, residuals=residuals, track_mode=track_mode
)
Expand Down Expand Up @@ -1005,7 +1028,7 @@ def _fit_toas(
# I don't know why this fails with multiprocessing, but bypass if it does
with contextlib.suppress(ValueError):
log.trace(f"Setting {getattr(self.model, p)} uncertainty to {e}")
pm = getattr(self.model, p)
pm = self.model[p]
except AttributeError:
if p != "Offset":
log.warning(f"Unexpected parameter {p}")
Expand Down Expand Up @@ -1123,6 +1146,9 @@ def fit_toas(
def fac(self):
return self.current_state.fac

def create_state(self) -> ModelState:
raise NotImplementedError

def _get_free_noise_params(self):
"""Returns a list of all free noise parameters."""
return [
Expand Down Expand Up @@ -1299,7 +1325,13 @@ class DownhillWLSFitter(DownhillFitter):
or :class:`pint.fitter.DownhillFitter`.
"""

def __init__(self, toas, model, track_mode=None, residuals=None):
def __init__(
self,
toas,
model,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
residuals=None,
):
if model.has_correlated_errors:
raise CorrelatedErrors(model)
super().__init__(
Expand Down Expand Up @@ -1449,7 +1481,13 @@ class DownhillGLSFitter(DownhillFitter):

# FIXME: do something clever to efficiently compute chi-squared

def __init__(self, toas, model, track_mode=None, residuals=None):
def __init__(
self,
toas,
model,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
residuals=None,
):
if not model.has_correlated_errors:
log.info(
"Model does not appear to have correlated errors so the GLS fitter "
Expand Down Expand Up @@ -1731,7 +1769,14 @@ class WidebandDownhillFitter(DownhillFitter):

# FIXME: do something clever to efficiently compute chi-squared

def __init__(self, toas, model, track_mode=None, residuals=None, add_args=None):
def __init__(
self,
toas,
model,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
residuals=None,
add_args=None,
):
self.method = "downhill_wideband"
self.full_cov = False
self.threshold = 0
Expand Down Expand Up @@ -1820,7 +1865,13 @@ class PowellFitter(Fitter):
may serve as an example of how to write your own fitting procedure.
"""

def __init__(self, toas, model, track_mode=None, residuals=None):
def __init__(
self,
toas,
model,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
residuals=None,
):
super().__init__(toas, model, residuals=residuals, track_mode=track_mode)
self.method = "Powell"

Expand Down Expand Up @@ -1866,7 +1917,13 @@ class WLSFitter(Fitter):
close enough that the derivatives are a good approximation.
"""

def __init__(self, toas, model, track_mode=None, residuals=None):
def __init__(
self,
toas,
model,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
residuals=None,
):
super().__init__(
toas=toas, model=model, residuals=residuals, track_mode=track_mode
)
Expand Down Expand Up @@ -2005,7 +2062,13 @@ class GLSFitter(Fitter):
the data covariance matrix.
"""

def __init__(self, toas=None, model=None, track_mode=None, residuals=None):
def __init__(
self,
toas=None,
model=None,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
residuals=None,
):
super().__init__(
toas=toas, model=model, residuals=residuals, track_mode=track_mode
)
Expand Down Expand Up @@ -2226,7 +2289,7 @@ def __init__(
fit_data,
model,
fit_data_names=["toa", "dm"],
track_mode=None,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
additional_args={},
):
self.model_init = model
Expand Down Expand Up @@ -2681,7 +2744,14 @@ class WidebandLMFitter(LMFitter):
Unfortunately it doesn't.
"""

def __init__(self, toas, model, track_mode=None, residuals=None, add_args=None):
def __init__(
self,
toas: TOAs,
model: TimingModel,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
residuals: Optional[Residuals] = None,
add_args=None,
):
self.method = "downhill_wideband"
self.full_cov = False
self.threshold = 0
Expand Down
6 changes: 2 additions & 4 deletions src/pint/models/timing_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2464,8 +2464,7 @@ def compare(
value2[pn] = str(otherpar.quantity)
if otherpar.quantity != par.quantity:
log.info(
"Parameter %s not fit, but has changed between these models"
% par.name
f"Parameter {par.name} not fit, but has changed between these models"
)
else:
value2[pn] = "Missing"
Expand Down Expand Up @@ -2566,8 +2565,7 @@ def compare(
)
else:
log.warning(
"Parameter %s not fit, but has changed between these models"
% par.name
f"Parameter {par.name} not fit, but has changed between these models"
)
modifier[pn].append("change")
if (
Expand Down
4 changes: 3 additions & 1 deletion src/pint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2865,7 +2865,9 @@ def get_unit(parname: str) -> u.Unit:
return ac.param_to_unit(parname)


def normalize_designmatrix(M, params):
def normalize_designmatrix(
M: np.ndarray, params: List[str]
) -> Tuple[np.ndarray, np.ndarray]:
"""Normalize each row of the design matrix.

This is used while computing the GLS chi2 and the GLS fitting step. The
Expand Down