Skip to content

Commit

Permalink
Add possibility to select different bucket types in iterative multiva…
Browse files Browse the repository at this point in the history
…lid methods (#133)
  • Loading branch information
gianlucadetommaso authored Sep 24, 2023
1 parent b79415d commit fc37649
Show file tree
Hide file tree
Showing 16 changed files with 374 additions and 111 deletions.
81 changes: 81 additions & 0 deletions benchmarks/multivalid/breast_cancer_multicalibrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier

from fortuna.conformal import BinaryClassificationMulticalibrator
from fortuna.metric.classification import accuracy

data = load_breast_cancer()
inputs = data.data
targets = data.target
train_inputs, test_inputs, train_targets, test_targets = train_test_split(
inputs, targets, test_size=0.3, random_state=1
)
train_size = int(len(train_inputs) * 0.5)
train_inputs, calib_inputs = train_inputs[:train_size], train_inputs[train_size:]
train_targets, calib_targets = train_targets[:train_size], train_targets[train_size:]

calib_size = calib_targets.shape[0]

model = MLPClassifier(random_state=42)
model.fit(train_inputs, train_targets)

calib_preds = model.predict(calib_inputs)
calib_probs = (
model.predict_proba(calib_inputs)
if hasattr(model, "predict_proba")
else model._predict_proba_lr(calib_inputs)
)
test_preds = model.predict(test_inputs)
test_probs = (
model.predict_proba(test_inputs)
if hasattr(model, "predict_proba")
else model._predict_proba_lr(test_inputs)
)

mc = BinaryClassificationMulticalibrator()
mc_test_probs1, status = mc.calibrate(
targets=calib_targets,
probs=calib_probs[:, 1],
test_probs=test_probs[:, 1],
)

mc_calib_probs1 = mc.apply_patches(probs=calib_probs[:, 1])

print(
f"Calib accuracy pre/post calibration: {float(accuracy(calib_preds, calib_targets)), float(accuracy(mc_calib_probs1 > 0.5, calib_targets))}"
)
print(
f"Test accuracy pre/post calibration: {float(accuracy(test_preds, test_targets)), float(accuracy(mc_test_probs1 > 0.5, test_targets))}"
)
print()
print(
f"Calib MSE pre/post calibration: {float(mc.mean_squared_error(calib_probs[:, 1], calib_targets)), float(mc.mean_squared_error(mc_calib_probs1, calib_targets))}"
)
print(
f"Test MSE pre/post calibration: {float(mc.mean_squared_error(test_probs[:, 1], test_targets)), float(mc.mean_squared_error(mc_test_probs1, test_targets))}"
)
print()

from fortuna.conformal import OneShotBinaryClassificationMulticalibrator

osmc = OneShotBinaryClassificationMulticalibrator()
osmc_test_probs1 = osmc.calibrate(
targets=calib_targets, probs=calib_probs[:, 1], test_probs=test_probs[:, 1]
)

osmc_calib_probs1 = osmc.apply_patches(probs=calib_probs[:, 1])

print(
f"Calib accuracy pre/post one-shot calibration: {float(accuracy(calib_preds, calib_targets)), float(accuracy(osmc_calib_probs1 > 0.5, calib_targets))}"
)
print(
f"Test accuracy pre/post one-shot calibration: {float(accuracy(test_preds, test_targets)), float(accuracy(osmc_test_probs1 > 0.5, test_targets))}"
)
print()
print(
f"Calib MSE pre/post one-shot calibration: {float(mc.mean_squared_error(calib_probs[:, 1], calib_targets)), float(mc.mean_squared_error(osmc_calib_probs1, calib_targets))}"
)
print(
f"Test MSE pre/post one-shot calibration: {float(mc.mean_squared_error(test_probs[:, 1], test_targets)), float(mc.mean_squared_error(osmc_test_probs1, test_targets))}"
)
4 changes: 1 addition & 3 deletions benchmarks/multivalid/two_moons_multicalibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@
probs=probs,
test_groups=test_groups,
test_probs=test_probs,
n_buckets=100,
min_prob_b=0.0,
)

plt.figure(figsize=(10, 3))
Expand All @@ -116,7 +114,7 @@
plt.show()

plt.title("Mean-squared error decay during calibration")
plt.semilogy(status["mean_squared_errors"])
plt.semilogy(status["losses"])
plt.show()

print(
Expand Down
8 changes: 6 additions & 2 deletions examples/multivalid_coverage.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,11 @@ def plot_intervals(xx, means, intervals, test_data, method):

batchmvp = BatchMVPConformalRegressor()
test_thresholds, status = batchmvp.calibrate(
scores=scores, groups=groups, test_groups=test_groups, n_buckets=100, eta=1
scores=scores,
groups=groups,
test_groups=test_groups,
eta=1,
bucket_types=("<=", ">="),
)
test_thresholds = min_score + (max_score - min_score) * test_thresholds

Expand All @@ -237,7 +241,7 @@ def plot_intervals(xx, means, intervals, test_data, method):

# %%
plt.figure(figsize=(6, 3))
plt.plot(status["mean_squared_errors"], label="mean squared error decay")
plt.plot(status["losses"], label="mean squared error decay")
plt.xlabel("rounds")
plt.legend()
plt.show()
Expand Down
25 changes: 0 additions & 25 deletions fortuna/conformal/multivalid/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,6 @@ def __init__(self, seed: int = 0):
self._patches = None
self._n_buckets = None

def mean_squared_error(self, values: Array, scores: Array) -> Array:
"""
The mean squared error between the model evaluations and the scores.
This is supposed to decrease at every round of the algorithm.
Parameters
----------
values: Array
The model evaluations.
scores: Array
The scores.
Returns
-------
Array
The mean-squared error.
"""
return self._mean_squared_error(values, scores)

@property
def patches(self):
return self._patches
Expand All @@ -53,12 +34,6 @@ def n_buckets(self):
def n_buckets(self, n_buckets):
self._n_buckets = n_buckets

@staticmethod
def _mean_squared_error(values: Array, scores: Array) -> Array:
if scores.ndim == 2 and values.ndim == 1:
scores = scores[:, 0]
return jnp.mean((values - scores) ** 2)

@staticmethod
def _get_buckets(n_buckets: int):
return jnp.linspace(0, 1, n_buckets)
Expand Down
Loading

0 comments on commit fc37649

Please sign in to comment.