@@ -437,10 +437,41 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self):
437
437
df_raw .reset_index ()["longitude" ],
438
438
)
439
439
440
- def test_highlevel_predict_with_pred_params (self ):
440
+ def test_highlevel_predict_with_pred_params_pandas (self ):
441
441
"""
442
442
Test that passing ``pred_params`` to ``.predict`` works with
443
- a spikes-beta likelihood.
443
+ a spikes-beta likelihood for prediction to pandas.
444
+ """
445
+ tl = TaskLoader (context = self .da , target = self .da )
446
+ model = ConvNP (
447
+ self .dp ,
448
+ tl ,
449
+ unet_channels = (5 , 5 , 5 ),
450
+ verbose = False ,
451
+ likelihood = "cnp-spikes-beta" ,
452
+ )
453
+ task = tl ("2020-01-01" , context_sampling = 10 , target_sampling = 10 )
454
+
455
+ # Off-grid prediction
456
+ X_t = np .array ([[0.0 , 0.5 , 1.0 ], [0.0 , 0.5 , 1.0 ]])
457
+
458
+ # Check that nothing breaks and the correct parameters are returned
459
+ pred_params = ["mean" , "std" , "variance" , "alpha" , "beta" ]
460
+ pred = model .predict (task , X_t = X_t , pred_params = pred_params )
461
+ for pred_param in pred_params :
462
+ assert pred_param in pred ["var" ]
463
+
464
+ # Test mixture probs special case
465
+ pred_params = ["mixture_probs" ]
466
+ pred = model .predict (task , X_t = self .da , pred_params = pred_params )
467
+ for component in range (model .N_mixture_components ):
468
+ pred_param = f"mixture_probs_{ component } "
469
+ assert pred_param in pred ["var" ]
470
+
471
+ def test_highlevel_predict_with_pred_params_xarray (self ):
472
+ """
473
+ Test that passing ``pred_params`` to ``.predict`` works with
474
+ a spikes-beta likelihood for prediction to xarray.
444
475
"""
445
476
tl = TaskLoader (context = self .da , target = self .da )
446
477
model = ConvNP (
0 commit comments