Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-j committed May 24, 2022
1 parent 2774cea commit de8ccdb
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 76 deletions.
24 changes: 18 additions & 6 deletions benchmarks/gpytorch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def negative_log_predictive_density(
pred_dist: MultivariateNormal,
test_y: torch.Tensor,
):
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
combine_dim = (
-2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
)
return -pred_dist.log_prob(test_y) / test_y.shape[combine_dim]


Expand All @@ -37,10 +39,14 @@ def mean_standardized_log_loss(
Carl Edward Rasmussen and Christopher K. I. Williams,
The MIT Press, 2006. ISBN 0-262-18253-X
"""
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
combine_dim = (
-2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
)
f_mean = pred_dist.mean
f_var = pred_dist.variance
return 0.5 * (torch.log(2 * pi * f_var) + torch.square(test_y - f_mean) / (2 * f_var)).mean(dim=combine_dim)
return 0.5 * (
torch.log(2 * pi * f_var) + torch.square(test_y - f_mean) / (2 * f_var)
).mean(dim=combine_dim)


def quantile_coverage_error(
Expand All @@ -53,11 +59,17 @@ def quantile_coverage_error(
"""
if quantile <= 0 or quantile >= 100:
raise NotImplementedError("Quantile must be between 0 and 100")
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
combine_dim = (
-2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
)
standard_normal = torch.distributions.Normal(loc=0.0, scale=1.0)
deviation = standard_normal.icdf(torch.as_tensor(0.5 + 0.5 * (quantile / 100)))
deviation = standard_normal.icdf(
torch.as_tensor(0.5 + 0.5 * (quantile / 100))
)
lower = pred_dist.mean - deviation * pred_dist.stddev
upper = pred_dist.mean + deviation * pred_dist.stddev
n_samples_within_bounds = ((test_y > lower) * (test_y < upper)).sum(combine_dim)
n_samples_within_bounds = ((test_y > lower) * (test_y < upper)).sum(
combine_dim
)
fraction = n_samples_within_bounds / test_y.shape[combine_dim]
return torch.abs(fraction - quantile / 100)
195 changes: 143 additions & 52 deletions benchmarks/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,29 @@
# Remove Graphein warnings
logging.getLogger("graphein").setLevel("ERROR")

gp_models = {'Tanimoto': 'Tanimoto', 'Scalar Product': 'Scalar Product'}
dataset_names = {'Photoswitch': 'Photoswitch', 'ESOL': 'ESOL', 'FreeSolv': 'FreeSolv', 'Lipophilicity': 'Lipophilicity'}
dataset_paths = {'Photoswitch':'../data/property_prediction/photoswitches.csv',
'ESOL': '../data/property_prediction/ESOL.csv',
'FreeSolv': '../data/property_prediction/FreeSolv.csv',
'Lipophilicity': '../data/property_prediction/Lipophilicity.csv'}


def main(n_trials, test_set_size, dataset_name, dataset_path, featurisation, gp_model):
gp_models = {"Tanimoto": "Tanimoto", "Scalar Product": "Scalar Product"}
dataset_names = {
"Photoswitch": "Photoswitch",
"ESOL": "ESOL",
"FreeSolv": "FreeSolv",
"Lipophilicity": "Lipophilicity",
}
dataset_paths = {
"Photoswitch": "../data/property_prediction/photoswitches.csv",
"ESOL": "../data/property_prediction/ESOL.csv",
"FreeSolv": "../data/property_prediction/FreeSolv.csv",
"Lipophilicity": "../data/property_prediction/Lipophilicity.csv",
}


def main(
n_trials,
test_set_size,
dataset_name,
dataset_path,
featurisation,
gp_model,
):
"""
Args:
Expand All @@ -53,11 +67,15 @@ def main(n_trials, test_set_size, dataset_name, dataset_path, featurisation, gp_
"""

if dataset_name not in dataset_names.values():
raise ValueError(f"The specified dataset choice ({dataset_name}) is not a valid option. "
f"Choose one of {list(dataset_names.keys())}.")
raise ValueError(
f"The specified dataset choice ({dataset_name}) is not a valid option. "
f"Choose one of {list(dataset_names.keys())}."
)
if dataset_path not in dataset_paths.values():
raise ValueError(f"The specified dataset path ({dataset_path}) is not a valid option. "
f"Choose one of {list(dataset_paths.values())}.")
raise ValueError(
f"The specified dataset path ({dataset_path}) is not a valid option. "
f"Choose one of {list(dataset_paths.values())}."
)

# Load the benchmark dataset
loader = DataLoaderMP()
Expand All @@ -80,12 +98,16 @@ def main(n_trials, test_set_size, dataset_name, dataset_path, featurisation, gp_

for i in range(0, n_trials):

print(f'Trial {i} of {n_trials}')
print(f"Trial {i} of {n_trials}")

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_set_size, random_state=i)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_set_size, random_state=i
)

