Skip to content

Commit

Permalink
Test expectation via public API. (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
null-a authored and neerajprad committed Nov 29, 2019
1 parent f2f5a7a commit d72cc47
Showing 1 changed file with 25 additions and 32 deletions.
57 changes: 25 additions & 32 deletions tests/test_brm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from scipy.special import expit as sigmoid
import numpyro.handlers as numpyro
import pandas as pd
import pytest
Expand Down Expand Up @@ -431,6 +432,30 @@ def test_mu_correctness(formula_str, cols, backend, expected):
assert np.allclose(actual_mu, expected_mu)


@pytest.mark.parametrize('cols, family, expected', [
([],
Normal,
lambda mu: mu),
([Integral('y', min=0, max=1)],
Bernoulli,
lambda mu: sigmoid(mu)),
([Integral('y', min=0, max=5)],
Binomial(num_trials=5),
lambda mu: sigmoid(mu) * 5),
])
@pytest.mark.parametrize('backend', [pyro_backend, numpyro_backend])
def test_expectation_correctness(cols, family, expected, backend):
formula_str = 'y ~ 1 + x'
df = dummy_df(expand_columns(parse(formula_str), cols), 10)
fit = brm(formula_str, df, family=family).prior(num_samples=1, backend=backend)
actual_expectation = fit.fitted(what='expectation')[0]
# We assume (since it's tested elsewhere) that `mu` is computed
# correctly by `fitted`. So given that, we check that `fitted`
# computes the correct expectation.
expected_expectation = expected(fit.fitted('linear')[0])
assert np.allclose(actual_expectation, expected_expectation)


@pytest.mark.parametrize('N', [0, 5])
@pytest.mark.parametrize('backend', [pyro_backend, numpyro_backend])
@pytest.mark.parametrize('formula_str, non_real_cols, contrasts, family, priors, expected', codegen_cases)
Expand All @@ -446,38 +471,6 @@ def test_sampling_from_prior_smoke(N, backend, formula_str, non_real_cols, contr
assert type(samples) == Samples


# Sanity checks to ensure that the generated `expected_response`
# function does something sensible for some common families.
@pytest.mark.parametrize('response_meta, family, args, expected', [
(RealValued('y'),
Normal,
[
np.array([[1., 2., 3.], [4., 5., 6.]]), # mean
np.array([[0.1], [0.2]]), # sd
],
np.array([[1., 2., 3.], [4., 5., 6.]])), # mean
(Integral('y', min=0, max=5),
Binomial(num_trials=5),
[
np.array([[-2., 0., 2.], [-1., 0., 1.]]), # logits
],
np.array([[0.59601461, 2.5, 4.40398539],
[1.34470711, 2.5, 3.65529289]])), # sigmoid(logits) * num_trials
])
@pytest.mark.parametrize('backend', [pyro_backend, numpyro_backend])
def test_expected_response_codegen(response_meta, family, args, expected, backend):
formula = parse('y ~ 1')
desc = makedesc(formula, metadata_from_cols([response_meta]), family, [], {})

def expected_response(*args):
backend_args = [backend.from_numpy(arg) for arg in args]
fn = backend.gen(desc).expected_response_fn
return backend.to_numpy(fn(*backend_args))

assert np.allclose(expected_response(*args), expected)


@pytest.mark.parametrize('formula_str, non_real_cols, contrasts, family, priors, expected', codegen_cases)
@pytest.mark.parametrize('fitargs', [
dict(backend=pyro_backend, num_samples=1, algo='prior'),
Expand Down

0 comments on commit d72cc47

Please sign in to comment.