Skip to content

Commit

Permalink
Merge pull request #10 from StFroese/update_validation
Browse files Browse the repository at this point in the history
separate plotting for validation
  • Loading branch information
StFroese authored Aug 31, 2023
2 parents c43d124 + 2c78ef7 commit 885200a
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 420 deletions.
476 changes: 143 additions & 333 deletions examples/Analysis.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = titrate
version = 0.4.0
version = 0.4.1
author = Stefan Fröse
author_email = [email protected]
description = asympTotic lIkelihood Tests for daRk mAtTer sEarch
Expand Down
136 changes: 120 additions & 16 deletions titrate/plotting.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import h5py
import matplotlib.pyplot as plt
import numpy as np
from astropy import visualization as viz
from astropy.table import QTable, unique
from astropy.units import Quantity

from titrate.datasets import AsimovMapDataset
from titrate.statistics import QMuTestStatistic, QTildeMuTestStatistic

STATISTICS = {"qmu": QMuTestStatistic, "qtildemu": QTildeMuTestStatistic}


class UpperLimitPlotter:
def __init__(self, path, channel, axes=None):
def __init__(self, path, channel, ax=None):
self.path = path
self.axes = axes if axes is not None else plt.gca()
self.ax = ax if ax is not None else plt.gca()

try:
table = QTable.read(self.path, path=channel)
table = QTable.read(self.path, path=f"upperlimits/{channel}")
except OSError:
channels = list(h5py.File("/Users/stefan/Downloads/test.hdf5").keys())
channels = list(h5py.File(self.path).keys())
channels = [ch for ch in channels if "meta" not in ch]
raise KeyError(
f"Channel {channel} not in dataframe. " f"Choose from {channels}"
Expand All @@ -39,21 +46,21 @@ def __init__(self, path, channel, axes=None):
two_sigma_plus,
)

self.axes.set_xscale("log")
self.axes.set_yscale("log")
self.ax.set_xscale("log")
self.ax.set_yscale("log")

