Skip to content

Commit

Permalink
Rename false negative to true negative threshold (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso authored Sep 28, 2023
1 parent fdcc372 commit 4b2e44e
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 24 deletions.
2 changes: 1 addition & 1 deletion examples/mnist_classification.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def download(split_range, shuffle=False):
val_probs=val_means,
test_probs=test_means,
val_targets=val_data_loader.to_array_targets(),
error=0.05
error=0.05,
)

# %% [markdown]
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist_classification_sghmc.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def download(split_range, shuffle=False):
val_probs=val_means,
test_probs=test_means,
val_targets=val_data_loader.to_array_targets(),
error=0.05
error=0.05,
)

# %% [markdown]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
class MaxCoverageFixedPrecisionBinaryClassificationCalibrator:
def __init__(self):
"""
A base iterative multivalid method.
Parameters
----------
seed: int
Random seed.
Given a binary classification framework, let us define true positive precision by
:math:`\mathbb{P}(Y=1|f(X)\ge T_{tp})`
and true negative precision by
:math:`\mathbb{P}(Y=0|f(X)\le T_{tn})`,
where :math:`T_{tp}` and :math:`T_{tn}` are two thresholds greater than :math:`\frac{1}{2}`,
and :math:`f(X)` is a model for the probability that :math:`Y=1`.
We further define coverage as
:math:`\mathbb{P}(f(X)\le T_{tn} + \mathbb{P}(f(X)\ge T_{tp}`.
Then this algorithm defines a new model
:math:`\hat{f}(x)=(\tau_{tp}\,1[f(x)\ge \frac{1}{2}] + \tau_{tn}f(x)\,1[f(x)<\frac{1}{2}])\,f(x)`,
and searches for the :math:`tau_{tp}\in[1, 2T_{tp}]` and :math:`tau_{tn}\in[2T_{tp}, 1]`
that maximize the coverage while guaranteeing that true positive and negative precisions are at least
:math:`T_{tp}` and `T_{tn}`, respectively.
"""
self._patches = dict()

Expand All @@ -27,17 +34,17 @@ def calibrate(
targets: Array,
probs: Array,
true_positive_precision_threshold: float,
false_negative_precision_threshold: float,
true_negative_precision_threshold: float,
test_probs: Optional[Array] = None,
n_taus: int = 100,
margin: float = 0.0,
) -> Union[None, Array]:
if (
false_negative_precision_threshold <= 0.5
true_negative_precision_threshold <= 0.5
or true_positive_precision_threshold <= 0.5
):
raise ValueError(
"Both `false_negative_precision_threshold` and"
"Both `true_negative_precision_threshold` and"
" `true_positive_precision_threshold` must be greater than 0.5."
)
probs = jnp.copy(probs)
Expand All @@ -51,28 +58,28 @@ def _true_positive_objective_fn(tau: Array):
pos_cond = pos_prec >= true_positive_precision_threshold + margin
return prob_b_pos_prec * pos_cond

def _false_negative_objective_fn(tau: Array):
def _true_negative_objective_fn(tau: Array):
calib_probs = (1 + (tau - 1) * (probs < 0.5)) * probs
b_neg_prec = calib_probs <= 1 - false_negative_precision_threshold
b_neg_prec = calib_probs <= 1 - true_negative_precision_threshold
prob_b_neg_prec = jnp.mean(b_neg_prec)
neg_prec = jnp.mean((1 - targets) * b_neg_prec) / prob_b_neg_prec
neg_cond = neg_prec >= false_negative_precision_threshold + margin
neg_cond = neg_prec >= true_negative_precision_threshold + margin
return prob_b_neg_prec * neg_cond

taus_pos = jnp.linspace(1, 2 * true_positive_precision_threshold, n_taus)
taus_neg = jnp.linspace(
2 * (1 - false_negative_precision_threshold), 1, n_taus
)[::-1]
taus_neg = jnp.linspace(2 * (1 - true_negative_precision_threshold), 1, n_taus)[
::-1
]

values_pos = vmap(_true_positive_objective_fn)(taus_pos)

msg = "The {} could not be satisfied. Please consider improving the classifier or decreasing the threshold."

if jnp.max(values_pos) == 0:
logging.warning(msg.format("`true_positive_precision_threshold`"))
values_neg = vmap(_false_negative_objective_fn)(taus_neg)
values_neg = vmap(_true_negative_objective_fn)(taus_neg)
if jnp.max(values_neg) == 0:
logging.warning(msg.format("`false_negative_precision_threshold`"))
logging.warning(msg.format("`true_negative_precision_threshold`"))

self._patches["tau_pos"] = taus_pos[jnp.argmax(values_pos)]
self._patches["tau_neg"] = taus_neg[jnp.argmax(values_neg)]
Expand All @@ -99,7 +106,7 @@ def true_positive_precision(probs: Array, targets: Array, threshold: float):
return jnp.mean(targets * b) / prob_b

@staticmethod
def false_negative_precision(probs: Array, targets: Array, threshold: float):
def true_negative_precision(probs: Array, targets: Array, threshold: float):
b = probs <= 1 - threshold
prob_b = jnp.mean(b)
return jnp.mean((1 - targets) * b) / prob_b
Expand All @@ -109,7 +116,7 @@ def true_positive_coverage(probs: Array, threshold: float):
return jnp.mean(probs >= threshold)

@staticmethod
def false_negative_coverage(probs: Array, threshold: float):
def true_negative_coverage(probs: Array, threshold: float):
return jnp.mean(probs <= threshold)

@property
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aws-fortuna"
version = "0.1.35"
version = "0.1.36"
description = "A Library for Uncertainty Quantification."
authors = ["Gianluca Detommaso <[email protected]>", "Alberto Gasparin <[email protected]>"]
license = "Apache-2.0"
Expand Down
4 changes: 2 additions & 2 deletions tests/fortuna/test_conformal_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,12 +818,12 @@ def test_max_coverage_fixed_precision_binary_classification_calibrator(self):
targets=targets,
probs=probs,
true_positive_precision_threshold=0.99,
false_negative_precision_threshold=0.99,
true_negative_precision_threshold=0.99,
)
test_values = calib.calibrate(
targets=targets,
probs=probs,
test_probs=test_probs,
true_positive_precision_threshold=0.99,
false_negative_precision_threshold=0.99,
true_negative_precision_threshold=0.99,
)

0 comments on commit 4b2e44e

Please sign in to comment.