Skip to content

Commit

Permalink
Added API ood post-training ood classifiers
Browse files Browse the repository at this point in the history
- Added abstract base class for ood classifiers
- Added Mahalanobis OOD classifiers
- Added DDU
  • Loading branch information
Alberto Gasparin committed Aug 17, 2023
1 parent 36a0737 commit e3c053d
Show file tree
Hide file tree
Showing 9 changed files with 706 additions and 216 deletions.
10 changes: 10 additions & 0 deletions docs/source/references/ood_classifier.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. _ood_detection:

Out-Of-Distribution (OOD) detection
==================
Starting from a trained a neural classifier, it's possible to fit one of the models below
to help distinguish between in-distribution and out of distribution inputs.

.. autoclass:: fortuna.ood_detection.mahalanobis.MalahanobisOODClassifier

.. autoclass:: fortuna.ood_detection.ddu.DeepDeterministicUncertaintyOODClassifier
1 change: 1 addition & 0 deletions docs/source/references/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ API References
output_calibrator
prob_output_layer
conformal
ood_detection
data_loader
metric
utils
Expand Down
2 changes: 1 addition & 1 deletion examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ In this section we show some examples of how to use Fortuna in classification an
multivalid_coverage
sinusoidal_regression
two_moons_classification
two_moons_classification_ood
subnet_calibration
two_moons_classification_sngp
scaling_up_bayesian_inference
mnist_classification_sghmc
sgmcmc_diagnostics
Expand Down
337 changes: 337 additions & 0 deletions examples/two_moons_classification_ood.pct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,337 @@
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.5
# kernelspec:
# display_name: python3
# language: python
# name: python3
# ---

# %% [markdown]

# # Two-moons Classification: Improved uncertainty quantification

# %% [markdown]
# In this notebook we will see how to fix model overconfidence over inputs that are far-away from the training data.
# We will do that using two different approaches; let's dive right into it!


# %% [markdown]
# ### Setup
# #### Download the Two-Moons data from scikit-learn
# Let us first download the two-moons data from [scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html).

# %%
TRAIN_DATA_SIZE = 500

from sklearn.datasets import make_moons

train_data = make_moons(n_samples=TRAIN_DATA_SIZE, noise=0.1, random_state=0)
val_data = make_moons(n_samples=500, noise=0.1, random_state=1)
test_data = make_moons(n_samples=500, noise=0.1, random_state=2)

# %% [markdown]
# #### Convert data to a compatible data loader
# Fortuna helps you convert data and data loaders into a data loader that Fortuna can digest.

# %%
from fortuna.data import DataLoader

train_data_loader = DataLoader.from_array_data(
train_data, batch_size=256, shuffle=True, prefetch=True
)
val_data_loader = DataLoader.from_array_data(val_data, batch_size=256, prefetch=True)
test_data_loader = DataLoader.from_array_data(test_data, batch_size=256, prefetch=True)

# %% [markdown]
# #### Define some utils for plotting the estimated uncertainty

# %%
import matplotlib.pyplot as plt
import numpy as np
from fortuna.data import InputsLoader
from fortuna.prob_model import ProbClassifier
import jax.numpy as jnp


def get_grid_inputs_loader(grid_size: int = 100):
xx = np.linspace(-4, 4, grid_size)
yy = np.linspace(-4, 4, grid_size)
grid = np.array([[_xx, _yy] for _xx in xx for _yy in yy])
grid_inputs_loader = InputsLoader.from_array_inputs(grid)
grid = grid.reshape(grid_size, grid_size, 2)
return grid, grid_inputs_loader


def compute_test_modes(
prob_model: ProbClassifier, test_data_loader: DataLoader
):
test_inputs_loader = test_data_loader.to_inputs_loader()
test_means = prob_model.predictive.mean(inputs_loader=test_inputs_loader)
return prob_model.predictive.mode(
inputs_loader=test_inputs_loader, means=test_means
)

def plot_uncertainty_over_grid(
grid: jnp.ndarray, scores: jnp.ndarray, test_modes: jnp.ndarray, title: str = "Predictive uncertainty"
):
scores = scores.reshape(grid.shape[0], grid.shape[1])

_, ax = plt.subplots(figsize=(7, 5.5))
plt.title(title, fontsize=12)
pcm = ax.imshow(
scores.T,
origin="lower",
extent=(-4., 4., -4., 4.),
interpolation='bicubic',
aspect='auto')

# Plot training data.
plt.scatter(
test_data[0][:, 0],
test_data[0][:, 1],
s=1,
c=["C0" if i == 1 else "C1" for i in test_modes],
)
plt.colorbar()


