Skip to content

Commit

Permalink
Refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Mar 8, 2023
1 parent afd9cca commit 4ca6ad2
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 56 deletions.
41 changes: 1 addition & 40 deletions tests/conftest.py → ramsey/_src/conftest.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,8 @@
import haiku as hk
import pytest
from jax import numpy as np
from jax import random
from numpyro import distributions as dist

from ramsey import ANP, DANP, NP
from ramsey.attention import MultiHeadAttention
from ramsey.covariance_functions import exponentiated_quadratic
from ramsey.models import ANP, DANP, NP


# pylint: disable=too-many-locals,invalid-name,redefined-outer-name
@pytest.fixture()
def simple_data_set():
key = random.PRNGKey(0)
batch_size = 10
n, p = 50, 1
n_context = 20

key, sample_key = random.split(key, 2)
x = random.normal(key, shape=(n * p,)).reshape((n, p))
ys = []
for _ in range(batch_size):
key, sample_key1, sample_key2, sample_key3 = random.split(key, 4)
rho = dist.InverseGamma(5, 5).sample(sample_key1)
sigma = dist.InverseGamma(5, 5).sample(sample_key2)
K = exponentiated_quadratic(x, x, sigma, rho)
y = random.multivariate_normal(
sample_key3, mean=np.zeros(n), cov=K + np.diag(np.ones(n)) * 0.05
).reshape((1, n, 1))
ys.append(y)

x_target = np.tile(x, [batch_size, 1, 1])
y_target = np.vstack(np.array(ys))

key, sample_key = random.split(key, 2)
idxs_context = random.choice(
sample_key, np.arange(n), shape=(n_context,), replace=False
)

x_context = x_target[:, idxs_context, :]
y_context = y_target[:, idxs_context, :]

return x_context, y_context, x_target, y_target


def __lnp(**kwargs):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import haiku as hk
from jax import random

from ramsey.data import sample_from_gaussian_process
from ramsey.train import train_neural_process


# pylint: disable=too-many-locals,invalid-name,redefined-outer-name
def test_neural_process_training(simple_data_set, module):
def test_neural_process_training(module):
key = random.PRNGKey(1)
_, _, x_target, y_target = simple_data_set
(x_target, y_target), _ = sample_from_gaussian_process(key)

f = hk.transform(module)
params = f.init(
Expand All @@ -19,7 +20,7 @@ def test_neural_process_training(simple_data_set, module):
f,
params,
train_key,
n_iter=100,
n_iter=10,
x=x_target,
y=y_target,
n_context=10,
Expand Down
36 changes: 24 additions & 12 deletions tests/test_networks.py → ramsey/_src/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import pytest
from jax import random

from ramsey.models import NP
from ramsey import NP
from ramsey.data import sample_from_gaussian_process


# pylint: disable=too-many-locals,invalid-name,redefined-outer-name
def test_module_dimensionality(simple_data_set):
def test_module_dimensionality():
key = random.PRNGKey(1)
x_context, y_context, x_target, _ = simple_data_set
(x_target, y_target), _ = sample_from_gaussian_process(key)

def module(**kwargs):
np = NP(
Expand All @@ -26,7 +27,10 @@ def module(**kwargs):

f = hk.transform(module)
params = f.init(
key, x_context=x_context, y_context=y_context, x_target=x_target
key,
x_context=x_target[:, :10, :],
y_context=y_target[:, :10, :],
x_target=x_target,
)

chex.assert_shape(params["latent_encoder1/~/linear_0"]["w"], (2, 3))
Expand All @@ -39,25 +43,28 @@ def module(**kwargs):
chex.assert_shape(params["decoder/~/linear_1"]["w"], (3, 2))


def test_modules(simple_data_set, module):
def test_modules(module):
key = random.PRNGKey(1)
x_context, y_context, x_target, y_target = simple_data_set
(x_target, y_target), _ = sample_from_gaussian_process(key)

f = hk.transform(module)
params = f.init(
key, x_context=x_context, y_context=y_context, x_target=x_target
key,
x_context=x_target[:, :10, :],
y_context=y_target[:, :10, :],
x_target=x_target,
)
y_star = f.apply(
rng=key,
params=params,
x_context=x_context,
y_context=y_context,
x_context=x_target[:, :10, :],
y_context=y_target[:, :10, :],
x_target=x_target,
)
chex.assert_equal_shape([y_target, y_star.mean])


def test_modules_false_decoder(simple_data_set):
def test_modules_false_decoder():
def f(**kwargs):
np = NP(
decoder=hk.nets.MLP([3, 3], name="decoder"),
Expand All @@ -72,8 +79,13 @@ def f(**kwargs):
return np(**kwargs)

key = random.PRNGKey(1)
x_context, y_context, x_target, _ = simple_data_set
(x_target, y_target), _ = sample_from_gaussian_process(key)

with pytest.raises(ValueError):
f = hk.transform(f)
f.init(key, x_context=x_context, y_context=y_context, x_target=x_target)
f.init(
key,
x_context=x_target[:, :10, :],
y_context=y_target[:, :10, :],
x_target=x_target,
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _version():
"pandas"
],
extras_require={
"dev": ["pre-commit", "black", "isort", "pylint", "tox"],
"dev": ["pre-commit", "black", "isort", "pylint", "tox", "pytest"],
"examples": ["matplotlib"],
},
classifiers=[
Expand Down
Empty file removed tests/__init__.py
Empty file.

0 comments on commit 4ca6ad2

Please sign in to comment.