Skip to content

Commit

Permalink
Fix bug in BatchMVP due to typo in paper (#114)
Browse files Browse the repository at this point in the history
* edit installation instructions in readme

* bump up version

* make small change in readme because of publish to pypi error

* bump up version

* fix error in batchmvp due to typo in paper
  • Loading branch information
gianlucadetommaso authored Aug 1, 2023
1 parent 5b66a60 commit b01f8bb
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 136 deletions.
18 changes: 1 addition & 17 deletions benchmarks/two_moons_multicalibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
values=values,
test_groups=test_groups,
test_values=test_values,
n_buckets=1000,
n_buckets=100,
)

plt.figure(figsize=(10, 3))
Expand Down Expand Up @@ -132,19 +132,3 @@
scores=test_scores, groups=test_groups, values=calib_test_values
),
)

print(
"Mismatch between labels and probs before calibration: ",
jnp.mean(
jnp.maximum((1 - test_targets) * test_values, test_targets * (1 - test_values))
),
)
print(
"Mismatch between labels and probs after calibration: ",
jnp.mean(
jnp.maximum(
(1 - test_targets) * calib_test_values,
test_targets * (1 - calib_test_values),
)
),
)
3 changes: 2 additions & 1 deletion examples/multivalid_coverage.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def plot_intervals(xx, means, intervals, test_data, method):

batchmvp = BatchMVPConformalRegressor()
test_thresholds, status = batchmvp.calibrate(
scores=scores, groups=groups, test_groups=test_groups
scores=scores, groups=groups, test_groups=test_groups, n_buckets=300
)
test_thresholds = min_score + (max_score - min_score) * test_thresholds

Expand Down Expand Up @@ -283,3 +283,4 @@ def plot_intervals(xx, means, intervals, test_data, method):
(xx_qleft - xx_thresholds, xx_qright + xx_thresholds), axis=1
)
plot_intervals(xx, xx_means, xx_batchmvp_intervals, test_data, "BatchMVP")
plt.show()
166 changes: 61 additions & 105 deletions fortuna/conformal/multivalid/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Union,
Expand All @@ -12,41 +11,9 @@
from jax import vmap
import jax.numpy as jnp

from fortuna.data.loader import (
DataLoader,
InputsLoader,
)
from fortuna.typing import Array


class Normalizer:
def __init__(self, x_min: Union[float, Array], x_max: Union[float, Array]):
self.x_min = x_min
self.x_max = x_max if x_max != x_min else x_min + 1

def normalize(self, x: Array) -> Array:
return (x - self.x_min) / (self.x_max - self.x_min)

def unnormalize(self, y: Array) -> Array:
return self.x_min + (self.x_max - self.x_min) * y


class Model:
def __init__(self, model_fn: Callable[[Array], Array]):
self.model_fn = model_fn

def __call__(self, x: Array):
v = self.model_fn(x)
if v.ndim > 1:
raise ValueError(
"Evaluations of the model function `model_fn` must be one-dimensional arrays, "
f"but its shape was {v.shape}."
)
if jnp.any(v < 0) or jnp.any(v > 1):
raise ValueError("The model function must take values within [0, 1].")
return v


