-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathbenchmark_models.py
36 lines (28 loc) · 1.29 KB
/
benchmark_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""
GP Model definitions to be used in the benchmarks
"""
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import LinearKernel, ScaleKernel
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
from gauche.kernels.fingerprint_kernels.tanimoto_kernel import TanimotoKernel
class TanimotoGP(ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(TanimotoGP, self).__init__(train_x, train_y, likelihood)
self.mean_module = ConstantMean()
# We use the Tanimoto kernel to work with molecular fingerprint representations
self.covar_module = ScaleKernel(TanimotoKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
class ScalarProductGP(ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(ScalarProductGP, self).__init__(train_x, train_y, likelihood)
self.mean_module = ConstantMean()
# We use the scalar product kernel on molecular fingerprint representations
self.covar_module = ScaleKernel(LinearKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)