diff --git a/.github/workflows/ci_test.yml b/.github/workflows/ci_test.yml index f43615c9c..74c8ba9f5 100644 --- a/.github/workflows/ci_test.yml +++ b/.github/workflows/ci_test.yml @@ -29,6 +29,9 @@ jobs: - os: ubuntu-latest python: '3.12' tox_env: 'black' + - os: ubuntu-latest + python: '3.12' + tox_env: 'mypy' - os: ubuntu-latest python: '3.12' tox_env: 'py312-test-cov' diff --git a/.gitignore b/.gitignore index c95fdb4b1..549cd5bea 100644 --- a/.gitignore +++ b/.gitignore @@ -139,3 +139,6 @@ tests/datafile/par_*.par tests/datafile/fake_toas.tim tests/datafile/*.converted.par tests/datafile/_test_pintempo.out + +# mypy +.mypy_cache \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c2c4b8cc..de5ce1e24 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,6 +6,6 @@ repos: - id: check-merge-conflict - id: check-symlinks - repo: https://github.com/psf/black - rev: 24.2.0 + rev: 23.12.1 hooks: - id: black diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..34e80cf34 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,56 @@ + +[mypy] +warn_unused_configs = True +files = src/pint +# 3.8 causes a problem with some versions of matplotlib +python_version = 3.9 +warn_unreachable = True +warn_return_any = True +local_partial_types = True +no_implicit_reexport = True +strict_equality = True + +[mypy-pint.observatory.*] +allow_untyped_globals = True +warn_unreachable = False +warn_return_any = False + +[mypy-pint.extern.*] +; external code, don't worry about it +ignore_errors = True + +; Other libraries that might not have type information +; some of them seem like they should? maybe we need new versions? + +[mypy-astropy.*] +ignore_missing_imports = True + +[mypy-erfa] +ignore_missing_imports = True + +[mypy-emcee] +ignore_missing_imports = True + +[mypy-jplephem] +ignore_missing_imports = True + +[mypy-numdifftools] +ignore_missing_imports = True + +[mypy-pylab] +ignore_missing_imports = True + +[mypy-scipy.*] +ignore_missing_imports = True + +[mypy-uncertainties.*] +ignore_missing_imports = True + +[mypy-fftfit] +ignore_missing_imports = True + +[mypy-pathos.*] +ignore_missing_imports = True + +[mypy-corner] +ignore_missing_imports = True \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt index a00511840..90348b340 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -41,3 +41,7 @@ loguru gprof2dot py-cpuinfo pytest-xdist +mypy==1.8.0 +GitPython +types-setuptools +types-tqdm diff --git a/src/pint/derived_quantities.py b/src/pint/derived_quantities.py index 104b73d8d..a6853fd75 100644 --- a/src/pint/derived_quantities.py +++ b/src/pint/derived_quantities.py @@ -1,5 +1,4 @@ -"""Functions to compute various derived quantities from pulsar spin parameters, masses, etc. -""" +"""Functions to compute various derived quantities from pulsar spin parameters, masses, etc.""" import astropy.constants as const import astropy.units as u @@ -34,7 +33,7 @@ p=[u.Hz, u.s], pd=[u.Hz / u.s, u.s / u.s], pdd=[u.Hz / u.s**2, u.s / u.s**2] ) def p_to_f(p, pd, pdd=None): - """Converts P, Pdot to F, Fdot (or vice versa) + r"""Converts P, Pdot to F, Fdot (or vice versa) Convert period, period derivative and period second derivative (if supplied) to the equivalent frequency counterparts. @@ -81,7 +80,7 @@ def p_to_f(p, pd, pdd=None): pdorfderr=[u.Hz / u.s, u.s / u.s], ) def pferrs(porf, porferr, pdorfd=None, pdorfderr=None): - """Convert P, Pdot to F, Fdot with uncertainties (or vice versa). + r"""Convert P, Pdot to F, Fdot with uncertainties (or vice versa). Calculate the period or frequency errors and the Pdot or fdot errors from the opposite ones. @@ -127,9 +126,11 @@ def pferrs(porf, porferr, pdorfd=None, pdorfderr=None): return [forp, forperr, fdorpd, fdorpderr] -@u.quantity_input(fo=u.Hz) -def pulsar_age(f: u.Hz, fdot: u.Hz / u.s, n=3, fo=1e99 * u.Hz): - """Compute pulsar characteristic age +@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s, fo=u.Hz) +def pulsar_age( + f: u.Quantity, fdot: u.Quantity, n: float = 3, fo: u.Quantity = 1e99 * u.Hz +): + r"""Compute pulsar characteristic age Return the age of a pulsar given the spin frequency and frequency derivative. By default, the characteristic age @@ -165,14 +166,16 @@ def pulsar_age(f: u.Hz, fdot: u.Hz / u.s, n=3, fo=1e99 * u.Hz): .. math:: - \\tau = \\frac{f}{(n-1)\dot f}\\left(1-\\left(\\frac{f}{f_0}\\right)^{n-1}\\right) + \tau = \frac{f}{(n-1)\dot f}\left(1-\left(\frac{f}{f_0}\right)^{n-1}\right) """ return (-f / ((n - 1.0) * fdot) * (1.0 - (f / fo) ** (n - 1.0))).to(u.yr) -@u.quantity_input(I=u.g * u.cm**2) -def pulsar_edot(f: u.Hz, fdot: u.Hz / u.s, I=1.0e45 * u.g * u.cm**2): - """Compute pulsar spindown energy loss rate +@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s, I=u.g * u.cm**2) +def pulsar_edot( + f: u.Quantity, fdot: u.Quantity, I: u.Quantity = 1.0e45 * u.g * u.cm**2 +): + r"""Compute pulsar spindown energy loss rate Return the pulsar `Edot` (:math:`\dot E`, in erg/s) given the spin frequency `f` and frequency derivative `fdot`. The NS moment of inertia is assumed to be @@ -206,9 +209,9 @@ def pulsar_edot(f: u.Hz, fdot: u.Hz / u.s, I=1.0e45 * u.g * u.cm**2): return (-4.0 * np.pi**2 * I * f * fdot).to(u.erg / u.s) -@u.quantity_input -def pulsar_B(f: u.Hz, fdot: u.Hz / u.s): - """Compute pulsar surface magnetic field +@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s) +def pulsar_B(f: u.Quantity, fdot: u.Quantity): + r"""Compute pulsar surface magnetic field Return the estimated pulsar surface magnetic field strength given the spin frequency and frequency derivative. @@ -241,9 +244,9 @@ def pulsar_B(f: u.Hz, fdot: u.Hz / u.s): return 3.2e19 * u.G * np.sqrt(-fdot.to_value(u.Hz / u.s) / f.to_value(u.Hz) ** 3.0) -@u.quantity_input -def pulsar_B_lightcyl(f: u.Hz, fdot: u.Hz / u.s): - """Compute pulsar magnetic field at the light cylinder +@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s) +def pulsar_B_lightcyl(f: u.Quantity, fdot: u.Quantity): + r"""Compute pulsar magnetic field at the light cylinder Return the estimated pulsar magnetic field strength at the light cylinder given the spin frequency and @@ -285,7 +288,7 @@ def pulsar_B_lightcyl(f: u.Hz, fdot: u.Hz / u.s): @u.quantity_input def mass_funct(pb: u.d, x: u.cm): - """Compute binary mass function from period and semi-major axis + r"""Compute binary mass function from period and semi-major axis Can handle scalar or array inputs. @@ -326,7 +329,7 @@ def mass_funct(pb: u.d, x: u.cm): @u.quantity_input def mass_funct2(mp: u.Msun, mc: u.Msun, i: u.deg): - """Compute binary mass function from masses and inclination + r"""Compute binary mass function from masses and inclination Can handle scalar or array inputs. @@ -371,61 +374,63 @@ def mass_funct2(mp: u.Msun, mc: u.Msun, i: u.deg): @u.quantity_input def pulsar_mass(pb: u.d, x: u.cm, mc: u.Msun, i: u.deg): - """Compute pulsar mass from orbital parameters - - Return the pulsar mass (in solar mass units) for a binary. - Can handle scalar or array inputs. - - Parameters - ---------- - pb : astropy.units.Quantity - Binary orbital period - x : astropy.units.Quantity - Projected pulsar semi-major axis (aka ASINI) in ``pint.ls`` - mc : astropy.units.Quantity - Companion mass in ``u.solMass`` - i : astropy.coordinates.Angle or astropy.units.Quantity - Inclination angle, in ``u.deg`` or ``u.rad`` - - Returns - ------- - mass : astropy.units.Quantity - In ``u.solMass`` - - Raises - ------ - astropy.units.UnitsError - If the input data are not appropriate quantities - TypeError - If the input data are not quantities - - Example - ------- - >>> import pint - >>> import pint.derived_quantities - >>> from astropy import units as u - >>> print(pint.derived_quantities.pulsar_mass(2*u.hr, .2*pint.ls, 0.5*u.Msun, 60*u.deg)) - 7.6018341985817885 solMass - - - Notes - ------- - This forms a quadratic equation of the form: - :math:`a M_p^2 + b M_p + c = 0`` - - with: - - - :math:`a = f(P_b,x)` (the mass function) - - :math:`b = 2 f(P_b,x) M_c` - - :math:`c = f(P_b,x) M_c^2 - M_c\sin^3 i` - - except the discriminant simplifies to: - :math:`4f(P_b,x) M_c^3 \sin^3 i` - - solve it directly - this has to be the positive branch of the quadratic - because the vertex is at :math:`-M_c`, so - the negative branch will always be < 0 + r"""Compute pulsar mass from orbital parameters + + Return the pulsar mass (in solar mass units) for a binary. + Can handle scalar or array inputs. + + Parameters + ---------- + pb : astropy.units.Quantity + Binary orbital period + x : astropy.units.Quantity + Projected pulsar semi-major axis (aka ASINI) in ``pint.ls`` + mc : astropy.units.Quantit[mypy-pint.templates.*] + ; ignore_errors = True + y + Companion mass in ``u.solMass`` + i : astropy.coordinates.Angle or astropy.units.Quantity + Inclination angle, in ``u.deg`` or ``u.rad`` + + Returns + ------- + mass : astropy.units.Quantity + In ``u.solMass`` + + Raises + ------ + astropy.units.UnitsError + If the input data are not appropriate quantities + TypeError + If the input data are not quantities + + Example + ------- + >>> import pint + >>> import pint.derived_quantities + >>> from astropy import units as u + >>> print(pint.derived_quantities.pulsar_mass(2*u.hr, .2*pint.ls, 0.5*u.Msun, 60*u.deg)) + 7.6018341985817885 solMass + + + Notes + ------- + This forms a quadratic equation of the form: + :math:`a M_p^2 + b M_p + c = 0`` + + with: + + - :math:`a = f(P_b,x)` (the mass function) + - :math:`b = 2 f(P_b,x) M_c` + - :math:`c = f(P_b,x) M_c^2 - M_c\sin^3 i` + + except the discriminant simplifies to: + :math:`4f(P_b,x) M_c^3 \sin^3 i` + + solve it directly + this has to be the positive branch of the quadratic + because the vertex is at :math:`-M_c`, so + the negative branch will always be < 0 """ massfunct = mass_funct(pb, x) @@ -438,7 +443,7 @@ def pulsar_mass(pb: u.d, x: u.cm, mc: u.Msun, i: u.deg): @u.quantity_input(inc=u.deg, mpsr=u.solMass) def companion_mass(pb: u.d, x: u.cm, i=60.0 * u.deg, mp=1.4 * u.solMass): - """Commpute the companion mass from the orbital parameters + r"""Commpute the companion mass from the orbital parameters Compute companion mass for a binary system from orbital mechanics, not Shapiro delay. @@ -542,7 +547,7 @@ def companion_mass(pb: u.d, x: u.cm, i=60.0 * u.deg, mp=1.4 * u.solMass): @u.quantity_input def pbdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): - """Post-Keplerian orbital decay pbdot, assuming general relativity. + r"""Post-Keplerian orbital decay pbdot, assuming general relativity. pbdot (:math:`\dot P_B`) is the change in the binary orbital period due to emission of gravitational waves. @@ -605,7 +610,7 @@ def pbdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): @u.quantity_input def gamma(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): - """Post-Keplerian time dilation and gravitational redshift gamma, assuming general relativity. + r"""Post-Keplerian time dilation and gravitational redshift gamma, assuming general relativity. gamma (:math:`\gamma`) is the amplitude of the modification in arrival times caused by the varying gravitational redshift of the companion and time dilation in an elliptical orbit. The time delay is @@ -661,7 +666,7 @@ def gamma(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): @u.quantity_input def omdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): - """Post-Keplerian longitude of periastron precession rate omdot, assuming general relativity. + r"""Post-Keplerian longitude of periastron precession rate omdot, assuming general relativity. omdot (:math:`\dot \omega`) is the relativistic advance of periastron. Can handle scalar or array inputs. @@ -716,7 +721,7 @@ def omdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): @u.quantity_input def sini(mp: u.Msun, mc: u.Msun, pb: u.d, x: u.cm): - """Post-Keplerian sine of inclination, assuming general relativity. + r"""Post-Keplerian sine of inclination, assuming general relativity. Can handle scalar or array inputs. @@ -770,7 +775,7 @@ def sini(mp: u.Msun, mc: u.Msun, pb: u.d, x: u.cm): @u.quantity_input def dr(mp: u.Msun, mc: u.Msun, pb: u.d): - """Post-Keplerian Roemer delay term + r"""Post-Keplerian Roemer delay term dr (:math:`\delta_r`) is part of the relativistic deformation of the orbit @@ -820,9 +825,9 @@ def dr(mp: u.Msun, mc: u.Msun, pb: u.d): @u.quantity_input def dth(mp: u.Msun, mc: u.Msun, pb: u.d): - """Post-Keplerian Roemer delay term + r"""Post-Keplerian Roemer delay term - dth (:math:`\delta_{\\theta}`) is part of the relativistic deformation of the orbit + dth (:math:`\delta_{\theta}`) is part of the relativistic deformation of the orbit Parameters ---------- @@ -850,8 +855,8 @@ def dth(mp: u.Msun, mc: u.Msun, pb: u.d): .. math:: - \delta_{\\theta} = T_{\odot}^{2/3} \\left(\\frac{P_b}{2\pi}\\right)^{2/3} - \\frac{3.5 m_p^2+6 m_p m_c +2m_c^2}{(m_p+m_c)^{4/3}} + \delta_{\theta} = T_{\odot}^{2/3} \left(\frac{P_b}{2\pi}\right)^{2/3} + \frac{3.5 m_p^2+6 m_p m_c +2m_c^2}{(m_p+m_c)^{4/3}} with :math:`T_\odot = GM_\odot c^{-3}`. @@ -868,9 +873,13 @@ def dth(mp: u.Msun, mc: u.Msun, pb: u.d): ).decompose() -@u.quantity_input -def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled): - """Determine total mass from Post-Keplerian longitude of periastron precession rate omdot, +@u.quantity_input(omdot=u.deg / u.yr, pb=u.d, e=u.dimensionless_unscaled) +def omdot_to_mtot( + omdot: u.Quantity, + pb: u.Quantity, + e: u.Quantity, +): + r"""Determine total mass from Post-Keplerian longitude of periastron precession rate omdot, assuming general relativity. omdot (:math:`\dot \omega`) is the relativistic advance of periastron. It relates to the total @@ -904,8 +913,8 @@ def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled): .. math:: - \dot \omega = 3T_{\odot}^{2/3} \\left(\\frac{P_b}{2\pi}\\right)^{-5/3} - \\frac{1}{1-e^2}(m_p+m_c)^{2/3} + \dot \omega = 3T_{\odot}^{2/3} \left(\frac{P_b}{2\pi}\right)^{-5/3} + \frac{1}{1-e^2}(m_p+m_c)^{2/3} to calculate :math:`m_{\\rm tot} = m_p + m_c`, with :math:`T_\odot = GM_\odot c^{-3}`. @@ -916,14 +925,12 @@ def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled): """ return ( ( - ( - omdot - / ( - 3 - * (const.G / const.c**3) ** (2.0 / 3) - * (pb / (2 * np.pi)) ** (-5.0 / 3) - * (1 - e**2) ** (-1) - ) + omdot + / ( + 3 + * (const.G / const.c**3) ** (2.0 / 3) + * (pb / (2 * np.pi)) ** (-5.0 / 3) + * (1 - e**2) ** (-1) ) ) ** (3.0 / 2) @@ -932,7 +939,7 @@ def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled): @u.quantity_input(pb=u.d, mp=u.Msun, mc=u.Msun, i=u.deg) def a1sini(mp, mc, pb, i=90 * u.deg): - """Pulsar's semi-major axis. + r"""Pulsar's semi-major axis. The full semi-major axis is given by Kepler's third law. This is the projection (:math:`\sin i`) of just the pulsar's orbit (:math:`m_c/(m_p+m_c)` @@ -968,8 +975,8 @@ def a1sini(mp, mc, pb, i=90 * u.deg): .. math:: - \\frac{a_p \sin i}{c} = \\frac{m_c \sin i}{(m_p+m_c)^{2/3}} - G^{1/3}\\left(\\frac{P_b}{2\pi}\\right)^{2/3} + \frac{a_p \sin i}{c} = \frac{m_c \sin i}{(m_p+m_c)^{2/3}} + G^{1/3}\left(\frac{P_b}{2\pi}\right)^{2/3} More details in :ref:`Timing Models`. Also see [8]_ @@ -982,9 +989,9 @@ def a1sini(mp, mc, pb, i=90 * u.deg): ).to(pint.ls) -@u.quantity_input -def shklovskii_factor(pmtot: u.mas / u.yr, D: u.kpc): - """Compute magnitude of Shklovskii correction factor. +@u.quantity_input(pmtot=u.mas / u.yr, d=u.kpc) +def shklovskii_factor(pmtot: u.Quantity, D: u.Quantity): + r"""Compute magnitude of Shklovskii correction factor. Computes the Shklovskii correction factor, as defined in Eq 8.12 of Lorimer & Kramer (2005) [10]_ This is the factor by which :math:`\dot P /P` is increased due to the transverse velocity. @@ -993,7 +1000,7 @@ def shklovskii_factor(pmtot: u.mas / u.yr, D: u.kpc): .. math:: - \dot P_{\\rm intrinsic} = \dot P_{\\rm observed} - a_s P + \dot P_{\rm intrinsic} = \dot P_{\rm observed} - a_s P Parameters ---------- @@ -1020,8 +1027,8 @@ def shklovskii_factor(pmtot: u.mas / u.yr, D: u.kpc): return a_s -@u.quantity_input -def dispersion_slope(dm: pint.dmu): +@u.quantity_input(dm=pint.dmu) +def dispersion_slope(dm: u.Quantity): """Compute the dispersion slope. This is equal to DMconst * DM. diff --git a/src/pint/eventstats.py b/src/pint/eventstats.py index b24c39aae..2fb68dbf1 100644 --- a/src/pint/eventstats.py +++ b/src/pint/eventstats.py @@ -47,7 +47,7 @@ def from_array(x): def sig2sigma(sig, two_tailed=True, logprob=False): - """Convert tail probability to "sigma" units. + r"""Convert tail probability to "sigma" units. Find the value of the argument for the normal distribution beyond which the integrated tail probability is sig. Note that the default is to interpret diff --git a/src/pint/fitter.py b/src/pint/fitter.py index c34ed3332..9021bb800 100644 --- a/src/pint/fitter.py +++ b/src/pint/fitter.py @@ -60,6 +60,8 @@ import contextlib import copy +from functools import cached_property +from typing import List, Optional from warnings import warn import astropy.units as u @@ -70,8 +72,10 @@ from numdifftools import Hessian import pint -import pint.utils import pint.derived_quantities +import pint.models +import pint.models.timing_model +import pint.utils from pint.models.parameter import ( AngleParameter, boolParameter, @@ -90,7 +94,6 @@ from pint.toa import TOAs from pint.utils import FTest, normalize_designmatrix - __all__ = [ "Fitter", "WLSFitter", @@ -107,64 +110,6 @@ "MaxiterReached", ] -try: - from functools import cached_property -except ImportError: - # not supported in python 3.7 - # This is just the code from python 3.8 - from _thread import RLock - - _NOT_FOUND = object() - - class cached_property: - def __init__(self, func): - self.func = func - self.attrname = None - self.__doc__ = func.__doc__ - self.lock = RLock() - - def __set_name__(self, owner, name): - if self.attrname is None: - self.attrname = name - elif name != self.attrname: - raise TypeError( - "Cannot assign the same cached_property to two different names " - f"({self.attrname!r} and {name!r})." - ) - - def __get__(self, instance, owner=None): - if instance is None: - return self - if self.attrname is None: - raise TypeError( - "Cannot use cached_property instance without calling __set_name__ on it." - ) - try: - cache = instance.__dict__ - except AttributeError: - # not all objects have __dict__ (e.g. class defines slots) - msg = ( - f"No '__dict__' attribute on {type(instance).__name__!r} " - f"instance to cache {self.attrname!r} property." - ) - raise TypeError(msg) from None - val = cache.get(self.attrname, _NOT_FOUND) - if val is _NOT_FOUND: - with self.lock: - # check if another thread filled cache while we awaited lock - val = cache.get(self.attrname, _NOT_FOUND) - if val is _NOT_FOUND: - val = self.func(instance) - try: - cache[self.attrname] = val - except TypeError: - msg = ( - f"The '__dict__' attribute on {type(instance).__name__!r} instance " - f"does not support item assignment for caching {self.attrname!r} property." - ) - raise TypeError(msg) from None - return val - class DegeneracyWarning(UserWarning): pass @@ -221,7 +166,41 @@ class Fitter: ``GLSFitter`` is used to compute ``chi2`` for appropriate Residuals objects. """ - def __init__(self, toas, model, track_mode=None, residuals=None): + toas: TOAs + """TOAs to fit.""" + model_init: pint.models.timing_model.TimingModel + """Initial timing model the Fitter was created with.""" + track_mode: Optional[str] + """How to handle phase wrapping. + + This is used when creating :class:`pint.residuals.Residuals` + objects, and its meaning is defined there. + """ + resids_init: Residuals + """Initial residuals with respect to the timing model.""" + model: pint.models.timing_model.TimingModel + """Current timing model in use by the Fitter.""" + fitresult: List + method: Optional[str] + is_wideband: bool + converged: bool + parameter_covariance_matrix: CovarianceMatrix + """The covariance matrix of the model parameters after fitting. + + This attribute may not exist if the fitter has not been run + (some subclasses of Fitter don't compute this matrix except + as part of the fit, and don't create the attribute). + """ + fac: np.ndarray + """Scaling factors applied to the columns(?) of the design matrix.""" + + def __init__( + self, + toas: TOAs, + model: pint.models.TimingModel, + track_mode: Optional[str] = None, + residuals: Optional[Residuals] = None, + ): if not set(model.free_params).issubset(model.fittable_params): free_unfittable_params = set(model.free_params).difference( model.fittable_params @@ -490,9 +469,11 @@ def get_derived_params(self, returndict=False): """ return self.model.get_derived_params( - rms=self.resids.toa.rms_weighted() - if self.is_wideband - else self.resids.rms_weighted(), + rms=( + self.resids.toa.rms_weighted() + if self.is_wideband + else self.resids.rms_weighted() + ), ntoas=self.toas.ntoas, returndict=returndict, ) @@ -1082,17 +1063,19 @@ def _fit_toas( self.parameter_covariance_matrix.to_correlation_matrix() ) - for p, e in zip(self.current_state.params, self.errors): + for p, error in zip(self.current_state.params, self.errors): try: # I don't know why this fails with multiprocessing, but bypass if it does with contextlib.suppress(ValueError): - log.trace(f"Setting {getattr(self.model, p)} uncertainty to {e}") + log.trace( + f"Setting {getattr(self.model, p)} uncertainty to {error}" + ) pm = getattr(self.model, p) except AttributeError: if p != "Offset": log.warning(f"Unexpected parameter {p}") else: - pm.uncertainty = e * pm.units + pm.uncertainty = error * pm.units self.update_model(self.current_state.chi2) if exception is not None: raise StepProblem( diff --git a/src/pint/gridutils.py b/src/pint/gridutils.py index 4a41d1c1d..521793b8c 100644 --- a/src/pint/gridutils.py +++ b/src/pint/gridutils.py @@ -1,4 +1,5 @@ """Tools for building chi-squared grids.""" + import concurrent.futures import copy import multiprocessing @@ -12,7 +13,7 @@ try: from tqdm import tqdm except ModuleNotFoundError: - tqdm = None + tqdm = None # type: ignore from astropy.utils.console import ProgressBar @@ -164,7 +165,7 @@ def grid_chisq( printprogress=True, **fitargs, ): - """Compute chisq over a grid of parameters + r"""Compute chisq over a grid of parameters Parameters ---------- diff --git a/src/pint/logging.py b/src/pint/logging.py index 0f4f67021..f1d367f8c 100644 --- a/src/pint/logging.py +++ b/src/pint/logging.py @@ -53,9 +53,10 @@ import re import sys import warnings -from loguru import logger as log +from typing import Dict, Tuple from erfa import ErfaWarning +from loguru import logger as log __all__ = ["LogFilter", "setup", "format", "levels", "get_level"] @@ -72,7 +73,7 @@ # https://loguru.readthedocs.io/en/stable/api/logger.html#color showwarning_ = warnings.showwarning -warning_onceregistry = {} +warning_onceregistry: Dict[Tuple[str, str], int] = {} # basic loguru level definitions from: # https://loguru.readthedocs.io/en/stable/api/logger.html @@ -124,12 +125,13 @@ def showwarning(message, category, filename, lineno, file=None, line=None): class LogFilter: """Custom logging filter for ``loguru``. + Define some messages that are never seen (e.g., Deprecation Warnings). Others that will only be seen once. Filtering of those is done on the basis of regular expressions. """ def __init__(self, onlyonce=None, never=None, onlyonce_level="INFO"): - """ + r""" Define regexs for messages that will only be seen once. Use ``\S+`` for a variable that might change. If a message comes through with a new value for that variable, it will be seen. diff --git a/src/pint/models/__init__.py b/src/pint/models/__init__.py index e08e69965..f55d4543c 100644 --- a/src/pint/models/__init__.py +++ b/src/pint/models/__init__.py @@ -45,6 +45,45 @@ from pint.models.wave import Wave from pint.models.wavex import WaveX +__all__ = [ + "AbsPhase", + "AstrometryEcliptic", + "AstrometryEquatorial", + "BinaryBT", + "BinaryBTPiecewise", + "BinaryDD", + "BinaryDDS", + "BinaryDDGR", + "BinaryDDK", + "BinaryELL1", + "BinaryELL1H", + "BinaryELL1k", + "DelayJump", + "DispersionDM", + "DispersionDMX", + "DMWaveX", + "EcorrNoise", + "FD", + "FDJump", + "Glitch", + "IFunc", + "PhaseJump", + "PiecewiseSpindown", + "PLRedNoise", + "ScaleToaError", + "SolarSystemShapiro", + "SolarWindDispersion", + "SolarWindDispersionX", + "Spindown", + "TroposphereDelay", + "Wave", + "WaveX", + "get_model", + "get_model_and_toas", + "TimingModel", + "DEFAULT_ORDER", +] + # Define a standard basic model StandardTimingModel = TimingModel( "StandardTimingModel", diff --git a/src/pint/models/binary_dd.py b/src/pint/models/binary_dd.py index b07df462f..adec86b4a 100644 --- a/src/pint/models/binary_dd.py +++ b/src/pint/models/binary_dd.py @@ -127,7 +127,7 @@ def validate(self): class BinaryDDS(BinaryDD): - """Damour and Deruelle model with alternate Shapiro delay parameterization. + r"""Damour and Deruelle model with alternate Shapiro delay parameterization. This extends the :class:`pint.models.binary_dd.BinaryDD` model with :math:`SHAPMAX = -\log(1-s)` instead of just :math:`s=\sin i`, which behaves better diff --git a/src/pint/models/binary_ddk.py b/src/pint/models/binary_ddk.py index 6b40f5ebc..1b5aa9e52 100644 --- a/src/pint/models/binary_ddk.py +++ b/src/pint/models/binary_ddk.py @@ -41,7 +41,7 @@ def _convert_kom(kom): class BinaryDDK(BinaryDD): - """Damour and Deruelle model with kinematics. + r"""Damour and Deruelle model with kinematics. This extends the :class:`pint.models.binary_dd.BinaryDD` model with "Shklovskii" and "Kopeikin" terms that account for the finite distance @@ -220,14 +220,14 @@ def validate(self): warnings.warn("Using A1DOT with a DDK model is not advised.") def alternative_solutions(self): - """Alternative Kopeikin solutions (potential local minima) + r"""Alternative Kopeikin solutions (potential local minima) There are 4 potential local minima for a DDK model where a1dot is the same These are given by where Eqn. 8 in Kopeikin (1996) is equal to the best-fit value. We first define the symmetry point where a1dot is zero (in equatorial coordinates): - :math:`KOM_0 = \\tan^{-1} (\mu_{\delta} / \mu_{\\alpha})` + :math:`KOM_0 = \tan^{-1} (\mu_{\delta} / \mu_{\alpha})` The solutions are then: diff --git a/src/pint/models/pulsar_binary.py b/src/pint/models/pulsar_binary.py index b27dc923c..dd9231797 100644 --- a/src/pint/models/pulsar_binary.py +++ b/src/pint/models/pulsar_binary.py @@ -38,7 +38,7 @@ class PulsarBinary(DelayComponent): - """Base class for binary models in PINT. + r"""Base class for binary models in PINT. This class provides a wrapper for internal classes that do the actual calculations. The calculations are done by the classes located in diff --git a/src/pint/models/stand_alone_psr_binaries/DDGR_model.py b/src/pint/models/stand_alone_psr_binaries/DDGR_model.py index 85c40fd39..86e5e8825 100644 --- a/src/pint/models/stand_alone_psr_binaries/DDGR_model.py +++ b/src/pint/models/stand_alone_psr_binaries/DDGR_model.py @@ -1,4 +1,5 @@ """The DDGR model - Damour and Deruelle with GR assumed""" + import astropy.constants as c import astropy.units as u import numpy as np @@ -618,29 +619,29 @@ def d_beta_d_M2(self): * self.d_omega_d_M2() ) - @SINI.setter + @SINI.setter # type: ignore[no-redef, attr-defined] def SINI(self, val): log.debug( "DDGR model uses MTOT to derive the inclination angle. SINI will not be used." ) - @PBDOT.setter + @PBDOT.setter # type: ignore[no-redef, attr-defined] def PBDOT(self, val): log.debug("DDGR model uses MTOT to derive PBDOT. PBDOT will not be used.") - @OMDOT.setter + @OMDOT.setter # type: ignore[no-redef, attr-defined] def OMDOT(self, val): log.debug("DDGR model uses MTOT to derive OMDOT. OMDOT will not be used.") - @GAMMA.setter + @GAMMA.setter # type: ignore[no-redef, attr-defined] def GAMMA(self, val): log.debug("DDGR model uses MTOT to derive GAMMA. GAMMA will not be used.") - @DR.setter + @DR.setter # type: ignore[no-redef, attr-defined] def DR(self, val): log.debug("DDGR model uses MTOT to derive Dr. Dr will not be used.") - @DTH.setter + @DTH.setter # type: ignore[no-redef, attr-defined] def DTH(self, val): log.debug("DDGR model uses MTOT to derive Dth. Dth will not be used.") diff --git a/src/pint/models/stand_alone_psr_binaries/DDH_model.py b/src/pint/models/stand_alone_psr_binaries/DDH_model.py index 3c4d59dcb..5fe95bcf6 100644 --- a/src/pint/models/stand_alone_psr_binaries/DDH_model.py +++ b/src/pint/models/stand_alone_psr_binaries/DDH_model.py @@ -1,4 +1,5 @@ """The DDS model - Damour and Deruelle with alternate Shapiro delay parametrization.""" + import astropy.constants as c import astropy.units as u import numpy as np @@ -69,13 +70,13 @@ def SINI(self): def M2(self): return self.H3 / self.STIGMA**3 / Tsun.value - @SINI.setter + @SINI.setter # type: ignore[no-redef, attr-defined] def SINI(self, val): log.debug( "DDH model uses H3/STIGMA as Shapiro delay parameter. SINI will not be used." ) - @M2.setter + @M2.setter # type: ignore[no-redef, attr-defined] def M2(self, val): log.debug( "DDH model uses H3/STIGMA as Shapiro delay parameter. M2 will not be used." diff --git a/src/pint/models/stand_alone_psr_binaries/DDK_model.py b/src/pint/models/stand_alone_psr_binaries/DDK_model.py index 709033aae..a586400b0 100644 --- a/src/pint/models/stand_alone_psr_binaries/DDK_model.py +++ b/src/pint/models/stand_alone_psr_binaries/DDK_model.py @@ -8,7 +8,7 @@ class DDKmodel(DDmodel): - """DDK model, a Kopeikin method corrected DD model. + r"""DDK model, a Kopeikin method corrected DD model. The main difference is that DDK model considers the effects on the pulsar binary parameters from the annual parallax of earth and the proper motion of the pulsar. @@ -155,7 +155,7 @@ def SINI(self, val): # Update binary parameters due to the pulser proper motion def delta_kin_proper_motion(self): - """The time dependent inclination angle + r"""The time dependent inclination angle (Kopeikin 1996 Eq 10): .. math:: @@ -231,7 +231,7 @@ def d_kin_d_par(self, par): return func() def delta_a1_proper_motion(self): - """The correction on a1 (projected semi-major axis) + r"""The correction on a1 (projected semi-major axis) due to the pulsar proper motion (Kopeikin 1996 Eq 8): @@ -289,7 +289,7 @@ def d_delta_a1_proper_motion_d_T0(self): return d_delta_a1_proper_motion_d_T0.to(a1.unit / self.T0.unit) def delta_omega_proper_motion(self): - """The correction on omega (Longitude of periastron) + r"""The correction on omega (Longitude of periastron) due to the pulsar proper motion (Kopeikin 1996 Eq 9): @@ -353,7 +353,7 @@ def d_delta_omega_proper_motion_d_T0(self): # Reference KOPEIKIN. 1995 Eq 18 -> Eq 19. def delta_I0(self): - """ + r""" :math:`\Delta_{I0}` Reference: (Kopeikin 1995 Eq 15) @@ -361,7 +361,7 @@ def delta_I0(self): return -self.obs_pos[:, 0] * self.sin_long + self.obs_pos[:, 1] * self.cos_long def delta_J0(self): - """ + r""" :math:`\Delta_{J0}` Reference: (Kopeikin 1995 Eq 16) @@ -373,19 +373,19 @@ def delta_J0(self): ) def delta_sini_parallax(self): - """Reference (Kopeikin 1995 Eq 18). Computes: + r"""Reference (Kopeikin 1995 Eq 18). Computes: .. math:: - x_{obs} = \\frac{a_p \sin(i)_{obs}}{c} + x_{obs} = \frac{a_p \sin(i)_{obs}}{c} Since :math:`a_p` and :math:`c` will not be changed by parallax: .. math:: - x_{obs} = \\frac{a_p}{c}(\sin(i)_{\\rm intrisic} + \delta_{\sin(i)}) + x_{obs} = \frac{a_p}{c}(\sin(i)_{\rm intrisic} + \delta_{\sin(i)}) - \delta_{\sin(i)} = \sin(i)_{\\rm intrisic} \\frac{\cot(i)_{\\rm intrisic}}{d} (\Delta_{I0} \sin KOM - \Delta_{J0} \cos KOM) + \delta_{\sin(i)} = \sin(i)_{\rm intrisic} \frac{\cot(i)_{\rm intrisic}}{d} (\Delta_{I0} \sin KOM - \Delta_{J0} \cos KOM) """ PX_kpc = self.PX.to(u.kpc, equivalencies=u.parallax()) diff --git a/src/pint/models/stand_alone_psr_binaries/DDS_model.py b/src/pint/models/stand_alone_psr_binaries/DDS_model.py index 1fc81fc56..db572f7f6 100644 --- a/src/pint/models/stand_alone_psr_binaries/DDS_model.py +++ b/src/pint/models/stand_alone_psr_binaries/DDS_model.py @@ -10,7 +10,7 @@ class DDSmodel(DDmodel): - """Damour and Deruelle model with alternate Shapiro delay parameterization. + r"""Damour and Deruelle model with alternate Shapiro delay parameterization. This extends the :class:`pint.models.binary_dd.BinaryDD` model with :math:`SHAPMAX = -\log(1-s)` instead of just :math:`s=\sin i`, which behaves better diff --git a/src/pint/models/stand_alone_psr_binaries/ELL1H_model.py b/src/pint/models/stand_alone_psr_binaries/ELL1H_model.py index 7a8de20e1..35930dbe8 100644 --- a/src/pint/models/stand_alone_psr_binaries/ELL1H_model.py +++ b/src/pint/models/stand_alone_psr_binaries/ELL1H_model.py @@ -9,7 +9,7 @@ class ELL1Hmodel(ELL1BaseModel): - """ELL1H pulsar binary model using H3, H4 or STIGMA as shapiro delay parameters. + r"""ELL1H pulsar binary model using H3, H4 or STIGMA as shapiro delay parameters. Note ---- @@ -21,7 +21,7 @@ class ELL1Hmodel(ELL1BaseModel): .. math:: - \\Delta_S = -2r \\left( \\frac{a_0}{2} + \\Sum_k (a_k \\cos k\\phi + b_k \\sin k \phi) \\right) + \Delta_S = -2r \left( \frac{a_0}{2} + \Sum_k (a_k \cos k\phi + b_k \sin k \phi) \right) The first two harmonics are generlly absorbed by the ELL1 Roemer delay. Thus, :class:`~pint.models.binary_ell1.BinaryELL1H` uses the series from the third diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index 06e3e1ead..7b8807624 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -24,53 +24,53 @@ :func:`~pint.models.model_builder.get_model`. See :ref:`Timing Models` for more details on how PINT's timing models work. - """ import abc +import contextlib import copy import inspect -import contextlib from collections import OrderedDict, defaultdict from functools import wraps +from typing import Dict from warnings import warn -from uncertainties import ufloat +import astropy.coordinates as coords import astropy.time as time -from astropy import units as u, constants as c import numpy as np +from astropy import constants as c +from astropy import units as u from astropy.utils.decorators import lazyproperty -import astropy.coordinates as coords -from pint.pulsar_ecliptic import OBL, PulsarEcliptic -from scipy.optimize import brentq from loguru import logger as log +from scipy.optimize import brentq +from uncertainties import ufloat import pint +from pint.derived_quantities import dispersion_slope from pint.models.parameter import ( - _parfile_formats, AngleParameter, MJDParameter, Parameter, + _parfile_formats, boolParameter, floatParameter, funcParameter, intParameter, maskParameter, - strParameter, prefixParameter, + strParameter, ) from pint.phase import Phase +from pint.pulsar_ecliptic import OBL, PulsarEcliptic from pint.toa import TOAs from pint.utils import ( PrefixError, - split_prefixed_name, - open_or_use, colorize, + open_or_use, + split_prefixed_name, xxxselections, ) -from pint.derived_quantities import dispersion_slope - __all__ = [ "DEFAULT_ORDER", @@ -490,10 +490,10 @@ def num_components_of_type(type): ), "Model can have at most one solar wind dispersion component." from pint.models.dispersion_model import DispersionDMX + from pint.models.dmwavex import DMWaveX + from pint.models.noise_model import PLDMNoise, PLRedNoise from pint.models.wave import Wave from pint.models.wavex import WaveX - from pint.models.dmwavex import DMWaveX - from pint.models.noise_model import PLRedNoise, PLDMNoise if num_components_of_type((DispersionDMX, PLDMNoise, DMWaveX)) > 1: log.warning( @@ -612,7 +612,9 @@ def free_params(self): """ return [p for p in self.params if not getattr(self, p).frozen] - @free_params.setter + # mypy doesn't understand the decorator syntax here + # maybe we'd need to express the type of property_exists better? + @free_params.setter # type: ignore def free_params(self, params): params_true = {self.match_param_aliases(p) for p in params} for p in self.params: @@ -635,10 +637,8 @@ def fittable_params(self): p in self.phase_deriv_funcs or p in self.delay_deriv_funcs or ( - ( - hasattr(self, "toasigma_deriv_funcs") - and p in self.toasigma_deriv_funcs - ) + hasattr(self, "toasigma_deriv_funcs") + and p in self.toasigma_deriv_funcs ) or (hasattr(self[p], "prefix") and self[p].prefix == "ECORR") ) @@ -3127,7 +3127,9 @@ def get_derived_params(self, rms=None, ntoas=None, returndict=False): ) s += "Conversion from ELL1 parameters:\n" ecc = um.sqrt(eps1**2 + eps2**2) - s += "ECC = {:P}\n".format(ecc) + # mypy does not know about uncertainties introducing a new + # format code, so we have to tell it to ignore this line + s += "ECC = {:P}\n".format(ecc) # type: ignore outdict["ECC"] = ecc om = um.atan2(eps1, eps2) * 180.0 / np.pi if om < 0.0: @@ -3181,14 +3183,12 @@ def get_derived_params(self, rms=None, ntoas=None, returndict=False): omdot = self.OMDOT.as_ufloat(u.rad / u.s) e = ecc if ell1 else self.ECC.as_ufloat() mt = ( - ( - omdot - / ( - 3 - * (c.G * u.Msun / c.c**3).to_value(u.s) ** (2.0 / 3) - * ((pb * 86400 / 2 / np.pi)) ** (-5.0 / 3) - * (1 - e**2) ** -1 - ) + omdot + / ( + 3 + * (c.G * u.Msun / c.c**3).to_value(u.s) ** (2.0 / 3) + * (pb * 86400 / 2 / np.pi) ** (-5.0 / 3) + * (1 - e**2) ** -1 ) ) ** (3.0 / 2) s += f"Total mass, assuming GR, from OMDOT is {mt:SP} Msun\n" @@ -3254,7 +3254,7 @@ class Component(metaclass=ModelMeta): invalid parameter values are chosen. """ - component_types = {} + component_types: Dict[str, ModelMeta] = {} def __init__(self): self.params = [] diff --git a/src/pint/observatory/__init__.py b/src/pint/observatory/__init__.py index 895d46411..caf3af767 100644 --- a/src/pint/observatory/__init__.py +++ b/src/pint/observatory/__init__.py @@ -21,14 +21,17 @@ necessary. """ -from copy import deepcopy import os import textwrap from collections import defaultdict +from collections.abc import Callable +from copy import deepcopy from io import StringIO from pathlib import Path +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union import astropy.coordinates +import astropy.time import astropy.units as u import numpy as np from astropy.coordinates import EarthLocation @@ -36,7 +39,10 @@ from pint.config import runtimefile from pint.pulsar_mjd import Time -from pint.utils import interesting_lines +from pint.utils import PosVel, interesting_lines + +if TYPE_CHECKING: + from pint.observatory.clock_file import ClockFile # Include any files that define observatories here. This will start # with the standard distribution files, then will read any system- or @@ -87,7 +93,7 @@ class ClockCorrectionOutOfRange(ClockCorrectionError): _bipm_clock_versions = {} -def _load_gps_clock(): +def _load_gps_clock() -> None: global _gps_clock if _gps_clock is None: log.info("Loading global GPS clock file") @@ -97,7 +103,7 @@ def _load_gps_clock(): ) -def _load_bipm_clock(bipm_version): +def _load_bipm_clock(bipm_version: str) -> None: bipm_version = bipm_version.lower() if bipm_version not in _bipm_clock_versions: try: @@ -136,34 +142,43 @@ class Observatory: position. """ + fullname: str + """Full human-readable name of the observatory.""" + include_gps: bool + """Whether to include GPS clock corrections.""" + include_bipm: bool + """Whether to include BIPM clock corrections.""" + bipm_version: str + """Version of the BIPM clock file to use.""" + # This is a dict containing all defined Observatory instances, # keyed on standard observatory name. - _registry = {} + _registry: Dict[str, "Observatory"] = {} # This is a dict mapping any defined aliases to the corresponding # standard name. - _alias_map = {} + _alias_map: Dict[str, str] = {} def __init__( self, - name, - fullname=None, - aliases=None, - include_gps=True, - include_bipm=True, - bipm_version=bipm_default, - overwrite=False, + name: str, + fullname: Optional[str] = None, + aliases: Optional[List[str]] = None, + include_gps: bool = True, + include_bipm: bool = True, + bipm_version: str = bipm_default, + overwrite: bool = False, ): - self._name = name.lower() - self._aliases = ( + self._name: str = name.lower() + self._aliases: List[str] = ( list(set(map(str.lower, aliases))) if aliases is not None else [] ) if aliases is not None: Observatory._add_aliases(self, aliases) - self.fullname = fullname if fullname is not None else name - self.include_gps = include_gps - self.include_bipm = include_bipm - self.bipm_version = bipm_version + self.fullname: str = fullname if fullname is not None else name + self.include_gps: bool = include_gps + self.include_bipm: bool = include_bipm + self.bipm_version: str = bipm_version if name.lower() in Observatory._registry: if not overwrite: @@ -175,16 +190,18 @@ def __init__( Observatory._register(self, name) @classmethod - def _register(cls, obs, name): - """Add an observatory to the registry using the specified name - (which will be converted to lower case). If an existing observatory + def _register(cls, obs: "Observatory", name: str) -> None: + """Add an observatory to the registry using the specified name (which will be converted to lower case). + + If an existing observatory of the same name exists, it will be replaced with the new one. The Observatory instance's name attribute will be updated for - consistency.""" + consistency. + """ cls._registry[name.lower()] = obs @classmethod - def _add_aliases(cls, obs, aliases): + def _add_aliases(cls, obs: "Observatory", aliases: List[str]) -> None: """Add aliases for the specified Observatory. Aliases should be given as a list. If any of the new aliases are already in use, they will be replaced. Aliases are not checked against the @@ -196,14 +213,17 @@ def _add_aliases(cls, obs, aliases): cls._alias_map[a.lower()] = obs.name @staticmethod - def gps_correction(t, limits="warn"): + def gps_correction(t: astropy.time.Time, limits: str = "warn") -> u.Quantity: """Compute the GPS clock corrections for times t.""" log.info("Applying GPS to UTC clock correction (~few nanoseconds)") _load_gps_clock() + assert _gps_clock is not None return _gps_clock.evaluate(t, limits=limits) @staticmethod - def bipm_correction(t, bipm_version=bipm_default, limits="warn"): + def bipm_correction( + t: astropy.time.Time, bipm_version: str = bipm_default, limits: str = "warn" + ) -> u.Quantity: """Compute the GPS clock corrections for times t.""" log.info(f"Applying TT(TAI) to TT({bipm_version}) clock correction (~27 us)") tt2tai = 32.184 * 1e6 * u.us @@ -214,7 +234,7 @@ def bipm_correction(t, bipm_version=bipm_default, limits="warn"): ) @classmethod - def clear_registry(cls): + def clear_registry(cls) -> None: """Clear registry for ground-based observatories.""" cls._registry = {} cls._alias_map = {} @@ -229,7 +249,7 @@ def names(cls): return cls._registry.keys() @classmethod - def names_and_aliases(cls): + def names_and_aliases(cls) -> Dict[str, List[str]]: """List all observatories and their aliases""" import pint.observatory.topo_obs # noqa import pint.observatory.special_locations # noqa @@ -241,15 +261,24 @@ def names_and_aliases(cls): # setter methods that update the registries appropriately. @property - def name(self): + def name(self) -> str: + """Short name of the observatory. + + This is the name used in TOA files and in the observatory registry. + """ return self._name @property - def aliases(self): + def aliases(self) -> List[str]: + """List of aliases for the observatory. + + These are short names also used to specify this observatory. + Includes ITOA and TEMPO codes, and any other common names. + """ return self._aliases @classmethod - def get(cls, name): + def get(cls, name: str) -> "Observatory": """Returns the Observatory instance for the specified name/alias. If the name has not been defined, an error will be raised. Aside @@ -303,9 +332,12 @@ def get(cls, name): # Any which raise NotImplementedError below must be implemented in # derived classes. - def earth_location_itrf(self, time=None): - """Returns observatory geocentric position as an astropy - EarthLocation object. For observatories where this is not + def earth_location_itrf( + self, time: Optional[astropy.time.Time] = None + ) -> Union[None, np.ndarray]: + """Returns observatory geocentric position as an astropy EarthLocation object. + + For observatories where this is not relevant, None can be returned. The location is in the International Terrestrial Reference Frame (ITRF). @@ -319,8 +351,9 @@ def earth_location_itrf(self, time=None): """ return None - def get_gcrs(self, t, ephem=None): - """Return position vector of observatory in GCRS + def get_gcrs(self, t: astropy.time.Time, ephem: Optional[str] = None): + """Return position vector of observatory in GCRS. + t is an astropy.Time or array of astropy.Time objects ephem is a link to an ephemeris file. Needed for SSB observatory Returns a 3-vector of Quantities representing the position @@ -329,14 +362,17 @@ def get_gcrs(self, t, ephem=None): raise NotImplementedError @property - def timescale(self): - """Returns the timescale that TOAs from this observatory will be in, - once any clock corrections have been applied. This should be a + def timescale(self) -> str: + """Returns the timescale that TOAs from this observatory will be in, once any clock corrections have been applied. + + This should be a string suitable to be passed directly to the scale argument of astropy.time.Time().""" raise NotImplementedError - def clock_corrections(self, t, limits="warn"): + def clock_corrections( + self, t: astropy.time.Time, limits: str = "warn" + ) -> u.Quantity: """Compute clock corrections for a Time array. Given an array-valued Time, return the clock corrections @@ -356,7 +392,7 @@ def clock_corrections(self, t, limits="warn"): return corr - def last_clock_correction_mjd(self): + def last_clock_correction_mjd(self) -> float: """Return the MJD of the last available clock correction. Returns ``np.inf`` if no clock corrections are relevant. @@ -365,6 +401,7 @@ def last_clock_correction_mjd(self): if self.include_gps: _load_gps_clock() + assert _gps_clock is not None t = min(t, _gps_clock.last_correction_mjd()) if self.include_bipm: _load_bipm_clock(self.bipm_version) @@ -374,7 +411,13 @@ def last_clock_correction_mjd(self): ) return t - def get_TDBs(self, t, method="default", ephem=None, options=None): + def get_TDBs( + self, + t: astropy.time.Time, + method: Union[str, Callable] = "default", + ephem: Optional[str] = None, + options: Optional[dict] = None, + ): """This is a high level function for converting TOAs to TDB time scale. Different method can be applied to obtain the result. Current supported @@ -409,13 +452,13 @@ def get_TDBs(self, t, method="default", ephem=None, options=None): t = Time([t]) if t.scale == "tdb": return t - # Check the method. This pattern is from numpy minimize - meth = "_custom" if callable(method) else method.lower() if options is None: options = {} - if meth == "_custom": + if callable(method): options = dict(options) return method(t, **options) + else: + meth = method.lower() if meth == "default": return self._get_TDB_default(t, ephem) elif meth == "ephemeris": @@ -428,17 +471,17 @@ def get_TDBs(self, t, method="default", ephem=None, options=None): else: raise ValueError(f"Unknown method '{method}'.") - def _get_TDB_default(self, t, ephem): + def _get_TDB_default(self, t: astropy.time.Time, ephem: Optional[str]): return t.tdb - def _get_TDB_ephem(self, t, ephem): + def _get_TDB_ephem(self, t: astropy.time.Time, ephem: Optional[str]): """Read the ephem TDB-TT column. This column is provided by DE4XXt version of ephemeris. """ raise NotImplementedError - def posvel(self, t, ephem, group=None): + def posvel(self, t: astropy.time.Time, ephem: Optional[str], group=None) -> PosVel: """Return observatory position and velocity for the given times. Position is relative to solar system barycenter; times are @@ -451,7 +494,10 @@ def posvel(self, t, ephem, group=None): def get_observatory( - name, include_gps=None, include_bipm=None, bipm_version=bipm_default + name: str, + include_gps: Optional[bool] = None, + include_bipm: Optional[bool] = None, + bipm_version: str = bipm_default, ): """Convenience function to get observatory object with options. @@ -491,14 +537,14 @@ def get_observatory( return Observatory.get(name) -def earth_location_distance(loc1, loc2): +def earth_location_distance(loc1: EarthLocation, loc2: EarthLocation) -> u.Quantity: """Compute the distance between two EarthLocations.""" return ( sum((u.Quantity(loc1.to_geocentric()) - u.Quantity(loc2.to_geocentric())) ** 2) ) ** 0.5 -def compare_t2_observatories_dat(t2dir=None): +def compare_t2_observatories_dat(t2dir: Optional[str] = None) -> Dict[str, List[Dict]]: """Read a tempo2 observatories.dat file and compare with PINT Produces a report including lines that can be added to PINT's @@ -531,10 +577,10 @@ def compare_t2_observatories_dat(t2dir=None): with open(filename) as f: for line in interesting_lines(f, comments="#"): try: - x, y, z, full_name, short_name = line.split() + x_str, y_str, z_str, full_name, short_name = line.split() except ValueError as e: raise ValueError(f"unrecognized line '{line}'") from e - x, y, z = float(x), float(y), float(z) + x, y, z = float(x_str), float(y_str), float(z_str) full_name, short_name = full_name.lower(), short_name.lower() topo_obs_entry = textwrap.dedent( f""" @@ -589,7 +635,7 @@ def compare_t2_observatories_dat(t2dir=None): return report -def compare_tempo_obsys_dat(tempodir=None): +def compare_tempo_obsys_dat(tempodir: Optional[str] = None) -> Dict[str, List[Dict]]: """Read a tempo obsys.dat file and compare with PINT. Produces a report including lines that can be added to PINT's @@ -629,8 +675,8 @@ def compare_tempo_obsys_dat(tempodir=None): y = float(line_io.read(15)) z = float(line_io.read(15)) line_io.read(2) - icoord = line_io.read(1).strip() - icoord = int(icoord) if icoord else 0 + icoord_str = line_io.read(1).strip() + icoord = int(icoord_str) if icoord_str else 0 line_io.read(2) obsnam = line_io.read(20).strip().lower() tempo_code = line_io.read(1) @@ -713,7 +759,7 @@ def convert_angle(x): return report -def list_last_correction_mjds(): +def list_last_correction_mjds() -> None: """Print out a list of the last MJD each clock correction is good for. Each observatory lists the clock files it uses and their last dates, @@ -744,7 +790,7 @@ def list_last_correction_mjds(): print(f" {c.friendly_name:<20} MISSING") -def update_clock_files(bipm_versions=None): +def update_clock_files(bipm_versions: Optional[List[str]] = None) -> None: """Obtain an up-to-date version of all clock files. This up-to-date version will be stored in the Astropy cache; @@ -786,13 +832,13 @@ def update_clock_files(bipm_versions=None): # Both topo_obs and special_locations need this def find_clock_file( - name, - format, - bogus_last_correction=False, - url_base=None, - clock_dir=None, - valid_beyond_ends=False, -): + name: str, + format: str, + bogus_last_correction: bool = False, + url_base: Optional[str] = None, + clock_dir: Union[str, Path, None] = None, + valid_beyond_ends: bool = False, +) -> "ClockFile": """Locate and return a ClockFile in one of several places. PINT looks for clock files in three places, in order: diff --git a/src/pint/observatory/topo_obs.py b/src/pint/observatory/topo_obs.py index 1d7fe8868..7af88807d 100644 --- a/src/pint/observatory/topo_obs.py +++ b/src/pint/observatory/topo_obs.py @@ -17,12 +17,15 @@ -------- :mod:`pint.observatory.special_locations` """ +import copy import json import os +from functools import cached_property from pathlib import Path -import copy +from typing import Optional, Union, List, Any, Dict import astropy.constants as c +import astropy.time import astropy.units as u import numpy as np from astropy.coordinates import EarthLocation @@ -36,13 +39,13 @@ NoClockCorrections, Observatory, bipm_default, + earth_location_distance, find_clock_file, get_observatory, - earth_location_distance, ) from pint.pulsar_mjd import Time from pint.solar_system_ephemerides import get_tdb_tt_ephem_geocenter, objPosVel_wrt_SSB -from pint.utils import has_astropy_unit, open_or_use +from pint.utils import has_astropy_unit, open_or_use, PosVel # environment variables that can override clock location and observatory location pint_obs_env_var = "PINT_OBS_OVERRIDE" @@ -147,38 +150,63 @@ class TopoObs(Observatory): """ + tempo_code: Optional[str] + """One-character TEMPO code.""" + itoa_code: Optional[str] + """Two-character ITOA code.""" + location: EarthLocation + """Location of the observatory.""" + clock_files: List[str] + """List of files to read for clock corrections. If empty, no clock corrections are applied.""" + clock_fmt: str + """Format of the clock files. + + See :class:`pint.observatory.clock_file.ClockFile` for allowed values. + """ + bogus_last_correction: bool + """Clock correction files include a bogus last correction. + + This is common with TEMPO/TEMPO2 clock files since neither program does + a good job with times past the end ot the table. It makes detecting values + past the end of real calibration difficult if it's not marked as bogus. + """ + clock_dir: Optional[Union[str, Path]] + """Where to look for the clock files.""" + origin: Optional[str] + """Documentation of the origin/author/date for the information.""" + def __init__( self, - name, + name: str, *, - fullname=None, - tempo_code=None, - itoa_code=None, - aliases=None, - location=None, + fullname: Optional[str] = None, + tempo_code: Optional[str] = None, + itoa_code: Optional[str] = None, + aliases: Optional[List[str]] = None, + location: Optional[EarthLocation] = None, itrf_xyz=None, - lat=None, - lon=None, + lat: Optional[float] = None, + lon: Optional[float] = None, height=None, - clock_file="", - clock_fmt="tempo", - clock_dir=None, - include_gps=True, - include_bipm=True, - bipm_version=bipm_default, - origin=None, - overwrite=False, - bogus_last_correction=False, + clock_file: str = "", + clock_fmt: str = "tempo", + clock_dir: Union[str, Path, None] = None, + include_gps: bool = True, + include_bipm: bool = True, + bipm_version: str = bipm_default, + origin: Optional[str] = None, + overwrite: bool = False, + bogus_last_correction: bool = False, ): input_values = [lat is not None, lon is not None, height is not None] - if sum(input_values) > 0 and sum(input_values) < 3: + if any(input_values) and not all(input_values): raise ValueError("All of lat, lon, height are required for observatory") input_values = [ location is not None, itrf_xyz is not None, (lat is not None and lon is not None and height is not None), ] - if sum(input_values) == 0: + if not any(input_values): raise ValueError( f"EarthLocation, ITRF coordinates, or lat/lon/height are required for observatory '{name}'" ) @@ -209,11 +237,12 @@ def __init__( # Save clock file info, the data will be read only if clock # corrections for this site are requested. - self.clock_files = [clock_file] if isinstance(clock_file, str) else clock_file - self.clock_files = [c for c in self.clock_files if c != ""] - self.clock_fmt = clock_fmt + clock_files: List[str] = ( + [clock_file] if isinstance(clock_file, str) else clock_file + ) + self.clock_files: List[str] = [c for c in clock_files if c != ""] + self.clock_fmt: str = clock_fmt self.clock_dir = clock_dir - self._clock = None # The ClockFile objects, will be read on demand # If using TEMPO time.dat we need to know the 1-char tempo-style # observatory code. @@ -248,7 +277,7 @@ def __init__( overwrite=overwrite, ) - def __repr__(self): + def __repr__(self) -> str: aliases = [f"'{x}'" for x in self.aliases] origin = ( f"{self.fullname}\n{self.origin}" @@ -258,10 +287,10 @@ def __repr__(self): return f"TopoObs('{self.name}' ({','.join(aliases)}) at [{self.location.x}, {self.location.y} {self.location.z}]:\n{origin})" @property - def timescale(self): + def timescale(self) -> str: return "utc" - def get_dict(self): + def get_dict(self) -> Dict[str, Dict[str, Any]]: """Return as a dict with limited/changed info""" # start with the default __dict__ # copy some attributes to rename them and remove those that aren't needed for initialization @@ -276,12 +305,12 @@ def get_dict(self): output["itrf_xyz"] = [x.to_value(u.m) for x in self.location.geocentric] return {self.name: output} - def get_json(self): - """Return as a JSON string""" + def get_json(self) -> str: + """Return as a JSON string.""" return json.dumps(self.get_dict()) - def separation(self, other, method="cartesian"): - """Return separation between two TopoObs objects + def separation(self, other: "TopoObs", method: str = "cartesian") -> u.Quantity: + """Return separation between two TopoObs objects. Parameters ---------- @@ -312,13 +341,12 @@ def separation(self, other, method="cartesian"): ) return (c.R_earth * dsigma).to(u.m, equivalencies=u.dimensionless_angles()) - def earth_location_itrf(self, time=None): + def earth_location_itrf(self, time=None) -> EarthLocation: return self.location - def _load_clock_corrections(self): - if self._clock is not None: - return - self._clock = [] + @cached_property + def _clock(self) -> list: + clock = [] for cf in self.clock_files: if cf == "": continue @@ -326,16 +354,20 @@ def _load_clock_corrections(self): if isinstance(cf, dict): kwargs.update(cf) cf = kwargs.pop("name") - self._clock.append( + clock.append( find_clock_file( cf, format=self.clock_fmt, clock_dir=self.clock_dir, - **kwargs, + # mypy is unhappy about passing in a dict as **kwargs + # which is fair enough since it can't check the keys + # are valid arguments. + **kwargs, # type: ignore ) ) + return clock - def clock_corrections(self, t, limits="warn"): + def clock_corrections(self, t: Time, limits: str = "warn") -> u.Quantity: """Compute the total clock corrections, Parameters @@ -344,17 +376,16 @@ def clock_corrections(self, t, limits="warn"): The time when the clock correcions are applied. """ - corr = super().clock_corrections(t, limits=limits) - # Read clock file if necessary - self._load_clock_corrections() + corr: u.Quantity = super().clock_corrections(t, limits=limits) if self._clock: log.info( f"Applying observatory clock corrections for observatory='{self.name}'." ) for clock in self._clock: corr += clock.evaluate(t, limits=limits) - elif self.clock_files: + # clock_files is not empty, but no clock corrections found + # FIXME: what if only some were found? msg = f"No clock corrections found for observatory {self.name} taken from file {self.clock_files}" if limits == "warn": log.warning(msg) @@ -365,19 +396,18 @@ def clock_corrections(self, t, limits="warn"): log.info(f"Observatory {self.name} requires no clock corrections.") return corr - def last_clock_correction_mjd(self): + def last_clock_correction_mjd(self) -> float: """Return the MJD of the last clock correction. Combines constraints based on Earth orientation parameters and on the available clock corrections specific to the telescope. """ t = super().last_clock_correction_mjd() - self._load_clock_corrections() for clock in self._clock: t = min(t, clock.last_correction_mjd()) return t - def _get_TDB_ephem(self, t, ephem): + def _get_TDB_ephem(self, t: Time, ephem: Optional[str]) -> Time: """Read the ephem TDB-TT column. This column is provided by DE4XXt version of ephemeris. This function is only @@ -389,8 +419,8 @@ def _get_TDB_ephem(self, t, ephem): # Topocenter to Geocenter # Since earth velocity is not going to change a lot in 3ms. The # differences between TT and TDB can be ignored. - earth_pv = objPosVel_wrt_SSB("earth", t.tdb, ephem) - obs_geocenter_pv = gcrs_posvel_from_itrf( + earth_pv: PosVel = objPosVel_wrt_SSB("earth", t.tdb, ephem) + obs_geocenter_pv: PosVel = gcrs_posvel_from_itrf( self.earth_location_itrf(), t, obsname=self.name ) # NOTE @@ -406,7 +436,7 @@ def _get_TDB_ephem(self, t, ephem): location=self.earth_location_itrf(), ) - def get_gcrs(self, t, ephem=None): + def get_gcrs(self, t: astropy.time.Time, ephem: Optional[str] = None): """Return position vector of TopoObs in GCRS Parameters @@ -418,22 +448,22 @@ def get_gcrs(self, t, ephem=None): np.array a 3-vector of Quantities representing the position in GCRS coordinates. """ - obs_geocenter_pv = gcrs_posvel_from_itrf( + obs_geocenter_pv: PosVel = gcrs_posvel_from_itrf( self.earth_location_itrf(), t, obsname=self.name ) return obs_geocenter_pv.pos - def posvel(self, t, ephem, group=None): + def posvel(self, t: astropy.time.Time, ephem: Optional[str], group=None) -> PosVel: if t.isscalar: t = Time([t]) - earth_pv = objPosVel_wrt_SSB("earth", t, ephem) - obs_geocenter_pv = gcrs_posvel_from_itrf( + earth_pv: PosVel = objPosVel_wrt_SSB("earth", t, ephem) + obs_geocenter_pv: PosVel = gcrs_posvel_from_itrf( self.earth_location_itrf(), t, obsname=self.name ) return obs_geocenter_pv + earth_pv -def export_all_clock_files(directory): +def export_all_clock_files(directory: Union[str, Path]) -> None: """Export all clock files PINT is using. This will export all the clock files PINT is using - every clock file used @@ -465,7 +495,7 @@ def export_all_clock_files(directory): clock.export(directory / Path(clock.filename).name) -def load_observatories(filename=observatories_json, overwrite=False): +def load_observatories(filename=observatories_json, overwrite: bool = False) -> None: """Load observatory definitions from JSON and create :class:`pint.observatory.topo_obs.TopoObs` objects, registering them Set `overwrite` to ``True`` if you want to re-read a file with updated definitions. @@ -499,7 +529,7 @@ def load_observatories(filename=observatories_json, overwrite=False): TopoObs(name=obsname, **obsdict) -def load_observatories_from_usual_locations(clear=False): +def load_observatories_from_usual_locations(clear: bool = False) -> None: """Load observatories from the default JSON file as well as ``$PINT_OBS_OVERRIDE``, optionally clearing the registry Running with ``clear=True`` will return PINT to the state it is on import. diff --git a/src/pint/output/publish.py b/src/pint/output/publish.py index 2cb82b0ac..9832eba6b 100644 --- a/src/pint/output/publish.py +++ b/src/pint/output/publish.py @@ -1,24 +1,28 @@ """Generate LaTeX summary of a timing model and TOAs.""" + +from io import StringIO +from typing import List, Union + +import numpy as np from pint.models import ( - TimingModel, - DispersionDMX, FD, + AbsPhase, + DispersionDMX, Glitch, PhaseJump, SolarWindDispersionX, - AbsPhase, + TimingModel, Wave, ) +from pint.models.timing_model import Component from pint.models.dispersion_model import DispersionJump from pint.models.noise_model import NoiseComponent from pint.models.parameter import ( Parameter, funcParameter, ) -from pint.toa import TOAs from pint.residuals import Residuals, WidebandTOAResiduals -from io import StringIO -import numpy as np +from pint.toa import TOAs def publish_param(param: Parameter): @@ -91,6 +95,7 @@ def publish( else "WLS" ) + res: Union[Residuals, WidebandTOAResiduals] if toas.is_wideband(): res = WidebandTOAResiduals(toas, model) toares = res.toa @@ -117,7 +122,7 @@ def publish( "BINARY", ] - exclude_components = [Wave] + exclude_components: List[type[Component]] = [Wave] if not include_dmx: exclude_components.append(DispersionDMX) if not include_jumps: @@ -259,7 +264,7 @@ def publish( ) tex.write("\\hline\n") - tex.write("\multicolumn{2}{c}{Measured Quantities} \\\\ \n") + tex.write("\\multicolumn{2}{c}{Measured Quantities} \\\\ \n") tex.write("\\hline\n") for fp in model.free_params: param = getattr(model, fp) @@ -273,7 +278,7 @@ def publish( tex.write("\\hline\n") if include_set_params: - tex.write("\multicolumn{2}{c}{Set Quantities} \\\\ \n") + tex.write("\\multicolumn{2}{c}{Set Quantities} \\\\ \n") tex.write("\\hline\n") for p in model.params: param = getattr(model, p) @@ -303,7 +308,7 @@ def publish( and getattr(model, p).quantity is not None ] if len(derived_params) > 0: - tex.write("\multicolumn{2}{c}{Derived Quantities} \\\\ \n") + tex.write("\\multicolumn{2}{c}{Derived Quantities} \\\\ \n") tex.write("\\hline\n") for param in derived_params: tex.write(publish_param(param)) diff --git a/src/pint/pintk/plk.py b/src/pint/pintk/plk.py index 1ea713b84..2e6c7d6c9 100644 --- a/src/pint/pintk/plk.py +++ b/src/pint/pintk/plk.py @@ -1,6 +1,7 @@ """ Interactive emulator of tempo2 plk """ + import copy import os import sys @@ -28,9 +29,9 @@ try: - from matplotlib.backends.backend_tkagg import NavigationToolbar2Tk + from matplotlib.backends.backend_tkagg import NavigationToolbar2Tk # type: ignore[attr-defined] except ImportError: - from matplotlib.backends.backend_tkagg import ( + from matplotlib.backends.backend_tkagg import ( # type: ignore[no-redef,attr-defined] NavigationToolbar2TkAgg as NavigationToolbar2Tk, ) @@ -587,11 +588,11 @@ class PlkToolbar(NavigationToolbar2Tk): necessary selections/un-selections on points """ - toolitems = [ + toolitems = tuple( t for t in NavigationToolbar2Tk.toolitems if t[0] in ("Home", "Back", "Forward", "Pan", "Zoom", "Save") - ] + ) class PlkActionsWidget(tk.Frame): diff --git a/src/pint/polycos.py b/src/pint/polycos.py index 686576b9b..e1f160fd9 100644 --- a/src/pint/polycos.py +++ b/src/pint/polycos.py @@ -1,17 +1,17 @@ -"""Polynomial coefficients for phase prediction +r"""Polynomial coefficients for phase prediction Polycos designed to predict the pulsar's phase and pulse-period over a -given interval using polynomial expansions. +given interval using polynomial expansions. The pulse phase and frequency at time T are then calculated as: .. math:: - \\Delta T = 1440(T-T_{\\rm mid}) + \Delta T = 1440(T-T_{\rm mid}) - \\phi = \\phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \\Delta T + COEFF[3] \\Delta T^2 + \\ldots + \phi = \phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \Delta T + COEFF[3] \Delta T^2 + \ldots - f({\\rm Hz}) = f_0 + \\frac{1}{60}\\left( COEFF[2] + 2 COEFF[3] \Delta T + 3 COEFF[4] \Delta T^2 + \\ldots \\right) + f({\rm Hz}) = f_0 + \frac{1}{60}\left( COEFF[2] + 2 COEFF[3] \Delta T + 3 COEFF[4] \Delta T^2 + \ldots \right) Examples -------- @@ -27,25 +27,28 @@ >>> from pint.polycos import Polycos >>> model = get_model(filename) >>> p = Polycos.generate_polycos(model, 50000, 50001, "AO", 144, 12, 1400) - + References ---------- http://tempo.sourceforge.net/ref_man_sections/tz-polyco.txt """ + +from collections import OrderedDict +from collections.abc import Callable +from typing import Dict, List, Union + import astropy.table as table import astropy.units as u import numpy as np from astropy.io import registry from astropy.time import Time -from collections import OrderedDict - from loguru import logger as log try: from tqdm import tqdm -except (ModuleNotFoundError, ImportError) as e: +except (ModuleNotFoundError, ImportError): - def tqdm(*args, **kwargs): + def tqdm(*args, **kwargs): # type: ignore return args[0] if args else kwargs.get("iterable", None) @@ -228,7 +231,7 @@ def evalfreqderiv(self, t): # Read polycos file data to table def tempo_polyco_table_reader(filename): - """Read tempo style polyco file to an astropy table. + r"""Read tempo style polyco file to an astropy table. Tempo style: The polynomial ephemerides are written to file 'polyco.dat'. Entries are listed sequentially within the file. The file format is:: @@ -262,11 +265,11 @@ def tempo_polyco_table_reader(filename): .. math:: - \\Delta T = 1440(T-T_{\\rm mid}) + \Delta T = 1440(T-T_{\rm mid}) - \\phi = \\phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \\Delta T + COEFF[3] \\Delta T^2 + \\ldots + \phi = \phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \Delta T + COEFF[3] \Delta T^2 + \ldots - f({\\rm Hz}) = f_0 + \\frac{1}{60}\\left( COEFF[2] + 2 COEFF[3] \Delta T + 3 COEFF[4] \Delta T^2 + \\ldots \\right) + f({\rm Hz}) = f_0 + \frac{1}{60}\left( COEFF[2] + 2 COEFF[3] \Delta T + 3 COEFF[4] \Delta T^2 + \ldots \right) Parameters ---------- @@ -356,7 +359,7 @@ def tempo_polyco_table_reader(filename): def tempo_polyco_table_writer(polycoTable, filename="polyco.dat"): - """Write tempo style polyco file from an astropy table. + r"""Write tempo style polyco file from an astropy table. Tempo style polyco file: The polynomial ephemerides are written to file 'polyco.dat'. Entries @@ -389,11 +392,11 @@ def tempo_polyco_table_writer(polycoTable, filename="polyco.dat"): .. math:: - \\Delta T = 1440(T-T_{\\rm mid}) + \Delta T = 1440(T-T_{\rm mid}) - \\phi = \\phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \\Delta T + COEFF[3] \\Delta T^2 + \\ldots + \phi = \phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \Delta T + COEFF[3] \Delta T^2 + \ldots - f({\\rm Hz}) = f_0 + \\frac{1}{60}\\left( COEFF[2] + 2 COEFF[3] \Delta T + 3 COEFF[4] \Delta T^2 + \\ldots \\right) + f({\rm Hz}) = f_0 + \frac{1}{60}\left( COEFF[2] + 2 COEFF[3] \Delta T + 3 COEFF[4] \Delta T^2 + \ldots \right) Parameters ---------- @@ -483,7 +486,7 @@ class Polycos: """ # loaded formats - polycoFormats = [] + polycoFormats: List[Dict[str, Union[str, Callable]]] = [] @classmethod def _register(cls, formatlist=_polycoFormats): @@ -918,7 +921,7 @@ def eval_phase(self, t): return self.eval_abs_phase(t).frac def eval_abs_phase(self, t): - """ + r""" Polyco evaluate absolute phase for a time array. Parameters @@ -937,7 +940,7 @@ def eval_abs_phase(self, t): .. math:: - \\phi = \\phi_0 + 60 \\Delta T f_0 + COEFF[1] + COEFF[2] \Delta T + COEFF[3] \Delta T^2 + \\ldots + \phi = \phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \Delta T + COEFF[3] \Delta T^2 + \ldots Calculation done using :meth:`pint.polycos.PolycoEntry.evalabsphase` """ diff --git a/src/pint/pulsar_mjd.py b/src/pint/pulsar_mjd.py index 3c0f14d8d..55f8f6b61 100644 --- a/src/pint/pulsar_mjd.py +++ b/src/pint/pulsar_mjd.py @@ -34,13 +34,6 @@ from astropy.time import Time from astropy.time.formats import TimeFormat -try: - maketrans = str.maketrans -except AttributeError: - # fallback for Python 2 - from string import maketrans - - # This check is implemented in pint.utils, but we want to avoid circular imports if np.finfo(np.longdouble).eps > 2e-19: import warnings @@ -303,7 +296,7 @@ def fortran_float(x): """ try: # First treat it as a string, wih d->e - return float(x.translate(maketrans("Dd", "ee"))) + return float(x.translate(str.maketrans("Dd", "ee"))) except AttributeError: # If that didn't work it may already be a numeric type return float(x) @@ -361,7 +354,7 @@ def str2longdouble(str_data): """ if not isinstance(str_data, (str, bytes)): raise TypeError("Need a string: {!r}".format(str_data)) - return np.longdouble(str_data.translate(maketrans("Dd", "ee"))) + return np.longdouble(str_data.translate(str.maketrans("Dd", "ee"))) # Simplified functions: These core functions, if they can be made to work @@ -453,7 +446,7 @@ def mjds_to_jds_pulsar(mjd1, mjd2): def _str_to_mjds(s): ss = s.lower().strip() if "e" in ss or "d" in ss: - ss = ss.translate(maketrans("d", "e")) + ss = ss.translate(str.maketrans("d", "e")) num, expon = ss.split("e") expon = int(expon) if expon < 0: diff --git a/src/pint/py.typed b/src/pint/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/src/pint/simulation.py b/src/pint/simulation.py index 958e6200d..b8be080eb 100644 --- a/src/pint/simulation.py +++ b/src/pint/simulation.py @@ -1,19 +1,19 @@ """Functions related to simulating TOAs and models """ +import pathlib from collections import OrderedDict from copy import deepcopy -from typing import Optional, List, Union -import pathlib +from typing import List, Optional, Tuple, Union, Dict, overload import astropy.units as u import numpy as np -from loguru import logger as log from astropy import time +from loguru import logger as log +import pint.fitter import pint.residuals import pint.toa -import pint.fitter from pint.observatory import bipm_default, get_observatory __all__ = [ @@ -33,7 +33,7 @@ def zero_residuals( subtract_mean: bool = True, maxiter: int = 10, tolerance: Optional[u.Quantity] = None, -): +) -> None: """Use a model to adjust a TOAs object, setting residuals to 0 iteratively. Parameters @@ -51,7 +51,7 @@ def zero_residuals( 1 nanosecond if operating in full precision or 5 us if not. """ ts.compute_pulse_numbers(model) - maxresid = None + maxresid: Optional[float] = None if tolerance is None: tolerance = 1 * u.ns if pint.utils.check_longdouble_precision() else 5 * u.us for i in range(maxiter): @@ -77,7 +77,7 @@ def get_fake_toa_clock_versions( model: pint.models.timing_model.TimingModel, include_bipm: bool = False, include_gps: bool = True, -) -> dict: +) -> Dict[str, Union[bool, str]]: """Get the clock settings (corrections, etc) for fake TOAs Parameters @@ -220,6 +220,54 @@ def update_fake_dms( return toas +@overload +def make_fake_toas_uniform( + startMJD: float, + endMJD: float, + ntoas: int, + model: pint.models.timing_model.TimingModel, + fuzz: u.Quantity = 0, + freq: u.Quantity = 1400 * u.MHz, + obs: str = "GBT", + error: u.Quantity = 1 * u.us, + add_noise: bool = False, + add_correlated_noise: bool = False, + wideband: bool = False, + wideband_dm_error: u.Quantity = 1e-4 * pint.dmu, + name: str = "fake", + include_bipm: bool = False, + include_gps: bool = True, + multi_freqs_in_epoch: bool = False, + flags: Optional[dict] = None, + subtract_mean: bool = True, +) -> pint.toa.TOAs: + ... + + +@overload +def make_fake_toas_uniform( + startMJD: u.Quantity, + endMJD: u.Quantity, + ntoas: int, + model: pint.models.timing_model.TimingModel, + fuzz: u.Quantity = 0, + freq: u.Quantity = 1400 * u.MHz, + obs: str = "GBT", + error: u.Quantity = 1 * u.us, + add_noise: bool = False, + add_correlated_noise: bool = False, + wideband: bool = False, + wideband_dm_error: u.Quantity = 1e-4 * pint.dmu, + name: str = "fake", + include_bipm: bool = False, + include_gps: bool = True, + multi_freqs_in_epoch: bool = False, + flags: Optional[dict] = None, + subtract_mean: bool = True, +) -> pint.toa.TOAs: + ... + + def make_fake_toas_uniform( startMJD: Union[float, u.Quantity, time.Time], endMJD: Union[float, u.Quantity, time.Time], @@ -566,7 +614,7 @@ def calculate_random_models( keep_models: bool = True, return_time: bool = False, params: str = "all", -) -> (np.ndarray, Optional[list]): +) -> Union[Tuple[np.ndarray, list], np.ndarray]: """ Calculates random models based on the covariance matrix of the `fitter` object. @@ -689,13 +737,35 @@ def calculate_random_models( return (dphase, random_models) if keep_models else dphase +@overload +def _get_freqs_and_times( + start: float, + end: float, + ntoas: int, + freqs: u.Quantity, + multi_freqs_in_epoch: bool = True, +) -> Tuple[np.ndarray, np.ndarray]: + ... + + +@overload +def _get_freqs_and_times( + start: time.Time, + end: time.Time, + ntoas: int, + freqs: u.Quantity, + multi_freqs_in_epoch: bool = True, +) -> Tuple[time.Time, np.ndarray]: + ... + + def _get_freqs_and_times( start: Union[float, u.Quantity, time.Time], end: Union[float, u.Quantity, time.Time], ntoas: int, freqs: u.Quantity, multi_freqs_in_epoch: bool = True, -) -> (Union[float, u.Quantity, time.Time], np.ndarray): +) -> Tuple[Union[np.ndarray, u.Quantity, time.Time], np.ndarray]: freqs = np.atleast_1d(freqs) assert ( len(freqs.shape) == 1 and len(freqs) <= ntoas diff --git a/src/pint/templates/lceprimitives.py b/src/pint/templates/lceprimitives.py index ad0a1a0d0..93f9b8750 100644 --- a/src/pint/templates/lceprimitives.py +++ b/src/pint/templates/lceprimitives.py @@ -1,10 +1,6 @@ from pint.templates.lcprimitives import * -def isvector(x): - return len(np.asarray(x).shape) > 0 - - def edep_gradient(self, grad_func, phases, log10_ens=3, free=False): """Return the analytic gradient of a general LCEPrimitive. @@ -232,7 +228,8 @@ def _einit(self): self.slope_bounds[2] = [-0.3, 0.3] -class LCELorentzian(LCEWrappedFunction, LCLorentzian): +# LCWrappedFunction.derivative doesn't accept index but LCLorentzian.derivative does +class LCELorentzian(LCEWrappedFunction, LCLorentzian): # type: ignore[misc] """Represent a (wrapped) Lorentzian peak. Parameters diff --git a/src/pint/templates/lcprimitives.py b/src/pint/templates/lcprimitives.py index cbcf665ba..de75583cf 100644 --- a/src/pint/templates/lcprimitives.py +++ b/src/pint/templates/lcprimitives.py @@ -661,6 +661,7 @@ def hessian(self, phases, log10_ens=3, free=False): # results[i,:] += gn[i] return results[self.free, self.free] if free else results + # This derivative doesn't accept an index argument, but LCLorentzian does def derivative(self, phases, log10_ens=3, order=1): """Return the phase gradient (dprim/dphi) at a vector of phases. diff --git a/src/pint/templates/lctemplate.py b/src/pint/templates/lctemplate.py index 8e14f0834..2fcf0fdc3 100644 --- a/src/pint/templates/lctemplate.py +++ b/src/pint/templates/lctemplate.py @@ -20,10 +20,6 @@ log = logging.getLogger(__name__) -def isvector(x): - return len(np.asarray(x).shape) > 0 - - class LCTemplate: """Manage a lightcurve template (collection of LCPrimitive objects). @@ -1071,7 +1067,3 @@ def check_gradient_derivative(templ): for i in range(gd.shape[0]): print(np.max(np.abs(gd[i] - ngd[i]))) return pcs, gd, ngd - - -def isvector(x): - return len(np.asarray(x).shape) > 0 diff --git a/src/pint/utils.py b/src/pint/utils.py index 1c0bcf21c..78a25a5d8 100644 --- a/src/pint/utils.py +++ b/src/pint/utils.py @@ -28,6 +28,7 @@ has moved to :mod:`pint.simulation`. """ + import configparser import datetime import getpass @@ -37,11 +38,13 @@ import re import sys import textwrap +import warnings +from collections.abc import Generator, Iterable from contextlib import contextmanager +from copy import deepcopy from pathlib import Path +from typing import IO, Any, Optional, Tuple, Union, List, Dict, Type, Mapping, cast from warnings import warn -from scipy.optimize import minimize -from numdifftools import Hessian import astropy.constants as const import astropy.coordinates as coords @@ -50,16 +53,15 @@ from astropy import constants from astropy.time import Time from loguru import logger as log -from scipy.special import fdtrc +from numdifftools import Hessian from scipy.linalg import cho_factor, cho_solve -from copy import deepcopy -import warnings +from scipy.optimize import minimize +from scipy.special import fdtrc import pint import pint.pulsar_ecliptic from pint.toa_select import TOASelect - __all__ = [ "PINTPrecisionError", "check_longdouble_precision", @@ -114,8 +116,17 @@ "get_unit", ] -COLOR_NAMES = ["black", "red", "green", "yellow", "blue", "magenta", "cyan", "white"] -TEXT_ATTRIBUTES = [ +COLOR_NAMES: List[str] = [ + "black", + "red", + "green", + "yellow", + "blue", + "magenta", + "cyan", + "white", +] +TEXT_ATTRIBUTES: List[str] = [ "normal", "bold", "subdued", @@ -145,7 +156,7 @@ def check_longdouble_precision(): return np.finfo(np.longdouble).eps < 2e-19 -def require_longdouble_precision(): +def require_longdouble_precision() -> None: """Raise an exception if long doubles do not have enough precision. Raises RuntimeError if PINT cannot be run with high precision on this @@ -181,7 +192,13 @@ class PosVel: """ - def __init__(self, pos, vel, obj=None, origin=None): + def __init__( + self, + pos: Union[u.Quantity, np.ndarray], + vel: Union[u.Quantity, np.ndarray], + obj=None, + origin=None, + ): if len(pos) != 3: raise ValueError(f"Position vector has length {len(pos)} instead of 3") self.pos = pos if isinstance(pos, u.Quantity) else np.asarray(pos) @@ -207,13 +224,13 @@ def __init__(self, pos, vel, obj=None, origin=None): self.origin = origin # FIXME: what about dtype compatibility? - def _has_labels(self): + def _has_labels(self) -> bool: return (self.obj is not None) and (self.origin is not None) - def __neg__(self): + def __neg__(self) -> "PosVel": return PosVel(-self.pos, -self.vel, obj=self.origin, origin=self.obj) - def __add__(self, other): + def __add__(self, other: "PosVel") -> "PosVel": obj = None origin = None if self._has_labels() and other._has_labels(): @@ -234,17 +251,17 @@ def __add__(self, other): self.pos + other.pos, self.vel + other.vel, obj=obj, origin=origin ) - def __sub__(self, other): + def __sub__(self, other: "PosVel") -> "PosVel": return self.__add__(other.__neg__()) - def __str__(self): + def __str__(self) -> str: return ( f"PosVel({str(self.pos)}, {str(self.vel)} {self.origin}->{self.obj})" if self._has_labels() else f"PosVel({str(self.pos)}, {str(self.vel)})" ) - def __getitem__(self, k): + def __getitem__(self, k: Union[int, Tuple[int, ...]]) -> "PosVel": """Allow extraction of slices of the contained arrays""" colon = slice(None, None, None) ix = (colon,) + k if isinstance(k, tuple) else (colon, k) @@ -305,7 +322,7 @@ def check_all_partials(f, args, delta=1e-6, atol=1e-4, rtol=1e-4): raise -def has_astropy_unit(x): +def has_astropy_unit(x) -> bool: """Test whether x has a unit attribute containing an astropy unit. This is useful, because different data types can still have units @@ -328,7 +345,7 @@ class PrefixError(ValueError): pass -def split_prefixed_name(name): +def split_prefixed_name(name: str) -> Tuple[str, str, int]: """Split a prefixed name. Parameters @@ -365,17 +382,16 @@ def split_prefixed_name(name): """ for pt in prefix_pattern: - try: - prefix_part, index_part = pt.match(name).groups() + m = pt.match(name) + if m is not None: + prefix_part, index_part = m.groups() break - except AttributeError: - continue else: raise PrefixError(f"Unrecognized prefix name pattern '{name}'.") return prefix_part, index_part, int(index_part) -def taylor_horner(x, coeffs): +def taylor_horner(x: Union[float, np.ndarray, u.Quantity], coeffs): """Evaluate a Taylor series of coefficients at x via the Horner scheme. For example, if we want: 10 + 3*x/1! + 4*x^2/2! + 12*x^3/3! with @@ -444,7 +460,10 @@ def taylor_horner_deriv(x, coeffs, deriv_order=1): @contextmanager -def open_or_use(f, mode="r"): +def open_or_use( + f: Union[str, bytes, Path, IO[Any]], + mode: str = "r", +) -> Generator[IO[Any], None, None]: """Open a filename or use an open file. Specifically, if f is a string, try to use it as an argument to @@ -459,7 +478,7 @@ def open_or_use(f, mode="r"): yield f -def lines_of(f): +def lines_of(f: Union[str, bytes, Path, IO[str]]) -> Generator[str, None, None]: """Iterate over the lines of a file, an open file, or an iterator. If ``f`` is a string, try to open a file of that name. Otherwise @@ -472,7 +491,10 @@ def lines_of(f): yield from fo -def interesting_lines(lines, comments=None): +def interesting_lines( + lines: Iterable[str], + comments: Union[None, str, Iterable[Union[str]]] = None, +) -> Generator[str, None, None]: """Iterate over lines skipping whitespace and comments. Each line has its whitespace stripped and then it is checked whether @@ -480,6 +502,7 @@ def interesting_lines(lines, comments=None): a list of strings. """ + cc: Tuple[str, ...] if comments is None: cc = () elif isinstance(comments, (str, bytes)): @@ -490,8 +513,8 @@ def interesting_lines(lines, comments=None): cs = c.strip() if not cs or not c.startswith(cs): raise ValueError( - "Unable to deal with comments that start with whitespace, " - "but comment string {!r} was requested.".format(c) + f"Unable to deal with comments that start with whitespace, " + f"but comment string {c:!r} was requested." ) for ln in lines: ln = ln.strip() @@ -1077,7 +1100,7 @@ def dmxparse(fitter, save=False): } -def get_prefix_timerange(model, prefixname): +def get_prefix_timerange(model, prefixname: str) -> Tuple[Time, Time]: """Get time range for a prefix quantity like DMX or SWX Parameters @@ -1105,7 +1128,7 @@ def get_prefix_timerange(model, prefixname): return getattr(model, r1).quantity, getattr(model, r2).quantity -def get_prefix_timeranges(model, prefixname): +def get_prefix_timeranges(model, prefixname: str) -> Tuple[np.ndarray, Time, Time]: """Get all time ranges and indices for a prefix quantity like DMX or SWX Parameters @@ -1142,7 +1165,9 @@ def get_prefix_timeranges(model, prefixname): ) -def find_prefix_bytime(model, prefixname, t): +def find_prefix_bytime( + model, prefixname: str, t: Union[Time, u.Quantity] +) -> Union[int, np.ndarray]: """Identify matching index(es) for a prefix parameter like DMX Parameters @@ -1163,11 +1188,14 @@ def find_prefix_bytime(model, prefixname, t): indices, r1, r2 = get_prefix_timeranges(model, prefixname) matches = np.where((t >= r1) & (t < r2))[0] if len(matches) == 1: - matches = int(matches) - return indices[matches] + return int(indices[int(matches)]) + else: + return indices[matches] -def merge_dmx(model, index1, index2, value="mean", frozen=True): +def merge_dmx( + model, index1: int, index2: int, value: str = "mean", frozen: bool = True +) -> int: """Merge two DMX bins Parameters @@ -1197,7 +1225,7 @@ def merge_dmx(model, index1, index2, value="mean", frozen=True): ) if value.lower() == "first": dmx = getattr(model, f"DMX_{index1:04d}").quantity - elif value.lower == "second": + elif value.lower() == "second": dmx = getattr(model, f"DMX_{index2:04d}").quantity elif value.lower() == "mean": dmx = ( @@ -1205,14 +1233,13 @@ def merge_dmx(model, index1, index2, value="mean", frozen=True): + getattr(model, f"DMX_{index2:04d}").quantity ) / 2 # add the new one before we delete previous ones to make sure we have >=1 present - newindex = model.add_DMX_range(tstart, tend, dmx=dmx, frozen=frozen) + newindex: int = model.add_DMX_range(tstart, tend, dmx=dmx, frozen=frozen) model.remove_DMX_range([index1, index2]) return newindex -def split_dmx(model, time): - """ - Split an existing DMX bin at the desired time +def split_dmx(model, time: Time) -> Tuple[int, int]: + """Split an existing DMX bin at the desired time. Parameters ---------- @@ -1234,10 +1261,10 @@ def split_dmx(model, time): dmx_epochs = [f"{x:04d}" for x in DMX_mapping.keys()] DMX_R1 = np.zeros(len(dmx_epochs)) DMX_R2 = np.zeros(len(dmx_epochs)) - for ii, epoch in enumerate(dmx_epochs): - DMX_R1[ii] = getattr(model, "DMXR1_{:}".format(epoch)).value - DMX_R2[ii] = getattr(model, "DMXR2_{:}".format(epoch)).value - ii = np.where((time.mjd > DMX_R1) & (time.mjd < DMX_R2))[0] + for iii, epoch in enumerate(dmx_epochs): + DMX_R1[iii] = getattr(model, "DMXR1_{:}".format(epoch)).value + DMX_R2[iii] = getattr(model, "DMXR2_{:}".format(epoch)).value + ii: np.ndarray = np.where((time.mjd > DMX_R1) & (time.mjd < DMX_R2))[0] if len(ii) == 0: raise ValueError(f"Time {time} not in any DMX bins") ii = ii[0] @@ -1255,9 +1282,8 @@ def split_dmx(model, time): return index, newindex -def split_swx(model, time): - """ - Split an existing SWX bin at the desired time +def split_swx(model, time: Time) -> Tuple[int, int]: + """Split an existing SWX bin at the desired time. Parameters ---------- @@ -1270,7 +1296,6 @@ def split_swx(model, time): Index of existing bin that was split newindex : int Index of new bin that was added - """ try: SWX_mapping = model.get_prefix_mapping("SWX_") @@ -1279,9 +1304,9 @@ def split_swx(model, time): swx_epochs = [f"{x:04d}" for x in SWX_mapping.keys()] SWX_R1 = np.zeros(len(swx_epochs)) SWX_R2 = np.zeros(len(swx_epochs)) - for ii, epoch in enumerate(swx_epochs): - SWX_R1[ii] = getattr(model, "SWXR1_{:}".format(epoch)).value - SWX_R2[ii] = getattr(model, "SWXR2_{:}".format(epoch)).value + for iii, epoch in enumerate(swx_epochs): + SWX_R1[iii] = getattr(model, "SWXR1_{:}".format(epoch)).value + SWX_R2[iii] = getattr(model, "SWXR2_{:}".format(epoch)).value ii = np.where((time.mjd > SWX_R1) & (time.mjd < SWX_R2))[0] if len(ii) == 0: raise ValueError(f"Time {time} not in any SWX bins") @@ -1301,7 +1326,8 @@ def split_swx(model, time): def wavex_setup(model, T_span, freqs=None, n_freqs=None, freeze_params=False): - """ + """Set up a WaveX model. + Set-up a WaveX model based on either an array of user-provided frequencies or the wave number frequency calculation. Sine and Cosine amplitudes are initially set to zero @@ -1644,9 +1670,10 @@ def get_wavex_amps(model, index=None, quantity=False): model.components["WaveX"].get_prefix_mapping_component("WXSIN_").keys() ) if len(indices) == 1: - values = getattr( - model.components["WaveX"], f"WXSIN_{int(indices):04d}" - ), getattr(model.components["WaveX"], f"WXCOS_{int(indices):04d}") + values = ( + getattr(model.components["WaveX"], f"WXSIN_{int(indices):04d}"), + getattr(model.components["WaveX"], f"WXCOS_{int(indices):04d}"), + ) else: values = [ ( @@ -1657,8 +1684,9 @@ def get_wavex_amps(model, index=None, quantity=False): ] elif isinstance(index, (int, float, np.int64)): idx_rf = f"{int(index):04d}" - values = getattr(model.components["WaveX"], f"WXSIN_{idx_rf}"), getattr( - model.components["WaveX"], f"WXCOS_{idx_rf}" + values = ( + getattr(model.components["WaveX"], f"WXSIN_{idx_rf}"), + getattr(model.components["WaveX"], f"WXCOS_{idx_rf}"), ) elif isinstance(index, (list, set, np.ndarray)): idx_rf = [f"{int(idx):04d}" for idx in index] @@ -1723,7 +1751,13 @@ def translate_wavex_to_wave(model): return new_model -def weighted_mean(arrin, weights_in, inputmean=None, calcerr=False, sdev=False): +def weighted_mean( + arrin: np.ndarray, + weights_in: np.ndarray, + inputmean: Optional[float] = None, + calcerr: bool = False, + sdev: bool = False, +) -> Union[Tuple[float, float], Tuple[float, float, float]]: """Compute weighted mean of input values Calculate the weighted mean, error, and optionally standard deviation of @@ -1734,10 +1768,10 @@ def weighted_mean(arrin, weights_in, inputmean=None, calcerr=False, sdev=False): Parameters ---------- arrin : array - Array containing the numbers whose weighted mean is desired. + Array containing the numbers whose weighted mean is desired. weights: array - A set of weights for each element in array. For measurements with - uncertainties, these should be 1/sigma^2. + A set of weights for each element in array. For measurements with + uncertainties, these should be 1/sigma^2. inputmean: float, optional An input mean value, around which the mean is calculated. calcerr : bool, optional @@ -1751,8 +1785,8 @@ def weighted_mean(arrin, weights_in, inputmean=None, calcerr=False, sdev=False): Returns ------- wmean, werr: tuple - A tuple of the weighted mean and error. If sdev=True the - tuple will also contain sdev: wmean,werr,wsdev + A tuple of the weighted mean and error. If sdev=True the + tuple will also contain sdev: wmean,werr,wsdev Notes ----- @@ -1784,12 +1818,12 @@ def weighted_mean(arrin, weights_in, inputmean=None, calcerr=False, sdev=False): def ELL1_check( A1: u.cm, E: u.dimensionless_unscaled, TRES: u.us, NTOA: int, outstring=True ): - """Check for validity of assumptions in ELL1 binary model + r"""Check for validity of assumptions in ELL1 binary model Checks whether the assumptions that allow ELL1 to be safely used are satisfied. To work properly, we should have: - :math:`asini/c e^4 \ll {\\rm timing precision} / \sqrt N_{\\rm TOA}` - or :math:`A1 E^4 \ll TRES / \sqrt N_{\\rm TOA}` + :math:`asini/c e^4 \ll {\rm timing precision} / \sqrt N_{\rm TOA}` + or :math:`A1 E^4 \ll TRES / \sqrt N_{\rm TOA}` since the ELL1 model now includes terms up to O(E^3) @@ -1810,7 +1844,6 @@ def ELL1_check( bool or str Returns True if ELL1 is safe to use, otherwise False. If outstring is True then returns a string summary instead. - """ lhs = A1 / const.c * E**4.0 rhs = TRES / np.sqrt(NTOA) @@ -1838,7 +1871,7 @@ def ELL1_check( return False -def FTest(chi2_1, dof_1, chi2_2, dof_2): +def FTest(chi2_1: float, dof_1: int, chi2_2: float, dof_2: int) -> float: """Run F-test. Compute an F-test to see if a model with extra parameters is @@ -1875,7 +1908,7 @@ def FTest(chi2_1, dof_1, chi2_2, dof_2): delta_dof = dof_1 - dof_2 new_redchi2 = chi2_2 / dof_2 F = float((delta_chi2 / delta_dof) / new_redchi2) # fdtr doesn't like float128 - return fdtrc(delta_dof, dof_2, F) + return float(fdtrc(delta_dof, dof_2, F)) elif dof_1 == dof_2: log.warning("Models have equal degrees of freedom, cannot perform F-test.") return np.nan @@ -1886,7 +1919,9 @@ def FTest(chi2_1, dof_1, chi2_2, dof_2): return 1.0 -def add_dummy_distance(c, distance=1 * u.kpc): +def add_dummy_distance( + c: coords.SkyCoord, distance: u.Quantity = 1 * u.kpc +) -> coords.SkyCoord: """Adds a dummy distance to a SkyCoord object for applying proper motion Parameters @@ -1958,7 +1993,7 @@ def add_dummy_distance(c, distance=1 * u.kpc): return c -def remove_dummy_distance(c): +def remove_dummy_distance(c: coords.SkyCoord) -> coords.SkyCoord: """Removes a dummy distance from a SkyCoord object after applying proper motion Parameters @@ -2023,7 +2058,9 @@ def remove_dummy_distance(c): return c -def info_string(prefix_string="# ", comment=None, detailed=False): +def info_string( + prefix_string: str = "# ", comment: Optional[str] = None, detailed: bool = False +) -> str: """Returns an informative string about the current state of PINT. Adds: @@ -2131,7 +2168,7 @@ def info_string(prefix_string="# ", comment=None, detailed=False): # user-level git config c = git.GitConfigParser() - username = c.get_value("user", option="name") + f" ({getpass.getuser()})" + username = str(c.get_value("user", option="name")) + f" ({getpass.getuser()})" except (configparser.NoOptionError, configparser.NoSectionError, ImportError): username = getpass.getuser() @@ -2145,13 +2182,14 @@ def info_string(prefix_string="# ", comment=None, detailed=False): } if detailed: - from numpy import __version__ as numpy_version - from scipy import __version__ as scipy_version from astropy import __version__ as astropy_version from erfa import __version__ as erfa_version from jplephem import __version__ as jpleph_version + from loguru import __version__ as loguru_version # type: ignore[attr-defined] from matplotlib import __version__ as matplotlib_version - from loguru import __version__ as loguru_version + from numpy import __version__ as numpy_version + from scipy import __version__ as scipy_version + from pint import __file__ as pint_file info_dict.update( @@ -2204,7 +2242,7 @@ def info_string(prefix_string="# ", comment=None, detailed=False): return s -def list_parameters(class_=None): +def list_parameters(class_=None) -> List[Dict]: """List parameters understood by PINT. Parameters @@ -2264,7 +2302,7 @@ def list_parameters(class_=None): results = {} ct = pint.models.timing_model.Component.component_types.copy() - ct["TimingModel"] = pint.models.timing_model.TimingModel + ct["TimingModel"] = pint.models.timing_model.TimingModel # type: ignore[assignment] for v in ct.values(): for d in list_parameters(v): n = d["name"] @@ -2283,7 +2321,12 @@ def list_parameters(class_=None): return sorted(results.values(), key=lambda d: d["name"]) -def colorize(text, fg_color=None, bg_color=None, attribute=None): +def colorize( + text: str, + fg_color: Optional[str] = None, + bg_color: Optional[str] = None, + attribute: Optional[str] = None, +) -> str: """Colorizes a string (including unicode strings) for printing on the terminal For an example of usage, as well as a demonstration as to what the @@ -2310,9 +2353,11 @@ def colorize(text, fg_color=None, bg_color=None, attribute=None): The colorized string using the defined text attribute. """ COLOR_FORMAT = "\033[%dm\033[%d;%dm%s\033[0m" - FOREGROUND = dict(zip(COLOR_NAMES, list(range(30, 38)))) - BACKGROUND = dict(zip(COLOR_NAMES, list(range(40, 48)))) - ATTRIBUTE = dict(zip(TEXT_ATTRIBUTES, [0, 1, 2, 3, 4, 5, 7, 8])) + FOREGROUND: Dict[Optional[str], int] = dict(zip(COLOR_NAMES, list(range(30, 38)))) + BACKGROUND: Dict[Optional[str], int] = dict(zip(COLOR_NAMES, list(range(40, 48)))) + ATTRIBUTE: Dict[Optional[str], int] = dict( + zip(TEXT_ATTRIBUTES, [0, 1, 2, 3, 4, 5, 7, 8]) + ) fg = FOREGROUND.get(fg_color, 39) bg = BACKGROUND.get(bg_color, 49) att = ATTRIBUTE.get(attribute, 0) @@ -2331,7 +2376,7 @@ def print_color_examples(): print("") -def group_iterator(items): +def group_iterator(items: np.ndarray) -> Generator[Tuple[Any, np.ndarray], None, None]: """An iterator to step over identical items in a :class:`numpy.ndarray` Example @@ -2348,7 +2393,7 @@ def group_iterator(items): yield item, np.where(items == item)[0] -def compute_hash(filename): +def compute_hash(filename: Union[str, Path, IO[bytes]]) -> bytes: """Compute a unique hash of a file. This is designed to keep around to detect changes, not to be @@ -2377,9 +2422,10 @@ def compute_hash(filename): return h.digest() -def get_conjunction(coord, t0, precision="low", ecl="IERS2010"): - """ - Find first time of Solar conjuction after t0 and approximate elongation at conjunction +def get_conjunction( + coord: coords.SkyCoord, t0: Time, precision: str = "low", ecl: str = "IERS2010" +) -> Tuple[Time, u.Quantity]: + """Find first time of Solar conjuction after t0 and approximate elongation at conjunction. Offers a low-precision version (based on analytic expression of Solar longitude) Or a higher-precision version (based on interpolating :func:`astropy.coordinates.get_sun`) @@ -2444,9 +2490,8 @@ def get_conjunction(coord, t0, precision="low", ecl="IERS2010"): return conjunction, csun.separation(coord) -def divide_times(t, t0, offset=0.5): - """ - Divide input times into years relative to t0 +def divide_times(t: Time, t0: Time, offset: float = 0.5) -> np.ndarray: + """Divide input times into years relative to t0. Years are centered around the requested offset value @@ -2478,7 +2523,7 @@ def divide_times(t, t0, offset=0.5): """ dt = t - t0 values = (dt.to(u.yr).value + offset) // 1 - return np.digitize(values, np.unique(values), right=True) + return cast(np.ndarray, np.digitize(values, np.unique(values), right=True)) def convert_dispersion_measure(dm, dmconst=None): @@ -2734,8 +2779,8 @@ def woodbury_dot(Ndiag, U, Phidiag, x, y): def _get_wx2pl_lnlike(model, component_name, ignore_fyr=True): - from pint.models.noise_model import powerlaw from pint import DMconst + from pint.models.noise_model import powerlaw assert component_name in ["WaveX", "DMWaveX"] prefix = "WX" if component_name == "WaveX" else "DMWX" diff --git a/tox.ini b/tox.ini index 05ed53ca3..e9126a3fa 100644 --- a/tox.ini +++ b/tox.ini @@ -13,6 +13,7 @@ envlist = report codestyle black + mypy py{38,39,310,311,312}-test{,-alldeps,-devdeps}{,-cov} skip_missing_interpreters = True @@ -131,8 +132,19 @@ commands = sphinx-build -d "{toxworkdir}/docs_doctree" . "{toxworkdir}/docs_out" skip_install = true changedir = . description = use black +basepython = python3.12 deps = black~=23.0 -commands = black --check src tests examples +commands = black src tests examples {posargs:--check} +[testenv:mypy] +changedir = . +description = use mypy +basepython = python3.12 +deps = + mypy==1.8.0 + GitPython + types-setuptools + types-tqdm +commands = mypy --no-incremental {posargs} \ No newline at end of file