# %% [markdown]
# ### Define the deterministic model
# In this tutorial we will use a deep residual network, see `fortuna.model.mlp.DeepResidualNet` for
# more details on the model.

# %%
from fortuna.model.mlp import DeepResidualNet
import flax.linen as nn

output_dim = 2
model = DeepResidualNet(
output_dim=output_dim,
activations=(nn.relu, nn.relu, nn.relu, nn.relu, nn.relu, nn.relu),
widths=(128, 128, 128, 128, 128, 128),
dropout_rate=0.1,
)

# %%
from fortuna.prob_model import MAPPosteriorApproximator
from fortuna.prob_model import FitConfig, FitMonitor, FitOptimizer
from fortuna.metric.classification import accuracy


prob_model = ProbClassifier(
model=model,
posterior_approximator=MAPPosteriorApproximator(),
output_calibrator=None,
)
status = prob_model.train(
train_data_loader=train_data_loader,
val_data_loader=val_data_loader,
calib_data_loader=val_data_loader,
fit_config=FitConfig(
monitor=FitMonitor(metrics=(accuracy,)),
optimizer=FitOptimizer(n_epochs=100),
),
)

# %%
test_modes = compute_test_modes(prob_model, test_data_loader)
grid, grid_inputs_loader = get_grid_inputs_loader(grid_size=100)
grid_entropies = prob_model.predictive.entropy(grid_inputs_loader)
plot_uncertainty_over_grid(grid=grid, scores=grid_entropies, test_modes=test_modes)
plt.show()

# %% [markdown]
# Clearly, the model is overconfident on inputs that are far away from the training data.
# This behaviour is not what one would expect, as we rather the model being less confident on out-of-distributin inputs.

# %% [markdown]
# ### Fit an OOD classifier to distinguish between in-distribution and out-of-distribution inputs
# Given the trained model from above, we can now use one of the models provided by Fortuna to actually improve
# the model's confidence on the out-of-distribution inputs.
# In the example below we will use the Malahanobis-based classifier introduced in
# [Lee, Kimin, et al](https://proceedings.neurips.cc/paper/2018/file/abdeb6f575ac5c6676b747bca8d09cc2-Paper.pdf)

# %%
from fortuna.ood_detection.mahalanobis import MalahanobisOODClassifier
from fortuna.model.mlp import DeepResidualFeatureExtractorSubNet
import jax


feature_extractor_subnet=DeepResidualFeatureExtractorSubNet(
dense=model.dense,
widths=model.widths,
activations=model.activations,
dropout=model.dropout,
dropout_rate=model.dropout_rate,
)

@jax.jit
def _apply(inputs, params, mutable):
variables = {'params': params["model"]['params']['dfe_subnet'].unfreeze()}
if mutable is not None:
mutable_variables = {k: v['dfe_subnet'].unfreeze() for k, v in mutable["model"].items()}
variables.update(mutable_variables)
return feature_extractor_subnet.apply(variables, inputs, train=False, mutable=False)

ood_classifier = MalahanobisOODClassifier(num_classes=2)

# %% [markdown]
# In the code block above we initialize our classifier (`MalahanobisOODClassifier`) and we also define a
# `feature_extractor_subnet`, which is a sub-network of our previously trained model
# that allow one to transform an input vector into an embedding vector. In the example, this is our original model
# (`DeepResidualNet`) without the output layer.
# We are now ready to fit the classifier using our training data and verify whether the model's overconfidence has been
# (at least partially) fixed:

# %%
from typing import Tuple

import tqdm

from fortuna.data.loader.base import BaseDataLoaderABC, BaseInputsLoader
from fortuna.prob_model.posterior.state import PosteriorState
from fortuna.typing import Array

# define some util functions


def get_embeddings_and_targets(state: PosteriorState, train_data_loader: BaseDataLoaderABC) -> Tuple[Array, Array]:
train_labels = []
train_embeddings = []
for x, y in tqdm.tqdm(
train_data_loader, desc="Computing embeddings for Malhanbis Classifier: "
):
train_embeddings.append(
_apply(inputs=x, params=state.params, mutable=state.mutable)
)
train_labels.append(y)
train_embeddings = jnp.concatenate(train_embeddings, 0)
train_labels = jnp.concatenate(train_labels)
return train_embeddings, train_labels


def get_embeddings(state: PosteriorState, inputs_loader: BaseInputsLoader):
return jnp.concatenate(
[
_apply(inputs=x, params=state.params, mutable=state.mutable)
for x in inputs_loader
],
0,
)

