Skip to content

Commit

Permalink
fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Sep 11, 2023
1 parent 8783f38 commit 0ceee31
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 18 deletions.
17 changes: 13 additions & 4 deletions benchmarks/multivalid/two_moons_multicalibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,26 @@
plt.tight_layout()
plt.show()

plt.title("Max calibration error decay during calibration")
plt.semilogy(status["max_calib_errors"])
plt.title("Mean-squared error decay during calibration")
plt.semilogy(status["mean_squared_errors"])
plt.show()

print(
"Per-group reweighted avg. squared calib. error before calibration: ",
"Per-group reweighted avg. squared calib. error before calibration on test data: ",
mc.calibration_error(targets=test_targets, groups=test_groups, probs=test_probs),
)
print(
"Per-group reweighted avg. squared calib. error after calibration: ",
"Per-group reweighted avg. squared calib. error after calibration on test data: ",
mc.calibration_error(
targets=test_targets, groups=test_groups, probs=calib_test_probs
),
)

print(
"Mean-squared error before calibration on test data: ",
mc.mean_squared_error(probs=test_probs, targets=test_targets),
)
print(
"Mean-squared error after calibration on test data: ",
mc.mean_squared_error(probs=calib_test_probs, targets=test_targets),
)
15 changes: 8 additions & 7 deletions examples/multivalid_coverage.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@
import numpy as np


def generate_data(n_data: int, sigma1=0.03, sigma2=0.5):
def generate_data(n_data: int, sigma1=0.03, sigma2=0.5, seed: int = 43):
rng = np.random.default_rng(seed=seed)
x = np.concatenate(
[
np.random.normal(loc=-1, scale=0.3, size=(n_data // 2, 1)),
np.random.normal(loc=1, scale=0.3, size=(n_data - n_data // 2, 1)),
rng.normal(loc=-1, scale=0.3, size=(n_data // 2, 1)),
rng.normal(loc=1, scale=0.3, size=(n_data - n_data // 2, 1)),
]
)
y = np.cos(x) + np.concatenate(
[
np.random.normal(scale=sigma1, size=(n_data // 2, 1)),
np.random.normal(scale=sigma2, size=(n_data - n_data // 2, 1)),
rng.normal(scale=sigma1, size=(n_data // 2, 1)),
rng.normal(scale=sigma2, size=(n_data - n_data // 2, 1)),
]
)
return x, y
Expand Down Expand Up @@ -227,7 +228,7 @@ def plot_intervals(xx, means, intervals, test_data, method):

batchmvp = BatchMVPConformalRegressor()
test_thresholds, status = batchmvp.calibrate(
scores=scores, groups=groups, test_groups=test_groups, n_buckets=300
scores=scores, groups=groups, test_groups=test_groups, n_buckets=100, eta=1
)
test_thresholds = min_score + (max_score - min_score) * test_thresholds

Expand All @@ -236,7 +237,7 @@ def plot_intervals(xx, means, intervals, test_data, method):

# %%
plt.figure(figsize=(6, 3))
plt.plot(status["max_calib_errors"], label="maximum calibration error decay")
plt.plot(status["mean_squared_errors"], label="mean squared error decay")
plt.xlabel("rounds")
plt.legend()
plt.show()
Expand Down
4 changes: 3 additions & 1 deletion fortuna/conformal/classification/binary_multicalibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def calibrate(
rtol: float = 1e-6,
n_buckets: int = 100,
n_rounds: int = 1000,
eta: float = 1.0,
eta: float = 0.1,
split: float = 0.8,
**kwargs,
) -> Union[Dict, Tuple[Array, Dict]]:
return super().calibrate(
Expand All @@ -38,6 +39,7 @@ def calibrate(
n_buckets=n_buckets,
n_rounds=n_rounds,
eta=eta,
split=split,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def calibrate(
rtol: float = 1e-6,
n_buckets: int = 100,
n_rounds: int = 1000,
eta: float = 1.0,
eta: float = 0.1,
split: float = 0.8,
**kwargs,
) -> Union[Dict, Tuple[Array, Dict]]:
return super().calibrate(
Expand All @@ -59,6 +60,7 @@ def calibrate(
n_buckets=n_buckets,
n_rounds=n_rounds,
eta=eta,
split=split,
**kwargs,
)

Expand Down
10 changes: 6 additions & 4 deletions fortuna/conformal/multivalid/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def calibrate(
rtol: float = 1e-6,
n_buckets: int = 100,
n_rounds: int = 1000,
eta: float = 1.0,
eta: float = 0.1,
split: float = 0.8,
**kwargs,
) -> Union[Dict, Tuple[Array, Dict]]:
Expand Down Expand Up @@ -106,12 +106,14 @@ def calibrate(
raise ValueError(
"If `groups` and `test_values` are provided, `test_groups` must also be provided."
)
if eta < 0 or eta > 1:
if eta <= 0 or eta > 1:
raise ValueError(
"`eta` must be a float between 0 and 1, extremes included."
"`eta` must be a float greater than 0 and less or equal than 1."
)
if split <= 0 or split > 1:
raise ValueError("`split` must be greater than 0 and less or equal than 1.")
raise ValueError(
"`split` must be a float greater than 0 and less or equal than 1."
)
self._check_scores(scores)
scores = self._process_scores(scores)
n_dims = scores.shape[1]
Expand Down
7 changes: 6 additions & 1 deletion fortuna/conformal/multivalid/batch_mvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def calibrate(
rtol: float = 1e-6,
n_buckets: int = 100,
n_rounds: int = 1000,
eta: float = 1.0,
eta: float = 0.1,
split: float = 0.8,
coverage: float = 0.95,
) -> Union[Dict, Tuple[Array, Dict]]:
"""
Expand Down Expand Up @@ -77,6 +78,9 @@ def calibrate(
The maximum number of rounds to run the method for.
eta: float
Step size. By default, this is set to 1.
split: float
Split the calibration data into calibration and validation, according to the given proportion.
The validation data will be used for early stopping.
coverage: float
The desired level of coverage. This must be a scalar between 0 and 1.
Returns
Expand All @@ -99,6 +103,7 @@ def calibrate(
n_buckets=n_buckets,
n_rounds=n_rounds,
eta=eta,
split=split,
coverage=coverage,
)

Expand Down

0 comments on commit 0ceee31

Please sign in to comment.