-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathgpytorch_metrics.py
75 lines (64 loc) · 2.11 KB
/
gpytorch_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Module containing GPyTorch metrics defined here:
https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/metrics/metrics.py
Not yet included in the latest release.
TODO: Once new release of GPyTorch becomes available, remove this module.
"""
from math import pi
import torch
from gpytorch.distributions import (
MultitaskMultivariateNormal,
MultivariateNormal,
)
pi = torch.tensor(pi)
def negative_log_predictive_density(
pred_dist: MultivariateNormal,
test_y: torch.Tensor,
):
combine_dim = (
-2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
)
return -pred_dist.log_prob(test_y) / test_y.shape[combine_dim]
def mean_standardized_log_loss(
pred_dist: MultivariateNormal,
test_y: torch.Tensor,
):
"""
Mean Standardized Log Loss.
Reference: Page No. 23,
Gaussian Processes for Machine Learning,
Carl Edward Rasmussen and Christopher K. I. Williams,
The MIT Press, 2006. ISBN 0-262-18253-X
"""
combine_dim = (
-2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
)
f_mean = pred_dist.mean
f_var = pred_dist.variance
return 0.5 * (
torch.log(2 * pi * f_var) + torch.square(test_y - f_mean) / (2 * f_var)
).mean(dim=combine_dim)
def quantile_coverage_error(
pred_dist: MultivariateNormal,
test_y: torch.Tensor,
quantile: float = 95.0,
):
"""
Quantile coverage error.
"""
if quantile <= 0 or quantile >= 100:
raise NotImplementedError("Quantile must be between 0 and 100")
combine_dim = (
-2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
)
standard_normal = torch.distributions.Normal(loc=0.0, scale=1.0)
deviation = standard_normal.icdf(
torch.as_tensor(0.5 + 0.5 * (quantile / 100))
)
lower = pred_dist.mean - deviation * pred_dist.stddev
upper = pred_dist.mean + deviation * pred_dist.stddev
n_samples_within_bounds = ((test_y > lower) * (test_y < upper)).sum(
combine_dim
)
fraction = n_samples_within_bounds / test_y.shape[combine_dim]
return torch.abs(fraction - quantile / 100)