class MultivalidMethod:
def __init__(self):
self._patches = []
Expand All @@ -60,7 +27,7 @@ def calibrate(
test_groups: Optional[Array] = None,
test_values: Optional[Array] = None,
tol: float = 1e-4,
n_buckets: int = None,
n_buckets: int = 100,
n_rounds: int = 1000,
**kwargs,
) -> Union[Dict, Tuple[Array, Dict]]:
Expand Down Expand Up @@ -88,10 +55,7 @@ def calibrate(
tol: float
A tolerance on the reweighted average squared calibration error, i.e. :math:`\mu(g) K_2(f, g, \mathcal{D})`.
n_buckets: int
The number of buckets used in the algorithm. The smaller the number of buckets, the simpler the model,
the better its generalization abilities. If not provided, We start from 2 buckets, and progressively double
the number of buckets until we find a value for which the calibration error falls below the given
tolerance. Such number of buckets is guaranteed to exist.
The number of buckets used in the algorithm.
n_rounds: int
The maximum number of rounds to run the method for.
Returns
Expand Down Expand Up @@ -127,79 +91,71 @@ def calibrate(

self._check_range(dict(scores=scores, values=values, test_values=test_values))

increase_n_buckets = False
if n_buckets is None:
n_buckets = 2
increase_n_buckets = True

n_groups = groups.shape[1]
tol_reached = False

while True:
logging.info(f"Attempt reaching tolerance with {n_buckets} buckets.")
buckets = self._get_buckets(n_buckets)
values = vmap(lambda v: self._round_to_buckets(v, buckets))(values_init)

max_calib_errors = []
old_calib_errors_vg = None
self._patches = []

for t in range(n_rounds):
calib_error_vg = vmap(
lambda g: vmap(
lambda v: self._calibration_error(
v,
g,
scores=scores,
groups=groups,
values=values,
n_buckets=n_buckets,
**kwargs,
)
)(buckets)
)(jnp.arange(n_groups))

max_calib_errors.append(calib_error_vg.sum(1).max())
if max_calib_errors[-1] <= tol:
tol_reached = True
logging.info(
f"Tolerance satisfied after {t} rounds with {n_buckets} buckets."
)
break
if old_calib_errors_vg is not None and jnp.allclose(
old_calib_errors_vg, calib_error_vg
):
break
old_calib_errors_vg = jnp.copy(calib_error_vg)

gt, vt = self._get_gt_and_vt(
calib_error_vg=calib_error_vg, buckets=buckets, n_groups=n_groups
)
bt = self._get_b(
groups=groups, values=values, v=vt, g=gt, n_buckets=len(buckets)
)
patch = self._get_patch(
vt=vt,
gt=gt,
scores=scores,
groups=groups,
values=values,
buckets=buckets,
**kwargs,
)
values = self._patch(values=values, patch=patch, bt=bt)

self._patches.append((gt, vt, patch))
self.n_buckets = n_buckets
buckets = self._get_buckets(n_buckets)
values = vmap(lambda v: self._round_to_buckets(v, buckets))(values_init)

if tol_reached:
max_calib_errors = []
old_calib_errors_vg = None
self._patches = []
converged = True

for t in range(n_rounds):
calib_error_vg = vmap(
lambda g: vmap(
lambda v: self._calibration_error(
v,
g,
scores=scores,
groups=groups,
values=values,
n_buckets=n_buckets,
**kwargs,
)
)(buckets)
)(jnp.arange(n_groups))

max_calib_errors.append(calib_error_vg.sum(1).max())
if max_calib_errors[-1] <= tol:
logging.info(f"Tolerance satisfied after {t} rounds.")
break
if increase_n_buckets:
n_buckets *= 2
else:
if old_calib_errors_vg is not None and jnp.allclose(
old_calib_errors_vg, calib_error_vg
):
converged = False
logging.warning(
"The algorithm cannot achieve the desired tolerance. "
"Please try increasing `n_buckets`."
)
break
old_calib_errors_vg = jnp.copy(calib_error_vg)

self.n_buckets = n_buckets
status = dict(n_rounds=len(self.patches), max_calib_errors=max_calib_errors)
gt, vt = self._get_gt_and_vt(
calib_error_vg=calib_error_vg, buckets=buckets, n_groups=n_groups
)
bt = self._get_b(
groups=groups, values=values, v=vt, g=gt, n_buckets=len(buckets)
)
patch = self._get_patch(
vt=vt,
gt=gt,
scores=scores,
groups=groups,
values=values,
buckets=buckets,
**kwargs,
)
values = self._patch(values=values, patch=patch, bt=bt)

self._patches.append((gt, vt, patch))

status = dict(
n_rounds=len(self.patches),
max_calib_errors=max_calib_errors,
converged=converged,
)

if test_groups is not None:
test_values = self.apply_patches(test_groups, test_values)
Expand Down
22 changes: 9 additions & 13 deletions fortuna/conformal/multivalid/batch_mvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def calibrate(
test_groups: Optional[Array] = None,
test_values: Optional[Array] = None,
tol: float = 1e-4,
n_buckets: int = None,
n_buckets: int = 100,
n_rounds: int = 1000,
coverage: float = 0.95,
) -> Union[Dict, Tuple[Array, Dict]]:
Expand Down Expand Up @@ -117,41 +117,42 @@ def _calibration_error(
values: Array,
n_buckets: int,
coverage: float = None,
threshold: Array = None,
):
prob_error, prob_b = self._compute_probability_error(
v=v,
g=g,
delta=0.0,
scores=scores,
groups=groups,
values=values,
n_buckets=n_buckets,
return_prob_b=True,
coverage=coverage,
threshold=threshold,
)
return prob_b * prob_error

def _compute_probability_error(
self,
v: Array,
g: Array,
delta: Array,
scores: Array,
groups: Array,
values: Array,
n_buckets: int,
return_prob_b: bool = False,
coverage: float = None,
threshold: Array = None,
):
prob = self._compute_probability(
v=v,
g=g,
delta=delta,
scores=scores,
groups=groups,
values=values,
n_buckets=n_buckets,
return_prob_b=return_prob_b,
threshold=threshold,
)
if return_prob_b:
prob, prob_b = prob
Expand All @@ -162,15 +163,15 @@ def _compute_probability(
self,
v: Array,
g: Array,
delta: Array,
scores: Array,
groups: Array,
values: Array,
n_buckets: int,
return_prob_b: bool = False,
threshold: Array = None,
):
b = self._get_b(groups=groups, values=values, v=v, g=g, n_buckets=n_buckets)
conds = (scores <= v + delta) * b
conds = (scores <= (v if threshold is None else threshold)) * b
prob_b = jnp.mean(b)
prob = jnp.where(prob_b > 0, jnp.mean(conds) / prob_b, 0.0)
if return_prob_b:
Expand All @@ -190,21 +191,16 @@ def _get_patch(
return buckets[
jnp.argmin(
vmap(
lambda delta: self._compute_probability_error(
lambda v: self._compute_probability_error(
v=vt,
g=gt,
delta=delta,
scores=scores,
groups=groups,
values=values,
n_buckets=len(buckets),
coverage=coverage,
threshold=v,
)
)(buckets)
)
]

def _patch(
self, values: Array, patch: Array, bt: Array, _shift: bool = True
) -> Array:
return super()._patch(values=values, patch=patch, bt=bt, _shift=_shift)

0 comments on commit b01f8bb

Please sign in to comment.