Skip to content

Commit 90bec68

Browse files
committed
Add off-grid non-Gaussian unit test
1 parent 33b8c2b commit 90bec68

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

tests/test_model.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,10 +437,41 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self):
437437
df_raw.reset_index()["longitude"],
438438
)
439439

440-
def test_highlevel_predict_with_pred_params(self):
440+
def test_highlevel_predict_with_pred_params_pandas(self):
441441
"""
442442
Test that passing ``pred_params`` to ``.predict`` works with
443-
a spikes-beta likelihood.
443+
a spikes-beta likelihood for prediction to pandas.
444+
"""
445+
tl = TaskLoader(context=self.da, target=self.da)
446+
model = ConvNP(
447+
self.dp,
448+
tl,
449+
unet_channels=(5, 5, 5),
450+
verbose=False,
451+
likelihood="cnp-spikes-beta",
452+
)
453+
task = tl("2020-01-01", context_sampling=10, target_sampling=10)
454+
455+
# Off-grid prediction
456+
X_t = np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]])
457+
458+
# Check that nothing breaks and the correct parameters are returned
459+
pred_params = ["mean", "std", "variance", "alpha", "beta"]
460+
pred = model.predict(task, X_t=X_t, pred_params=pred_params)
461+
for pred_param in pred_params:
462+
assert pred_param in pred["var"]
463+
464+
# Test mixture probs special case
465+
pred_params = ["mixture_probs"]
466+
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
467+
for component in range(model.N_mixture_components):
468+
pred_param = f"mixture_probs_{component}"
469+
assert pred_param in pred["var"]
470+
471+
def test_highlevel_predict_with_pred_params_xarray(self):
472+
"""
473+
Test that passing ``pred_params`` to ``.predict`` works with
474+
a spikes-beta likelihood for prediction to xarray.
444475
"""
445476
tl = TaskLoader(context=self.da, target=self.da)
446477
model = ConvNP(

0 commit comments

Comments
 (0)