From e3c053ded54d077dc9b77d20bd7b57eb55d97912 Mon Sep 17 00:00:00 2001 From: Alberto Gasparin Date: Thu, 17 Aug 2023 15:00:03 +0200 Subject: [PATCH] Added API ood post-training ood classifiers - Added abstract base class for ood classifiers - Added Mahalanobis OOD classifiers - Added DDU --- docs/source/references/ood_classifier.rst | 10 + docs/source/references/references.rst | 1 + examples/index.rst | 2 +- examples/two_moons_classification_ood.pct.py | 337 ++++++++++++++++++ examples/two_moons_classification_sngp.pct.py | 215 ----------- fortuna/ood_detection/__init__.py | 0 fortuna/ood_detection/base.py | 32 ++ fortuna/ood_detection/ddu.py | 164 +++++++++ fortuna/ood_detection/mahalanobis.py | 161 +++++++++ 9 files changed, 706 insertions(+), 216 deletions(-) create mode 100644 docs/source/references/ood_classifier.rst create mode 100644 examples/two_moons_classification_ood.pct.py delete mode 100644 examples/two_moons_classification_sngp.pct.py create mode 100644 fortuna/ood_detection/__init__.py create mode 100644 fortuna/ood_detection/base.py create mode 100644 fortuna/ood_detection/ddu.py create mode 100644 fortuna/ood_detection/mahalanobis.py diff --git a/docs/source/references/ood_classifier.rst b/docs/source/references/ood_classifier.rst new file mode 100644 index 00000000..317ef7eb --- /dev/null +++ b/docs/source/references/ood_classifier.rst @@ -0,0 +1,10 @@ +.. _ood_detection: + +Out-Of-Distribution (OOD) detection +================== +Starting from a trained a neural classifier, it's possible to fit one of the models below +to help distinguish between in-distribution and out of distribution inputs. + +.. autoclass:: fortuna.ood_detection.mahalanobis.MalahanobisOODClassifier + +.. autoclass:: fortuna.ood_detection.ddu.DeepDeterministicUncertaintyOODClassifier diff --git a/docs/source/references/references.rst b/docs/source/references/references.rst index 2fa8f13e..cd89d17c 100644 --- a/docs/source/references/references.rst +++ b/docs/source/references/references.rst @@ -12,6 +12,7 @@ API References output_calibrator prob_output_layer conformal + ood_detection data_loader metric utils diff --git a/examples/index.rst b/examples/index.rst index cf3ae3c9..1f7539ca 100644 --- a/examples/index.rst +++ b/examples/index.rst @@ -16,8 +16,8 @@ In this section we show some examples of how to use Fortuna in classification an multivalid_coverage sinusoidal_regression two_moons_classification + two_moons_classification_ood subnet_calibration - two_moons_classification_sngp scaling_up_bayesian_inference mnist_classification_sghmc sgmcmc_diagnostics diff --git a/examples/two_moons_classification_ood.pct.py b/examples/two_moons_classification_ood.pct.py new file mode 100644 index 00000000..6b4ce9aa --- /dev/null +++ b/examples/two_moons_classification_ood.pct.py @@ -0,0 +1,337 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.14.5 +# kernelspec: +# display_name: python3 +# language: python +# name: python3 +# --- + +# %% [markdown] + +# # Two-moons Classification: Improved uncertainty quantification + +# %% [markdown] +# In this notebook we will see how to fix model overconfidence over inputs that are far-away from the training data. +# We will do that using two different approaches; let's dive right into it! + + +# %% [markdown] +# ### Setup +# #### Download the Two-Moons data from scikit-learn +# Let us first download the two-moons data from [scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html). + +# %% +TRAIN_DATA_SIZE = 500 + +from sklearn.datasets import make_moons + +train_data = make_moons(n_samples=TRAIN_DATA_SIZE, noise=0.1, random_state=0) +val_data = make_moons(n_samples=500, noise=0.1, random_state=1) +test_data = make_moons(n_samples=500, noise=0.1, random_state=2) + +# %% [markdown] +# #### Convert data to a compatible data loader +# Fortuna helps you convert data and data loaders into a data loader that Fortuna can digest. + +# %% +from fortuna.data import DataLoader + +train_data_loader = DataLoader.from_array_data( + train_data, batch_size=256, shuffle=True, prefetch=True +) +val_data_loader = DataLoader.from_array_data(val_data, batch_size=256, prefetch=True) +test_data_loader = DataLoader.from_array_data(test_data, batch_size=256, prefetch=True) + +# %% [markdown] +# #### Define some utils for plotting the estimated uncertainty + +# %% +import matplotlib.pyplot as plt +import numpy as np +from fortuna.data import InputsLoader +from fortuna.prob_model import ProbClassifier +import jax.numpy as jnp + + +def get_grid_inputs_loader(grid_size: int = 100): + xx = np.linspace(-4, 4, grid_size) + yy = np.linspace(-4, 4, grid_size) + grid = np.array([[_xx, _yy] for _xx in xx for _yy in yy]) + grid_inputs_loader = InputsLoader.from_array_inputs(grid) + grid = grid.reshape(grid_size, grid_size, 2) + return grid, grid_inputs_loader + + +def compute_test_modes( + prob_model: ProbClassifier, test_data_loader: DataLoader +): + test_inputs_loader = test_data_loader.to_inputs_loader() + test_means = prob_model.predictive.mean(inputs_loader=test_inputs_loader) + return prob_model.predictive.mode( + inputs_loader=test_inputs_loader, means=test_means + ) + +def plot_uncertainty_over_grid( + grid: jnp.ndarray, scores: jnp.ndarray, test_modes: jnp.ndarray, title: str = "Predictive uncertainty" +): + scores = scores.reshape(grid.shape[0], grid.shape[1]) + + _, ax = plt.subplots(figsize=(7, 5.5)) + plt.title(title, fontsize=12) + pcm = ax.imshow( + scores.T, + origin="lower", + extent=(-4., 4., -4., 4.), + interpolation='bicubic', + aspect='auto') + + # Plot training data. + plt.scatter( + test_data[0][:, 0], + test_data[0][:, 1], + s=1, + c=["C0" if i == 1 else "C1" for i in test_modes], + ) + plt.colorbar() + + +# %% [markdown] +# ### Define the deterministic model +# In this tutorial we will use a deep residual network, see `fortuna.model.mlp.DeepResidualNet` for +# more details on the model. + +# %% +from fortuna.model.mlp import DeepResidualNet +import flax.linen as nn + +output_dim = 2 +model = DeepResidualNet( + output_dim=output_dim, + activations=(nn.relu, nn.relu, nn.relu, nn.relu, nn.relu, nn.relu), + widths=(128, 128, 128, 128, 128, 128), + dropout_rate=0.1, +) + +# %% +from fortuna.prob_model import MAPPosteriorApproximator +from fortuna.prob_model import FitConfig, FitMonitor, FitOptimizer +from fortuna.metric.classification import accuracy + + +prob_model = ProbClassifier( + model=model, + posterior_approximator=MAPPosteriorApproximator(), + output_calibrator=None, +) +status = prob_model.train( + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + calib_data_loader=val_data_loader, + fit_config=FitConfig( + monitor=FitMonitor(metrics=(accuracy,)), + optimizer=FitOptimizer(n_epochs=100), + ), +) + +# %% +test_modes = compute_test_modes(prob_model, test_data_loader) +grid, grid_inputs_loader = get_grid_inputs_loader(grid_size=100) +grid_entropies = prob_model.predictive.entropy(grid_inputs_loader) +plot_uncertainty_over_grid(grid=grid, scores=grid_entropies, test_modes=test_modes) +plt.show() + +# %% [markdown] +# Clearly, the model is overconfident on inputs that are far away from the training data. +# This behaviour is not what one would expect, as we rather the model being less confident on out-of-distributin inputs. + +# %% [markdown] +# ### Fit an OOD classifier to distinguish between in-distribution and out-of-distribution inputs +# Given the trained model from above, we can now use one of the models provided by Fortuna to actually improve +# the model's confidence on the out-of-distribution inputs. +# In the example below we will use the Malahanobis-based classifier introduced in +# [Lee, Kimin, et al](https://proceedings.neurips.cc/paper/2018/file/abdeb6f575ac5c6676b747bca8d09cc2-Paper.pdf) + +# %% +from fortuna.ood_detection.mahalanobis import MalahanobisOODClassifier +from fortuna.model.mlp import DeepResidualFeatureExtractorSubNet +import jax + + +feature_extractor_subnet=DeepResidualFeatureExtractorSubNet( + dense=model.dense, + widths=model.widths, + activations=model.activations, + dropout=model.dropout, + dropout_rate=model.dropout_rate, + ) + +@jax.jit +def _apply(inputs, params, mutable): + variables = {'params': params["model"]['params']['dfe_subnet'].unfreeze()} + if mutable is not None: + mutable_variables = {k: v['dfe_subnet'].unfreeze() for k, v in mutable["model"].items()} + variables.update(mutable_variables) + return feature_extractor_subnet.apply(variables, inputs, train=False, mutable=False) + +ood_classifier = MalahanobisOODClassifier(num_classes=2) + +# %% [markdown] +# In the code block above we initialize our classifier (`MalahanobisOODClassifier`) and we also define a +# `feature_extractor_subnet`, which is a sub-network of our previously trained model +# that allow one to transform an input vector into an embedding vector. In the example, this is our original model +# (`DeepResidualNet`) without the output layer. +# We are now ready to fit the classifier using our training data and verify whether the model's overconfidence has been +# (at least partially) fixed: + +# %% +from typing import Tuple + +import tqdm + +from fortuna.data.loader.base import BaseDataLoaderABC, BaseInputsLoader +from fortuna.prob_model.posterior.state import PosteriorState +from fortuna.typing import Array + +# define some util functions + + +def get_embeddings_and_targets(state: PosteriorState, train_data_loader: BaseDataLoaderABC) -> Tuple[Array, Array]: + train_labels = [] + train_embeddings = [] + for x, y in tqdm.tqdm( + train_data_loader, desc="Computing embeddings for Malhanbis Classifier: " + ): + train_embeddings.append( + _apply(inputs=x, params=state.params, mutable=state.mutable) + ) + train_labels.append(y) + train_embeddings = jnp.concatenate(train_embeddings, 0) + train_labels = jnp.concatenate(train_labels) + return train_embeddings, train_labels + + +def get_embeddings(state: PosteriorState, inputs_loader: BaseInputsLoader): + return jnp.concatenate( + [ + _apply(inputs=x, params=state.params, mutable=state.mutable) + for x in inputs_loader + ], + 0, + ) + +# %% +state = prob_model.posterior.state.get() +train_embeddings, train_labels = get_embeddings_and_targets(state=state, train_data_loader=train_data_loader) +ood_classifier.fit(embeddings=train_embeddings, targets=train_labels) +grid, grid_inputs_loader = get_grid_inputs_loader(grid_size=100) +grid_embeddings = get_embeddings(state=state, inputs_loader=grid_inputs_loader) +grid_scores = ood_classifier.score(embeddings=grid_embeddings) +# for the sake of plotting we set a threshold on the OOD classifier scores using the max score +# obtained from a known in-distribution source +ind_embeddings = get_embeddings(state=state, inputs_loader=val_data_loader.to_inputs_loader()) +ind_scores = ood_classifier.score(embeddings=ind_embeddings) +threshold = ind_scores.max()*2 +grid_scores = jnp.where(grid_scores < threshold, grid_scores, threshold) +plot_uncertainty_over_grid(grid=grid, scores=grid_scores, test_modes=test_modes, title="OOD scores") +plt.show() + + +# %% [markdown] +# We will now see a different way of obtaining improved uncertainty estimation +# (for out-of-distribution inputs): [SNGP](https://arxiv.org/abs/2006.10108). +# Unlike before, we now have to retrain the model as the architecture will slighly change. +# The reason for this will be clear from the model definition below. + +# %% [markdown] +# ### Define the SNGP model +# Compared to the deterministic model obtained in the first part of this notebook, SNGP has two crucial differences: +# +# 1. [Spectral Normalization](https://arxiv.org/abs/1802.05957) is applied to all Dense (or Convolutional) layers. +# 2. The Dense output layer is replaced with a Gaussian Process layer. +# +# Let's see how to do it in Fortuna: + +# %% [markdown] +# In order to add Spectral Normalization to a deterministic network we just need to define a new deep feature extractor, +# inheriting from both the feature extractor used by the deterministic model (in this case `MLPDeepFeatureExtractorSubNet`) +# and `WithSpectralNorm`. It is worth highlighting that `WithSpectralNorm` should be taken as is, while the deep feature extractor +# can be replaced with any custom object: + +# %% +from fortuna.model.mlp import DeepResidualFeatureExtractorSubNet +from fortuna.model.utils.spectral_norm import WithSpectralNorm + + +class SNGPDeepFeatureExtractorSubNet( + WithSpectralNorm, DeepResidualFeatureExtractorSubNet +): + pass + + +# %% [markdown] +# Then, we can define our SNGP model by: +# +# - Replacing the deep feature extractor: from `MLPDeepFeatureExtractorSubNet` to `SNGPDeepFeatureExtractorSubNet` +# - Using the `SNGPPosteriorApproximator` as the `posterior_approximator` for the `ProbModel`. +# +# Nothing else is needed, Fortuna will take care of the rest for you! + +# %% +import jax.numpy as jnp + +from fortuna.prob_model.prior import IsotropicGaussianPrior +from fortuna.prob_model import SNGPPosteriorApproximator + +output_dim = 2 +model = SNGPDeepFeatureExtractorSubNet( + activations=tuple([nn.relu] * 6), + widths=tuple([128] * 6), + dropout_rate=0.1, + spectral_norm_bound=0.9, +) + +prob_model = ProbClassifier( + model=model, + prior=IsotropicGaussianPrior( + log_var=jnp.log(1.0 / 1e-4) - jnp.log(TRAIN_DATA_SIZE) + ), + posterior_approximator=SNGPPosteriorApproximator(output_dim=output_dim), + output_calibrator=None, +) + +# %% [markdown] +# Notice that the only required argument when initializing `SNGPPosteriorApproximator` is +# `output_dim`, which should be set to the number of classes in the classification task. +# `SNGPPosteriorApproximator` has more optional parameters that you can play with, to gain a better understanding of those you can +# check out the documentation and/or the [original paper](https://arxiv.org/abs/2006.10108). + +# %% [markdown] +# We are now ready to train the model as we usually do: + +# %% +status = prob_model.train( + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + calib_data_loader=val_data_loader, + fit_config=FitConfig( + monitor=FitMonitor(metrics=(accuracy,)), + optimizer=FitOptimizer(n_epochs=100), + ), +) + +# %% +test_modes = compute_test_modes(prob_model, test_data_loader) +grid, grid_inputs_loader = get_grid_inputs_loader(grid_size=100) +grid_entropies = prob_model.predictive.entropy(grid_inputs_loader) +plot_uncertainty_over_grid(grid=grid, scores=grid_entropies, test_modes=test_modes) +plt.show() + +# %% [markdown] + +# We can clearly see that the SNGP model provides much better uncertainty estimates compared to the deterministic one. diff --git a/examples/two_moons_classification_sngp.pct.py b/examples/two_moons_classification_sngp.pct.py deleted file mode 100644 index 181bf8a7..00000000 --- a/examples/two_moons_classification_sngp.pct.py +++ /dev/null @@ -1,215 +0,0 @@ -# --- -# jupyter: -# jupytext: -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.14.5 -# kernelspec: -# display_name: python3 -# language: python -# name: python3 -# --- - -# %% [markdown] - -# # Two-moons Classification: Improved uncertainty quantification with SNGP - -# %% [markdown] -# In this notebook we show how to train an [SNGP](https://arxiv.org/abs/2006.10108) model using Fortuna, showing improved -# uncertainty estimation on the two moons dataset with respect to it's deterministic counterpart. - - -# %% [markdown] -# ### Download the Two-Moons data from scikit-learn -# Let us first download the two-moons data from [scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html). - -# %% - -TRAIN_DATA_SIZE = 500 - -from sklearn.datasets import make_moons - -train_data = make_moons(n_samples=TRAIN_DATA_SIZE, noise=0.1, random_state=0) -val_data = make_moons(n_samples=500, noise=0.1, random_state=1) -test_data = make_moons(n_samples=500, noise=0.1, random_state=2) - -# %% [markdown] -# ### Convert data to a compatible data loader -# Fortuna helps you convert data and data loaders into a data loader that Fortuna can digest. - -# %% -from fortuna.data import DataLoader - -train_data_loader = DataLoader.from_array_data( - train_data, batch_size=256, shuffle=True, prefetch=True -) -val_data_loader = DataLoader.from_array_data(val_data, batch_size=256, prefetch=True) -test_data_loader = DataLoader.from_array_data(test_data, batch_size=256, prefetch=True) - -# %% [markdown] -# ### Define some utils for plotting the estimated uncertainty - -# %% -import matplotlib.pyplot as plt -import numpy as np -from fortuna.data import InputsLoader -from fortuna.prob_model import ProbClassifier - - -def plot_uncertainty( - prob_model: ProbClassifier, test_data_loader: DataLoader, grid_size: int = 100 -): - test_inputs_loader = test_data_loader.to_inputs_loader() - test_means = prob_model.predictive.mean(inputs_loader=test_inputs_loader) - test_modes = prob_model.predictive.mode( - inputs_loader=test_inputs_loader, means=test_means - ) - - fig = plt.figure(figsize=(6, 3)) - xx = np.linspace(-5, 5, grid_size) - yy = np.linspace(-5, 5, grid_size) - grid = np.array([[_xx, _yy] for _xx in xx for _yy in yy]) - grid_loader = InputsLoader.from_array_inputs(grid) - grid_entropies = prob_model.predictive.entropy(grid_loader).reshape( - grid_size, grid_size - ) - grid = grid.reshape(grid_size, grid_size, 2) - plt.title("Predictive uncertainty", fontsize=12) - im = plt.pcolor(grid[:, :, 0], grid[:, :, 1], grid_entropies) - plt.scatter( - test_data[0][:, 0], - test_data[0][:, 1], - s=1, - c=["C0" if i == 1 else "C1" for i in test_modes], - ) - plt.colorbar() - - -# %% [markdown] -# ### Define the deterministic model -# In this tutorial we will use a deep residual network, see `fortuna.model.mlp.DeepResidualNet` for -# more details on the model. - -# %% -from fortuna.model.mlp import DeepResidualNet -import flax.linen as nn - -output_dim = 2 -model = DeepResidualNet( - output_dim=output_dim, - activations=(nn.relu, nn.relu, nn.relu, nn.relu, nn.relu, nn.relu), - widths=(128, 128, 128, 128, 128, 128), - dropout_rate=0.1, -) - -# %% -from fortuna.prob_model import MAPPosteriorApproximator -from fortuna.prob_model import FitConfig, FitMonitor, FitOptimizer -from fortuna.metric.classification import accuracy - - -prob_model = ProbClassifier( - model=model, - posterior_approximator=MAPPosteriorApproximator(), - output_calibrator=None, -) -status = prob_model.train( - train_data_loader=train_data_loader, - val_data_loader=val_data_loader, - calib_data_loader=val_data_loader, - fit_config=FitConfig( - monitor=FitMonitor(metrics=(accuracy,)), - optimizer=FitOptimizer(n_epochs=100), - ), -) - -# %% -plot_uncertainty(prob_model, test_data_loader, grid_size=100) -plt.show() - -# %% [markdown] -# ### Define the SNGP model -# Compared to the deterministic model obtained above, SNGP has two crucial differences: -# -# 1. [Spectral Normalization](https://arxiv.org/abs/1802.05957) is applied to all Dense (or Convolutional) layers. -# 2. The Dense output layer is replaced with a Gaussian Process layer. -# -# Let's see how to do it in Fortuna: - -# %% [markdown] -# In order to add Spectral Normalization to a deterministic network we just need to define a new deep feature extractor, -# inheriting from both the feature extractor used by the deterministic model (in this case `MLPDeepFeatureExtractorSubNet`) -# and `WithSpectralNorm`. It is worth highlighting that `WithSpectralNorm` should be taken as is, while the deep feature extractor -# can be replaced with any custom object: - -# %% -from fortuna.model.mlp import DeepResidualFeatureExtractorSubNet -from fortuna.model.utils.spectral_norm import WithSpectralNorm - - -class SNGPDeepFeatureExtractorSubNet( - WithSpectralNorm, DeepResidualFeatureExtractorSubNet -): - pass - - -# %% [markdown] -# Then, we can define our SNGP model by: -# -# - Replacing the deep feature extractor: from `MLPDeepFeatureExtractorSubNet` to `SNGPDeepFeatureExtractorSubNet` -# - Using the `SNGPPosteriorApproximator` as the `posterior_approximator` for the `ProbModel`. -# -# Nothing else is needed, Fortuna will take care of the rest for you! - -# %% -import jax.numpy as jnp - -from fortuna.prob_model.prior import IsotropicGaussianPrior -from fortuna.prob_model import SNGPPosteriorApproximator - -output_dim = 2 -model = SNGPDeepFeatureExtractorSubNet( - activations=tuple([nn.relu] * 6), - widths=tuple([128] * 6), - dropout_rate=0.1, - spectral_norm_bound=0.9, -) - -prob_model = ProbClassifier( - model=model, - prior=IsotropicGaussianPrior( - log_var=jnp.log(1.0 / 1e-4) - jnp.log(TRAIN_DATA_SIZE) - ), - posterior_approximator=SNGPPosteriorApproximator(output_dim=output_dim), - output_calibrator=None, -) - -# %% [markdown] -# Notice that the only required argument when initializing `SNGPPosteriorApproximator` is -# `output_dim`, which should be set to the number of classes in the classification task. -# `SNGPPosteriorApproximator` has more optional parameters that you can play with, to gain a better understanding of those you can -# check out the documentation and/or the [original paper](https://arxiv.org/abs/2006.10108). - -# %% [markdown] -# We are now ready to train the model as we usually do: - -# %% -status = prob_model.train( - train_data_loader=train_data_loader, - val_data_loader=val_data_loader, - calib_data_loader=val_data_loader, - fit_config=FitConfig( - monitor=FitMonitor(metrics=(accuracy,)), - optimizer=FitOptimizer(n_epochs=100), - ), -) - -# %% -plot_uncertainty(prob_model, test_data_loader, grid_size=100) -plt.show() - -# %% [markdown] - -# We can clearly see that the SNGP model provides much better uncertainty estimates compared to the deterministic one. diff --git a/fortuna/ood_detection/__init__.py b/fortuna/ood_detection/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/ood_detection/base.py b/fortuna/ood_detection/base.py new file mode 100644 index 00000000..78444d7a --- /dev/null +++ b/fortuna/ood_detection/base.py @@ -0,0 +1,32 @@ +import abc + +from fortuna.typing import Array + + +class NotFittedError(ValueError, AttributeError): + """Exception class to raise if estimator is used before fitting.""" + + +class OutOfDistributionClassifierABC: + """ + Post-training classifier that uses the training sample embeddings coming from the model + to score a (new) test sample w.r.t. its chance of belonging to the original training distribution + (i.e, it is in-distribution) or not (i.e., it is out of distribution). + """ + + def __init__(self, num_classes: int): + """ + Parameters + ---------- + num_classes: int + The number of classes for the in-distribution classification task. + """ + self.num_classes = num_classes + + @abc.abstractmethod + def fit(self, embeddings: Array, targets: Array) -> None: + pass + + @abc.abstractmethod + def score(self, embeddings: Array) -> Array: + pass diff --git a/fortuna/ood_detection/ddu.py b/fortuna/ood_detection/ddu.py new file mode 100644 index 00000000..4a7c9fa4 --- /dev/null +++ b/fortuna/ood_detection/ddu.py @@ -0,0 +1,164 @@ +""" +Adapted from https://github.com/omegafragger/DDU/blob/main/utils/gmm_utils.py +""" +import logging +from typing import ( + Callable, + Tuple, +) + +import jax +import jax.scipy as jsp +import jax.scipy.stats as jsp_stats +import numpy as np +from jax import numpy as jnp + +from fortuna.ood_detection.base import ( + NotFittedError, + OutOfDistributionClassifierABC, +) +from fortuna.typing import Array + +DOUBLE_INFO = np.finfo(np.double) +JITTERS = [0, DOUBLE_INFO.tiny] + [10**exp for exp in range(-10, 0, 1)] + + +def _centered_cov(x: Array) -> Array: + n = x.shape[0] + res = jnp.matmul(1 / (n - 1) * x.T, x) + return res + + +def compute_classwise_mean_and_cov( + embeddings: Array, labels: Array, num_classes: int +) -> Tuple[Array, Array]: + """ + Computes class-specific means and a covariance matrices given the training set embeddings + (e.g., the last-layer representation of the model for each training example). + + Parameters + ---------- + embeddings: Array + The embeddings of shape `(n, d)` where `n` is the number of training samples and `d` is the embbeding's size. + labels: Array + An array of length `n` containing, for each input sample, its ground-truth label. + num_classes: int + The total number of classes available in the classification task. + + Returns + ---------- + Tuple[Array, Array]: + A tuple containing: + 1) an `Array` containing the per-class mean vector of the fitted GMM. + The shape of the array is `(num_classes, d)`. + 2) an `Array` containing the per-class covariance matrix of the fitted GMM. + The shape of the array is `(num_classes, d, d)`. + """ + # + classwise_mean_features = np.stack( + [jnp.mean(embeddings[labels == c], 0) for c in range(num_classes)] + ) + # + classwise_cov_features = np.stack( + [ + _centered_cov(embeddings[labels == c] - classwise_mean_features[c]) + for c in range(num_classes) + ] + ) + return classwise_mean_features, classwise_cov_features + + +def _get_logpdf_fn( + classwise_mean_features: Array, classwise_cov_features: Array +) -> Callable[[Array], Array]: + """ + Returns a function to evaluate the log-likelihood of a test sample according to the (fitted) GMM. + + Parameters + ---------- + classwise_mean_features: Array + The per-class mean vector of the fitted GMM. The shape of the array is `(num_classes, d)`. + classwise_cov_features: Array + The per-class covariance matrix of the fitted GMM. The shape of the array is `(num_classes, d, d)`. + + Returns + ------- + Callable[[Array], Array] + A function to evaluate the log-likelihood of a test sample according to the (fitted) GMM. + """ + for jitter_eps in JITTERS: + jitter = np.expand_dims(jitter_eps * np.eye(classwise_cov_features.shape[1]), 0) + gmm_logprob_fn_vmapped = jax.vmap( + jsp_stats.multivariate_normal.logpdf, in_axes=(None, 0, 0) + ) + gmm_logprob_fn = lambda x: gmm_logprob_fn_vmapped( + x, classwise_mean_features, (classwise_cov_features + jitter) + ).T + + nans = np.isnan(gmm_logprob_fn(classwise_mean_features)).sum() + if nans > 0: + logging.info(f"Nans, jittering {jitter_eps}") + continue + break + + return gmm_logprob_fn + + +class DeepDeterministicUncertaintyOODClassifier(OutOfDistributionClassifierABC): + """ + A Gaussian Mixture Model :math:`q(\mathbf{x}, z)` with a single Gaussian mixture component per class :math:`k \in {1,...,K}` + is fit after training. + Each class component is fit computing the empirical mean :math:`\mathbf{\hat{\mu}_k}` and covariance matrix + :math:`\mathbf{\hat{\Sigma}_k}` of the feature vectors :math:`f(\mathbf{x})`. + + The confidence score :math:`M(\mathbf{x})` for a new test sample is obtained computing the negative marginal likelihood + of the feature representation. + + See `Mukhoti, Jishnu, et al. `_ + """ + + def __init__(self, *args, **kwargs): + super(DeepDeterministicUncertaintyOODClassifier, self).__init__(*args, **kwargs) + self._gmm_logpdf_fn = None + + def fit(self, embeddings: Array, targets: Array) -> None: + """ + Fits a Multivariate Gaussian to the training data using class-specific means and covariance matrix. + + Parameters + ---------- + embeddings: Array + The embeddings of shape `(n, d)` where `n` is the number of training samples and `d` is the embbeding's size. + targets: Array + An array of length `n` containing, for each input sample, its ground-truth label. + """ + ( + classwise_mean_features, + classwise_cov_features, + ) = compute_classwise_mean_and_cov(embeddings, targets, self.num_classes) + self._gmm_logpdf_fn = _get_logpdf_fn( + classwise_mean_features, classwise_cov_features + ) + + def score(self, embeddings: Array) -> Array: + """ + The confidence score :math:`M(\mathbf{x})` for a new test sample :math:`\mathbf{x}` is obtained computing + the negative marginal likelihood of the feature representation + :math:`-q(f(\mathbf{x})) = - \sum\limits_{k}q(f(\mathbf{x})|y) q(y)`. + + A high score signals that the test sample :math:`\mathbf{x}` is identified as OOD. + + Parameters + ---------- + embeddings: Array + The embeddings of shape `(n, d)` where `n` is the number of test samples and `d` is the embbeding's size. + + Returns + ------- + Array + An array of scores with length `n`. + """ + if self._gmm_logpdf_fn is None: + raise NotFittedError("You have to call fit before calling score.") + loglik = self._gmm_logpdf_fn(embeddings) + return -jsp.special.logsumexp(jnp.nan_to_num(loglik, 0.0), axis=1) diff --git a/fortuna/ood_detection/mahalanobis.py b/fortuna/ood_detection/mahalanobis.py new file mode 100644 index 00000000..9f7c4e04 --- /dev/null +++ b/fortuna/ood_detection/mahalanobis.py @@ -0,0 +1,161 @@ +import logging +from typing import Tuple + +import jax +import jax.numpy as jnp + +from fortuna.ood_detection.base import ( + NotFittedError, + OutOfDistributionClassifierABC, +) +from fortuna.typing import Array + + +@jax.jit +def compute_mean_and_joint_cov( + embeddings: jnp.ndarray, labels: jnp.ndarray, class_ids: jnp.ndarray +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Computes class-specific means and a shared covariance matrix given the training set embeddings + (e.g., the last-layer representation of the model for each training example). + + Parameters + ---------- + embeddings: jnp.ndarray + An array of shape `(n, d)`, where `n` is the sample size of training set, + `d` is the dimension of the embeddings. + labels: jnp.ndarray + An array of shape `(n,)` + class_ids: jnp.ndarray + An array of the unique class ids in `labels`. + + Returns + ------- + Tuple[jnp.ndarray, jnp.ndarray] + A tuple containing: + 1) A `jnp.ndarray` of len n_class, and the i-th element is an np.array of size + ` (d,)` corresponding to the mean of the fitted Gaussian distribution for the i-th class; + 2) The shared covariance matrix of shape `(d, d)`. + """ + n_dim = embeddings.shape[1] + cov = jnp.zeros((n_dim, n_dim)) + + def f(cov, class_id): + mask = jnp.expand_dims(labels == class_id, axis=-1) + data = embeddings * mask + mean = jnp.sum(data, axis=0) / jnp.sum(mask) + diff = (data - mean) * mask + cov += jnp.matmul(diff.T, diff) + return cov, mean + + cov, means = jax.lax.scan(f, cov, class_ids) + cov = cov / len(labels) + return means, cov + + +@jax.jit +def compute_mahalanobis_distance( + embeddings: jnp.ndarray, means: jnp.ndarray, cov: jnp.ndarray +) -> jnp.ndarray: + """ + Computes Mahalanobis distance between the input and the fitted Guassians. + + Parameters + ---------- + embeddings: jnp.ndarray + A matrix of shape `(n, d)`, where `n` is the sample size of the test set, and `d` is the size of the embeddings. + means: jnp.ndarray + A matrix of shape `(c, d)`, where `c` is the number of classes in the classification task. + The ith row of the matrix corresponds to the mean of the fitted Gaussian distribution for the i-th class. + cov: jnp.ndarray + The shared covariance mmatrix of the shape `(d, d)`. + + Returns + ------- + A matrix of size `(n, c)` where the `(i, j)` element + corresponds to the Mahalanobis distance between i-th sample to the j-th + class Gaussian. + """ + # NOTE: It's possible for `cov` to be singular, in part because it is + # estimated on a sample of data. This can be exacerbated by lower precision, + # where, for example, the matrix could be non-singular in float64, but + # singular in float32. For our purposes in computing Mahalanobis distance, + # using a pseudoinverse is a reasonable approach that will be equivalent to + # the inverse if `cov` is non-singular. + cov_inv = jnp.linalg.pinv(cov) + + def maha_dist(x, mean): + # NOTE: This computes the squared Mahalanobis distance. + diff = x - mean + return jnp.einsum("i,ij,j->", diff, cov_inv, diff) + + maha_dist_all_classes_fn = jax.vmap(maha_dist, in_axes=(None, 0)) + out = jax.lax.map(lambda x: maha_dist_all_classes_fn(x, means), embeddings) + return out + + +class MalahanobisOODClassifier(OutOfDistributionClassifierABC): + """ + The pre-trained features of a softmax neural classifier :math:`f(\mathbf{x})` are assumed to follow a + class-conditional gaussian distribution with a tied covariance matrix :math:`\mathbf{\Sigma}`: + + .. math:: + \mathbb{P}(f(\mathbf{x})|y=k) = \mathcal{N}(f(\mathbf{x})|\mu_k, \mathbf{\Sigma}) + + for all :math:`k \in {1,...,K}`, where :math:`K` is the number of classes. + + The confidence score :math:`M(\mathbf{x})` for a new test sample :math:`\mathbf{x}` is obtained computing + the max (squared) Mahalanobis distance between :math:`f(\mathbf{x})` and the fitted class-wise guassians. + + See `Lee, Kimin, et al. `_ + """ + + def __init__(self, *args, **kwargs): + super(MalahanobisOODClassifier, self).__init__(*args, **kwargs) + self._maha_dist_all_classes_fn = None + + def fit(self, embeddings: Array, targets: Array) -> None: + """ + Fits a Multivariate Gaussian to the training data using class-specific means and a shared covariance matrix. + + Parameters + ---------- + embeddings: Array + The embeddings of shape `(n, d)` where `n` is the number of training samples and `d` is the embbeding's size. + targets: Array + An array of length `n` containing, for each input sample, its ground-truth label. + """ + n_labels_observed = len(jnp.unique(targets)) + if n_labels_observed != self.num_classes: + logging.warning( + f"{self.num_classes} labels were expected but found {n_labels_observed} in the provided train set. " + f"Will proceed but performance may be hurt by this." + ) + + means, cov = compute_mean_and_joint_cov( + embeddings, targets, jnp.arange(self.num_classes) + ) + self._maha_dist_all_classes_fn = lambda x: compute_mahalanobis_distance( + x, means, cov + ) + + def score(self, embeddings: Array) -> Array: + """ + The confidence score :math:`M(\mathbf{x})` for a new test sample :math:`\mathbf{x}` is obtained computing + the max (squared) Mahalanobis distance between :math:`f(\mathbf{x})` and the fitted class-wise Guassians. + + A high score signals that the test sample :math:`\mathbf{x}` is identified as OOD. + + Parameters + ---------- + embeddings: Array + The embeddings of shape `(n, d)` where `n` is the number of test samples and `d` is the embbeding's size. + + Returns + ------- + Array + An array of scores with length `n`. + """ + if self._maha_dist_all_classes_fn is None: + raise NotFittedError("You have to call fit before calling score.") + return self._maha_dist_all_classes_fn(embeddings).min(axis=1)