diff --git a/examples/mnist_classification.pct.py b/examples/mnist_classification.pct.py index 08468afd..3f8d2e82 100644 --- a/examples/mnist_classification.pct.py +++ b/examples/mnist_classification.pct.py @@ -153,7 +153,7 @@ def download(split_range, shuffle=False): val_probs=val_means, test_probs=test_means, val_targets=val_data_loader.to_array_targets(), - error=0.05 + error=0.05, ) # %% [markdown] diff --git a/examples/mnist_classification_sghmc.pct.py b/examples/mnist_classification_sghmc.pct.py index 0f133d4c..8bf628cf 100644 --- a/examples/mnist_classification_sghmc.pct.py +++ b/examples/mnist_classification_sghmc.pct.py @@ -143,7 +143,7 @@ def download(split_range, shuffle=False): val_probs=val_means, test_probs=test_means, val_targets=val_data_loader.to_array_targets(), - error=0.05 + error=0.05, ) # %% [markdown] diff --git a/fortuna/conformal/classification/maxcovfixprec_binary_classfication.py b/fortuna/conformal/classification/maxcovfixprec_binary_classfication.py index d484fd43..df577416 100644 --- a/fortuna/conformal/classification/maxcovfixprec_binary_classfication.py +++ b/fortuna/conformal/classification/maxcovfixprec_binary_classfication.py @@ -13,12 +13,19 @@ class MaxCoverageFixedPrecisionBinaryClassificationCalibrator: def __init__(self): """ - A base iterative multivalid method. - - Parameters - ---------- - seed: int - Random seed. + Given a binary classification framework, let us define true positive precision by + :math:`\mathbb{P}(Y=1|f(X)\ge T_{tp})` + and true negative precision by + :math:`\mathbb{P}(Y=0|f(X)\le T_{tn})`, + where :math:`T_{tp}` and :math:`T_{tn}` are two thresholds greater than :math:`\frac{1}{2}`, + and :math:`f(X)` is a model for the probability that :math:`Y=1`. + We further define coverage as + :math:`\mathbb{P}(f(X)\le T_{tn} + \mathbb{P}(f(X)\ge T_{tp}`. + Then this algorithm defines a new model + :math:`\hat{f}(x)=(\tau_{tp}\,1[f(x)\ge \frac{1}{2}] + \tau_{tn}f(x)\,1[f(x)<\frac{1}{2}])\,f(x)`, + and searches for the :math:`tau_{tp}\in[1, 2T_{tp}]` and :math:`tau_{tn}\in[2T_{tp}, 1]` + that maximize the coverage while guaranteeing that true positive and negative precisions are at least + :math:`T_{tp}` and `T_{tn}`, respectively. """ self._patches = dict() @@ -27,17 +34,17 @@ def calibrate( targets: Array, probs: Array, true_positive_precision_threshold: float, - false_negative_precision_threshold: float, + true_negative_precision_threshold: float, test_probs: Optional[Array] = None, n_taus: int = 100, margin: float = 0.0, ) -> Union[None, Array]: if ( - false_negative_precision_threshold <= 0.5 + true_negative_precision_threshold <= 0.5 or true_positive_precision_threshold <= 0.5 ): raise ValueError( - "Both `false_negative_precision_threshold` and" + "Both `true_negative_precision_threshold` and" " `true_positive_precision_threshold` must be greater than 0.5." ) probs = jnp.copy(probs) @@ -51,18 +58,18 @@ def _true_positive_objective_fn(tau: Array): pos_cond = pos_prec >= true_positive_precision_threshold + margin return prob_b_pos_prec * pos_cond - def _false_negative_objective_fn(tau: Array): + def _true_negative_objective_fn(tau: Array): calib_probs = (1 + (tau - 1) * (probs < 0.5)) * probs - b_neg_prec = calib_probs <= 1 - false_negative_precision_threshold + b_neg_prec = calib_probs <= 1 - true_negative_precision_threshold prob_b_neg_prec = jnp.mean(b_neg_prec) neg_prec = jnp.mean((1 - targets) * b_neg_prec) / prob_b_neg_prec - neg_cond = neg_prec >= false_negative_precision_threshold + margin + neg_cond = neg_prec >= true_negative_precision_threshold + margin return prob_b_neg_prec * neg_cond taus_pos = jnp.linspace(1, 2 * true_positive_precision_threshold, n_taus) - taus_neg = jnp.linspace( - 2 * (1 - false_negative_precision_threshold), 1, n_taus - )[::-1] + taus_neg = jnp.linspace(2 * (1 - true_negative_precision_threshold), 1, n_taus)[ + ::-1 + ] values_pos = vmap(_true_positive_objective_fn)(taus_pos) @@ -70,9 +77,9 @@ def _false_negative_objective_fn(tau: Array): if jnp.max(values_pos) == 0: logging.warning(msg.format("`true_positive_precision_threshold`")) - values_neg = vmap(_false_negative_objective_fn)(taus_neg) + values_neg = vmap(_true_negative_objective_fn)(taus_neg) if jnp.max(values_neg) == 0: - logging.warning(msg.format("`false_negative_precision_threshold`")) + logging.warning(msg.format("`true_negative_precision_threshold`")) self._patches["tau_pos"] = taus_pos[jnp.argmax(values_pos)] self._patches["tau_neg"] = taus_neg[jnp.argmax(values_neg)] @@ -99,7 +106,7 @@ def true_positive_precision(probs: Array, targets: Array, threshold: float): return jnp.mean(targets * b) / prob_b @staticmethod - def false_negative_precision(probs: Array, targets: Array, threshold: float): + def true_negative_precision(probs: Array, targets: Array, threshold: float): b = probs <= 1 - threshold prob_b = jnp.mean(b) return jnp.mean((1 - targets) * b) / prob_b @@ -109,7 +116,7 @@ def true_positive_coverage(probs: Array, threshold: float): return jnp.mean(probs >= threshold) @staticmethod - def false_negative_coverage(probs: Array, threshold: float): + def true_negative_coverage(probs: Array, threshold: float): return jnp.mean(probs <= threshold) @property diff --git a/pyproject.toml b/pyproject.toml index d17b9d3c..fffac6bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aws-fortuna" -version = "0.1.35" +version = "0.1.36" description = "A Library for Uncertainty Quantification." authors = ["Gianluca Detommaso ", "Alberto Gasparin "] license = "Apache-2.0" diff --git a/tests/fortuna/test_conformal_methods.py b/tests/fortuna/test_conformal_methods.py index 2f815856..ef3220d6 100755 --- a/tests/fortuna/test_conformal_methods.py +++ b/tests/fortuna/test_conformal_methods.py @@ -818,12 +818,12 @@ def test_max_coverage_fixed_precision_binary_classification_calibrator(self): targets=targets, probs=probs, true_positive_precision_threshold=0.99, - false_negative_precision_threshold=0.99, + true_negative_precision_threshold=0.99, ) test_values = calib.calibrate( targets=targets, probs=probs, test_probs=test_probs, true_positive_precision_threshold=0.99, - false_negative_precision_threshold=0.99, + true_negative_precision_threshold=0.99, )