cl_type = unique(table[table["channel"] == self.channel], keys="cl_type")[
"cl_type"
][0]
cl = unique(table[table["channel"] == self.channel], keys="cl")["cl"][0]
self.axes.set_xlabel(f"m / {masses.unit:latex}")
self.axes.set_ylabel(
self.ax.set_xlabel(f"m / {masses.unit:latex}")
self.ax.set_ylabel(
rf"$CL_{cl_type}^{{{cl}}}$ upper limit on $< \sigma v>$ / {uls.unit:latex}"
)

self.axes.set_title(f"Annihilation Upper Limits for channel {self.channel}")
self.ax.set_title(f"Annihilation Upper Limits for channel {self.channel}")

self.axes.legend()
self.ax.legend()

def plot_channel(
self,
Expand All @@ -65,27 +72,124 @@ def plot_channel(
two_sigma_minus,
two_sigma_plus,
):
self.axes.plot(masses, uls, color="tab:orange", label="Upper Limits")
self.axes.plot(masses, median, color="tab:blue", label="Expected Upper Limits")
self.axes.fill_between(
self.ax.plot(masses, uls, color="tab:orange", label="Upper Limits")
self.ax.plot(masses, median, color="tab:blue", label="Expected Upper Limits")
self.ax.fill_between(
masses,
median,
one_sigma_plus,
color="tab:blue",
alpha=0.75,
label=r"$1\sigma$-region",
)
self.axes.fill_between(
self.ax.fill_between(
masses, median, one_sigma_minus, color="tab:blue", alpha=0.75
)
self.axes.fill_between(
self.ax.fill_between(
masses,
one_sigma_plus,
two_sigma_plus,
color="tab:blue",
alpha=0.5,
label=r"$2\sigma$-region",
)
self.axes.fill_between(
self.ax.fill_between(
masses, one_sigma_minus, two_sigma_minus, color="tab:blue", alpha=0.5
)


class ValidationPlotter:
def __init__(
self,
measurement_dataset,
path,
channel=None,
mass=None,
statistic="qmu",
poi_name="scale",
ax=None,
):
self.path = path
self.ax = ax if ax is not None else plt.gca()

asimov_dataset = AsimovMapDataset.from_MapDataset(measurement_dataset)

try:
table = QTable.read(
self.path, path=f"validation/{statistic}/{channel}/{mass}"
)
except OSError:
if channel is None:
channels = list(h5py.File(self.path)["validation"][statistic].keys())
channels = [ch for ch in channels if "meta" not in ch]
raise ValueError(f"Channel must be one of {channels}")
if mass is None:
masses = list(
h5py.File(self.path)["validation"][statistic][channel].keys()
)
masses = [Quantity(m) for m in masses if "meta" not in m]
raise ValueError(f"Mass must be one of {masses}")

toys_ts_same = table["toys_ts_same"]
toys_ts_diff = table["toys_ts_diff"]

max_ts = max(toys_ts_diff.max(), toys_ts_same.max())
bins = np.linspace(0, max_ts, 31)
linspace = np.linspace(0, max_ts, 1000)
statistic = STATISTICS[statistic](asimov_dataset, poi_name)
statistic_math_name = (
r"q_\mu" if isinstance(statistic, QMuTestStatistic) else r"\tilde{q}_\mu"
)

self.plot(
linspace, bins, toys_ts_same, toys_ts_diff, statistic, statistic_math_name
)

self.ax.set_yscale("log")
self.ax.set_xlim(0, max_ts)

self.ax.set_ylabel("pdf")
self.ax.set_xlabel(rf"${statistic_math_name}$")
self.ax.set_title(statistic.__class__.__name__)
self.ax.legend()

def plot(
self, linspace, bins, toys_ts_same, toys_ts_diff, statistic, statistic_math_name
):
plt.hist(
toys_ts_diff,
bins=bins,
density=True,
histtype="step",
color="tab:blue",
label=(
rf"$f({statistic_math_name}\vert\mu^\prime)$, "
r"$\mu=1$, $\mu^\prime=0$"
),
)
plt.hist(
toys_ts_same,
bins=bins,
density=True,
histtype="step",
color="tab:orange",
label=(
rf"$f({statistic_math_name}\vert\mu^\prime)$, "
r"$\mu=1$, $\mu^\prime=1$"
),
)

plt.plot(
linspace,
statistic.asympotic_approximation_pdf(
poi_val=1, same=False, poi_true_val=0, ts_val=linspace
),
color="tab:blue",
label=rf"$f({statistic_math_name}\vert\mu^\prime)$, asympotic",
)
plt.plot(
linspace,
statistic.asympotic_approximation_pdf(poi_val=1, ts_val=linspace),
color="tab:orange",
label=rf"$f({statistic_math_name}\vert\mu^\prime)$, asympotic",
)
4 changes: 2 additions & 2 deletions titrate/tests/test_upperlimits.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def upperlimits_file(jfact_map, measurement_dataset, tmp_path_factory):

@pytest.mark.parametrize("channel", ["b", "W"])
def test_ULFactory(upperlimits_file, channel):
table = QTable.read(upperlimits_file, path=channel)
table = QTable.read(upperlimits_file, path=f"upperlimits/{channel}")
assert np.all(table["mass"] == np.geomspace(0.1, 100, 5) * u.TeV)
assert len(table["ul"]) == 5
assert len(table["median_ul"]) == 5
Expand All @@ -64,4 +64,4 @@ def test_UpperLimitPlotter(upperlimits_file):
fig, axs = plt.subplots(nrows=1, ncols=2)

for channel, ax in zip(["b", "W"], np.array(axs).reshape(-1)):
UpperLimitPlotter(upperlimits_file, channel=channel, axes=ax)
UpperLimitPlotter(upperlimits_file, channel=channel, ax=ax)
61 changes: 51 additions & 10 deletions titrate/tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import astropy.units as u
import numpy as np
import pytest


def test_AsmyptoticValidator(measurement_dataset, asimov_dataset):
@pytest.fixture(scope="module")
def validation_file(measurement_dataset, tmp_path_factory):
from titrate.validation import AsymptoticValidator

validator = AsymptoticValidator(measurement_dataset, asimov_dataset, "qmu", "scale")
data = tmp_path_factory.mktemp("data")

validator = AsymptoticValidator(measurement_dataset, "qmu", "scale")
result = validator.validate(n_toys=10)
assert list(result.keys()) == ["pvalue_diff", "pvalue_same", "valid"]
assert result["pvalue_diff"] != 0
Expand All @@ -14,19 +18,56 @@ def test_AsmyptoticValidator(measurement_dataset, asimov_dataset):
assert result["pvalue_same"] != np.nan
assert isinstance(result["valid"], np.bool_)

# same for qtildemu
validator_tilde = AsymptoticValidator(
measurement_dataset, asimov_dataset, "qtildemu", "scale"
)
with pytest.raises(ValueError) as excinfo:
AsymptoticValidator(measurement_dataset, "stupidTest", "scale")

assert str(excinfo.value) == "Statistic must be one of ['qmu', 'qtildemu']"

validator.save_toys(f"{data}/val.h5")

validator_tilde = AsymptoticValidator(measurement_dataset, "qtildemu", "scale")
result_tilde = validator_tilde.validate(n_toys=10)
assert list(result_tilde.keys()) == ["pvalue_diff", "pvalue_same", "valid"]
assert list(result.keys()) == ["pvalue_diff", "pvalue_same", "valid"]
assert result_tilde["pvalue_diff"] != 0
assert result_tilde["pvalue_diff"] != np.nan
assert result_tilde["pvalue_same"] != 0
assert result_tilde["pvalue_same"] != np.nan
assert isinstance(result_tilde["valid"], np.bool_)

with pytest.raises(ValueError) as excinfo:
AsymptoticValidator(measurement_dataset, asimov_dataset, "stupidTest", "scale")
validator_tilde.save_toys(f"{data}/val.h5")

assert str(excinfo.value) == "Statistic must be one of ['qmu', 'qtildemu']"
return f"{data}/val.h5"


@pytest.mark.parametrize("statistic", ["qmu", "qtildemu"])
def test_AsmyptoticValidator(measurement_dataset, statistic, validation_file):
from titrate.validation import AsymptoticValidator

validator = AsymptoticValidator(
measurement_dataset,
statistic=statistic,
path=validation_file,
channel="b",
mass=50 * u.TeV,
)
result = validator.validate()

assert list(result.keys()) == ["pvalue_diff", "pvalue_same", "valid"]
assert result["pvalue_diff"] != 0
assert result["pvalue_diff"] != np.nan
assert result["pvalue_same"] != 0
assert result["pvalue_same"] != np.nan
assert isinstance(result["valid"], np.bool_)


@pytest.mark.parametrize("statistic", ["qmu", "qtildemu"])
def test_ValidationPlotter(measurement_dataset, statistic, validation_file):
from titrate.plotting import ValidationPlotter

ValidationPlotter(
measurement_dataset,
path=validation_file,
statistic=statistic,
channel="b",
mass=50 * u.TeV,
)
2 changes: 1 addition & 1 deletion titrate/upperlimits.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def save_results(self, path, overwrite=False, **kwargs):
qtable.write(
path,
format="hdf5",
path=f"{channel}",
path=f"upperlimits/{channel}",
overwrite=overwrite,
append=True,
serialize_meta=True,
Expand Down
Loading

0 comments on commit 885200a

Please sign in to comment.