Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assure that gaussfamily likelihoods have methods called in the right order #356

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ jobs:
pylint --rcfile firecrown/models/pylintrc firecrown/models
- name: Running pytest
shell: bash -l {0}
run: python -m pytest -vv
run: python -m pytest -vv --runslow
- name: Running example - cosmosis - cosmic-shear
shell: bash -l {0}
run: |
Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
# -- Project information -----------------------------------------------------

project = "firecrown"
copyright = "2022, LSST DESC Firecrown Contributors"
copyright = "2022--2024, LSST DESC Firecrown Contributors"
author = "LSST DESC Firecrown Contributors"

# The full version, including alpha/beta/rc tags
release = "1.1"
release = "1.6.1a0"


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/developer_installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ Please also file an issue in the GitHub issue tracker describing the failure.
# Update the pip-installed products.
# The --no-deps flag is critical to avoid accidentally installing new packages
# with pip (rather than with conda).
python -m pip install --upgrade --no-deps autoclasstoc cobaya
python -m pip install --upgrade --no-deps autoclasstoc cobaya pygobject-stubs
# Rebuild the CSL
cd ${CSL_DIR}
# Optionally, you may want to update to the newest version of the CSL
Expand Down
63 changes: 62 additions & 1 deletion firecrown/likelihood/gauss_family/gauss_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

from __future__ import annotations
from enum import Enum
from typing import List, Optional, Tuple, Sequence
from typing import final
import warnings
Expand All @@ -18,24 +19,46 @@

import sacc

from ...parameters import ParamsMap
from ..likelihood import Likelihood
from ...modeling_tools import ModelingTools
from ...updatable import UpdatableCollection
from .statistic.statistic import Statistic, GuardedStatistic


class State(Enum):
"""The states used in GaussFamily."""

INITIALIZED = 1
READY = 2
UPDATED = 3


class GaussFamily(Likelihood):
"""GaussFamily is an abstract class. It is the base class for all likelihoods
based on a chi-squared calculation. It provides an implementation of
Likelihood.compute_chisq. Derived classes must implement the abstract method
compute_loglike, which is inherited from Likelihood.

GaussFamily (and all classes that inherit from it) must abide by the the
following rules regarding the order of calling of methods.

1. after a new object is created, :meth:`read` must be called before any
other method in the interfaqce.
2. after :meth:`read` has been called it is legal to call
:meth:`get_data_vector`, or to call :meth:`update`.
3. after :meth:`update` is called it is then legal to call
:meth:`calculate_loglike` or :meth:`get_data_vector`, or to reset
the object (returning to the pre-update state) by calling
:meth:`reset`.
"""

def __init__(
self,
statistics: Sequence[Statistic],
):
super().__init__()
self.state: State = State.INITIALIZED
if len(statistics) == 0:
raise ValueError("GaussFamily requires at least one statistic")
self.statistics: UpdatableCollection = UpdatableCollection(
Expand All @@ -45,9 +68,28 @@ def __init__(
self.cholesky: Optional[npt.NDArray[np.float64]] = None
self.inv_cov: Optional[npt.NDArray[np.float64]] = None

def _update(self, _: ParamsMap) -> None:
"""Handle the state resetting required by :class:`GaussFamily`
likelihoods. Any derived class that needs to implement :meth:`_update`
for its own reasons must be sure to do what this does: check the state
at the start of the method, and change the state at the end of the
method."""
assert self.state == State.READY, "read() must be called before update()"
self.state = State.UPDATED

def _reset(self) -> None:
"""Handle the state resetting required by :class:`GaussFamily`
likelihoods. Any derived class that needs to implement :meth:`reset`
for its own reasons must be sure to do what this does: check the state
at the start of the method, and change the state at the end of the
method."""
assert self.state == State.UPDATED, "update() must be called before reset()"
self.state = State.READY

def read(self, sacc_data: sacc.Sacc) -> None:
"""Read the covariance matrix for this likelihood from the SACC file."""

assert self.state == State.INITIALIZED, "read() must only be called once"
if sacc_data.covariance is None:
msg = (
f"The {type(self).__name__} likelihood requires a covariance, "
Expand All @@ -71,20 +113,26 @@ def read(self, sacc_data: sacc.Sacc) -> None:
self.cholesky = scipy.linalg.cholesky(self.cov, lower=True)
self.inv_cov = np.linalg.inv(cov)

self.state = State.READY

@final
def get_cov(self) -> npt.NDArray[np.float64]:
"""Gets the current covariance matrix."""
assert self._is_ready(), "read() must be called before get_cov()"
assert self.cov is not None
# We do not change the state.
return self.cov

@final
def get_data_vector(self) -> npt.NDArray[np.float64]:
"""Get the data vector from all statistics and concatenate in the right
order."""
assert self._is_ready(), "read() must be called before get_data_vector()"

data_vector_list: List[npt.NDArray[np.float64]] = [
stat.get_data_vector() for stat in self.statistics
]
# We do not change the state.
return np.concatenate(data_vector_list)

@final
Expand All @@ -93,18 +141,22 @@ def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64]

:param tools: Current ModelingTools object
"""
assert (
self.state == State.UPDATED
), "update() must be called before compute_theory_vector()"

theory_vector_list: List[npt.NDArray[np.float64]] = [
stat.compute_theory_vector(tools) for stat in self.statistics
]
# We do not change the state
return np.concatenate(theory_vector_list)

@final
def compute(
self, tools: ModelingTools
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
"""Calculate and return both the data and theory vectors."""

assert self.state == State.UPDATED, "update() must be called before compute()"
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"The use of the `compute` method on Statistic is deprecated."
Expand All @@ -113,11 +165,15 @@ def compute(
category=DeprecationWarning,
)

# We do not change the state.
return self.get_data_vector(), self.compute_theory_vector(tools)

@final
def compute_chisq(self, tools: ModelingTools) -> float:
"""Calculate and return the chi-squared for the given cosmology."""
assert (
self.state == State.UPDATED
), "update() must be called before compute_chisq()"
theory_vector: npt.NDArray[np.float64]
data_vector: npt.NDArray[np.float64]
residuals: npt.NDArray[np.float64]
Expand All @@ -136,4 +192,9 @@ def compute_chisq(self, tools: ModelingTools) -> float:
x = scipy.linalg.solve_triangular(self.cholesky, residuals, lower=True)
chisq = np.dot(x, x)

# We do not change the state.
return chisq

def _is_ready(self) -> bool:
"""Return True if the state is either READY or UPDATED."""
return self.state in (State.READY, State.UPDATED)
2 changes: 1 addition & 1 deletion firecrown/likelihood/gauss_family/statistic/statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
================================

The Statistic class describing objects that implement methods to compute the
data and theory vectors for a GaussianFamily subclass.
data and theory vectors for a :class:`GaussFamily` subclass.

"""

Expand Down
6 changes: 6 additions & 0 deletions tests/likelihood/gauss_family/test_student_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def test_require_nonempty_statistics():
_ = StudentT(statistics=[])


def test_update_fails_before_read(trivial_stats, trivial_params_student_t):
likelihood = StudentT(statistics=trivial_stats)
with pytest.raises(AssertionError):
likelihood.update(trivial_params_student_t)


def test_get_cov_fails_before_read(trivial_stats):
likelihood = StudentT(statistics=trivial_stats)
with pytest.raises(AssertionError):
Expand Down
Loading