# We standardise the outputs but leave the inputs unchanged
_, y_train, _, y_test, y_scaler = transform_data(X_train, y_train, X_test, y_test)
_, y_train, _, y_test, y_scaler = transform_data(
X_train, y_train, X_test, y_test
)

# Specify the precision. GPyTorch has issues with large datasets and float64.
if y.size > 1000:
Expand All @@ -102,13 +124,15 @@ def main(n_trials, test_set_size, dataset_name, dataset_path, featurisation, gp_
# initialise GP likelihood and model
likelihood = GaussianLikelihood()

if gp_model == 'Tanimoto':
if gp_model == "Tanimoto":
model = TanimotoGP(X_train, y_train, likelihood)
elif gp_model == 'Scalar Product':
elif gp_model == "Scalar Product":
model = ScalarProductGP(X_train, y_train, likelihood)
else:
raise ValueError(f"The specified model choice ({gp_model}) is not a valid option. "
f"Choose one of {list(gp_models.keys())}.")
raise ValueError(
f"The specified model choice ({gp_model}) is not a valid option. "
f"Choose one of {list(gp_models.keys())}."
)

# Find optimal model hyperparameters
model.train()
Expand Down Expand Up @@ -136,9 +160,9 @@ def main(n_trials, test_set_size, dataset_name, dataset_path, featurisation, gp_
# Compute quantile coverage error on test set
qce = quantile_coverage_error(trained_pred_dist, y_test, quantile=95)

print(f'NLPD: {nlpd:.2f}')
print(f'MSLL: {msll:.2f}')
print(f'QCE: {qce:.2f}')
print(f"NLPD: {nlpd:.2f}")
print(f"MSLL: {msll:.2f}")
print(f"QCE: {qce:.2f}")

# mean and variance GP prediction
f_pred = model(X_test)
Expand All @@ -154,8 +178,11 @@ def main(n_trials, test_set_size, dataset_name, dataset_path, featurisation, gp_
y_pred_train = model(X_train).mean.detach()
train_rmse_stan = np.sqrt(mean_squared_error(y_train, y_pred_train))
train_rmse = np.sqrt(
mean_squared_error(y_scaler.inverse_transform(y_train.unsqueeze(dim=1)),
y_scaler.inverse_transform(y_pred_train.unsqueeze(dim=1))))
mean_squared_error(
y_scaler.inverse_transform(y_train.unsqueeze(dim=1)),
y_scaler.inverse_transform(y_pred_train.unsqueeze(dim=1)),
)
)
print("\nStandardised Train RMSE: {:.3f}".format(train_rmse_stan))
print("Train RMSE: {:.3f}".format(train_rmse))

Expand Down Expand Up @@ -184,35 +211,99 @@ def main(n_trials, test_set_size, dataset_name, dataset_path, featurisation, gp_
rmse_list = np.array(rmse_list)
mae_list = np.array(mae_list)

print("\nmean NLPD: {:.4f} +- {:.4f}".format(torch.mean(nlpd_list), torch.std(nlpd_list) / torch.sqrt(torch.tensor(n_trials))))
print("\nmean MSLL: {:.4f} +- {:.4f}".format(torch.mean(msll_list), torch.std(msll_list) / np.sqrt(torch.tensor(n_trials))))
print("\nmean QCE: {:.4f} +- {:.4f}".format(torch.mean(qce_list), torch.std(qce_list) / np.sqrt(torch.tensor(n_trials))))

print("\nmean R^2: {:.4f} +- {:.4f}".format(np.mean(r2_list), np.std(r2_list) / np.sqrt(len(r2_list))))
print("mean RMSE: {:.4f} +- {:.4f}".format(np.mean(rmse_list), np.std(rmse_list) / np.sqrt(len(rmse_list))))
print("mean MAE: {:.4f} +- {:.4f}\n".format(np.mean(mae_list), np.std(mae_list) / np.sqrt(len(mae_list))))


if __name__ == '__main__':
print(
"\nmean NLPD: {:.4f} +- {:.4f}".format(
torch.mean(nlpd_list),
torch.std(nlpd_list) / torch.sqrt(torch.tensor(n_trials)),
)
)
print(
"\nmean MSLL: {:.4f} +- {:.4f}".format(
torch.mean(msll_list),
torch.std(msll_list) / np.sqrt(torch.tensor(n_trials)),
)
)
print(
"\nmean QCE: {:.4f} +- {:.4f}".format(
torch.mean(qce_list),
torch.std(qce_list) / np.sqrt(torch.tensor(n_trials)),
)
)

print(
"\nmean R^2: {:.4f} +- {:.4f}".format(
np.mean(r2_list), np.std(r2_list) / np.sqrt(len(r2_list))
)
)
print(
"mean RMSE: {:.4f} +- {:.4f}".format(
np.mean(rmse_list), np.std(rmse_list) / np.sqrt(len(rmse_list))
)
)
print(
"mean MAE: {:.4f} +- {:.4f}\n".format(
np.mean(mae_list), np.std(mae_list) / np.sqrt(len(mae_list))
)
)


if __name__ == "__main__":

parser = argparse.ArgumentParser()

parser.add_argument('-n', '--n_trials', type=int, default=50,
help='int specifying number of random train/test splits to use')
parser.add_argument('-ts', '--test_set_size', type=float, default=0.2,
help='float in range [0, 1] specifying fraction of dataset to use as test set')
parser.add_argument('-d', '--dataset', type=str, default='Lipophilicity',
help='Dataset to use. One of [Photoswitch, ESOL, FreeSolv, Lipophilicity]')
parser.add_argument('-p', '--path', type=str, default="../data/property_prediction/Lipophilicity.csv",
help='Path to the dataset file. One of [../data/property_prediction/photoswitches.csv, '
'../data/property_prediction/ESOL.csv, '
'../data/property_prediction/FreeSolv.csv, '
'../data/property_prediction/Lipophilicity.csv]')
parser.add_argument('-r', '--featurisation', type=str, default='fingerprints',
help='str specifying the molecular featurisation. '
'One of [fingerprints, fragments, fragprints].')
parser.add_argument('-m', '--model', type=str, default='Tanimoto',
help='Model to use. One of [Tanimoto, Scalar Product,].')
parser.add_argument(
"-n",
"--n_trials",
type=int,
default=50,
help="int specifying number of random train/test splits to use",
)
parser.add_argument(
"-ts",
"--test_set_size",
type=float,
default=0.2,
help="float in range [0, 1] specifying fraction of dataset to use as test set",
)
parser.add_argument(
"-d",
"--dataset",
type=str,
default="Lipophilicity",
help="Dataset to use. One of [Photoswitch, ESOL, FreeSolv, Lipophilicity]",
)
parser.add_argument(
"-p",
"--path",
type=str,
default="../data/property_prediction/Lipophilicity.csv",
help="Path to the dataset file. One of [../data/property_prediction/photoswitches.csv, "
"../data/property_prediction/ESOL.csv, "
"../data/property_prediction/FreeSolv.csv, "
"../data/property_prediction/Lipophilicity.csv]",
)
parser.add_argument(
"-r",
"--featurisation",
type=str,
default="fingerprints",
help="str specifying the molecular featurisation. "
"One of [fingerprints, fragments, fragprints].",
)
parser.add_argument(
"-m",
"--model",
type=str,
default="Tanimoto",
help="Model to use. One of [Tanimoto, Scalar Product,].",
)

args = parser.parse_args()
main(args.n_trials, args.test_set_size, args.dataset, args.path, args.featurisation, args.model)
main(
args.n_trials,
args.test_set_size,
args.dataset,
args.path,
args.featurisation,
args.model,
)
14 changes: 7 additions & 7 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

# -- Project information -----------------------------------------------------

project = 'GAUCHE'
copyright = '2022, Ryan Rhys-Griffiths'
author = 'Ryan Rhys-Griffiths'
project = "GAUCHE"
copyright = "2022, Ryan Rhys-Griffiths"
author = "Ryan Rhys-Griffiths"

# The full version, including alpha/beta/rc tags
release = '0.1.0'
release = "0.1.0"


# -- General configuration ---------------------------------------------------
Expand All @@ -44,7 +44,7 @@
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
nbsphinx_execute = "never"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
Expand All @@ -57,12 +57,12 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'furo'
html_theme = "furo"

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]


intersphinx_mapping = {
Expand Down
Loading

0 comments on commit de8ccdb

Please sign in to comment.