# %%
state = prob_model.posterior.state.get()
train_embeddings, train_labels = get_embeddings_and_targets(state=state, train_data_loader=train_data_loader)
ood_classifier.fit(embeddings=train_embeddings, targets=train_labels)
grid, grid_inputs_loader = get_grid_inputs_loader(grid_size=100)
grid_embeddings = get_embeddings(state=state, inputs_loader=grid_inputs_loader)
grid_scores = ood_classifier.score(embeddings=grid_embeddings)
# for the sake of plotting we set a threshold on the OOD classifier scores using the max score
# obtained from a known in-distribution source
ind_embeddings = get_embeddings(state=state, inputs_loader=val_data_loader.to_inputs_loader())
ind_scores = ood_classifier.score(embeddings=ind_embeddings)
threshold = ind_scores.max()*2
grid_scores = jnp.where(grid_scores < threshold, grid_scores, threshold)
plot_uncertainty_over_grid(grid=grid, scores=grid_scores, test_modes=test_modes, title="OOD scores")
plt.show()


# %% [markdown]
# We will now see a different way of obtaining improved uncertainty estimation
# (for out-of-distribution inputs): [SNGP](https://arxiv.org/abs/2006.10108).
# Unlike before, we now have to retrain the model as the architecture will slighly change.
# The reason for this will be clear from the model definition below.

# %% [markdown]
# ### Define the SNGP model
# Compared to the deterministic model obtained in the first part of this notebook, SNGP has two crucial differences:
#
# 1. [Spectral Normalization](https://arxiv.org/abs/1802.05957) is applied to all Dense (or Convolutional) layers.
# 2. The Dense output layer is replaced with a Gaussian Process layer.
#
# Let's see how to do it in Fortuna:

# %% [markdown]
# In order to add Spectral Normalization to a deterministic network we just need to define a new deep feature extractor,
# inheriting from both the feature extractor used by the deterministic model (in this case `MLPDeepFeatureExtractorSubNet`)
# and `WithSpectralNorm`. It is worth highlighting that `WithSpectralNorm` should be taken as is, while the deep feature extractor
# can be replaced with any custom object:

# %%
from fortuna.model.mlp import DeepResidualFeatureExtractorSubNet
from fortuna.model.utils.spectral_norm import WithSpectralNorm


class SNGPDeepFeatureExtractorSubNet(
WithSpectralNorm, DeepResidualFeatureExtractorSubNet
):
pass


# %% [markdown]
# Then, we can define our SNGP model by:
#
# - Replacing the deep feature extractor: from `MLPDeepFeatureExtractorSubNet` to `SNGPDeepFeatureExtractorSubNet`
# - Using the `SNGPPosteriorApproximator` as the `posterior_approximator` for the `ProbModel`.
#
# Nothing else is needed, Fortuna will take care of the rest for you!

# %%
import jax.numpy as jnp

from fortuna.prob_model.prior import IsotropicGaussianPrior
from fortuna.prob_model import SNGPPosteriorApproximator

output_dim = 2
model = SNGPDeepFeatureExtractorSubNet(
activations=tuple([nn.relu] * 6),
widths=tuple([128] * 6),
dropout_rate=0.1,
spectral_norm_bound=0.9,
)

prob_model = ProbClassifier(
model=model,
prior=IsotropicGaussianPrior(
log_var=jnp.log(1.0 / 1e-4) - jnp.log(TRAIN_DATA_SIZE)
),
posterior_approximator=SNGPPosteriorApproximator(output_dim=output_dim),
output_calibrator=None,
)

# %% [markdown]
# Notice that the only required argument when initializing `SNGPPosteriorApproximator` is
# `output_dim`, which should be set to the number of classes in the classification task.
# `SNGPPosteriorApproximator` has more optional parameters that you can play with, to gain a better understanding of those you can
# check out the documentation and/or the [original paper](https://arxiv.org/abs/2006.10108).

# %% [markdown]
# We are now ready to train the model as we usually do:

# %%
status = prob_model.train(
train_data_loader=train_data_loader,
val_data_loader=val_data_loader,
calib_data_loader=val_data_loader,
fit_config=FitConfig(
monitor=FitMonitor(metrics=(accuracy,)),
optimizer=FitOptimizer(n_epochs=100),
),
)

# %%
test_modes = compute_test_modes(prob_model, test_data_loader)
grid, grid_inputs_loader = get_grid_inputs_loader(grid_size=100)
grid_entropies = prob_model.predictive.entropy(grid_inputs_loader)
plot_uncertainty_over_grid(grid=grid, scores=grid_entropies, test_modes=test_modes)
plt.show()

# %% [markdown]

# We can clearly see that the SNGP model provides much better uncertainty estimates compared to the deterministic one.
Loading

0 comments on commit e3c053d

Please sign in to comment.