@@ -67,7 +67,8 @@ def _flattened_conditional_mean_fn_helper(
67
67
observations = tf .convert_to_tensor (observations )
68
68
if observation_index_points is not None :
69
69
observation_index_points = nest_util .convert_to_nested_tensor (
70
- observation_index_points , dtype = kernel .dtype , allow_packing = True )
70
+ observation_index_points , dtype_hint = kernel .dtype , allow_packing = True
71
+ )
71
72
72
73
k_x_obs_linop = kernel .matrix_over_all_tasks (x , observation_index_points )
73
74
if solve_on_observations is None :
@@ -296,12 +297,13 @@ def __init__(self,
296
297
297
298
input_dtype = dtype_util .common_dtype (
298
299
dict (
299
- kernel = kernel ,
300
300
index_points = index_points ,
301
301
observation_index_points = observation_index_points ,
302
302
),
303
303
dtype_hint = nest_util .broadcast_structure (
304
- kernel .feature_ndims , tf .float32 ))
304
+ kernel .feature_ndims , tf .float32
305
+ ),
306
+ )
305
307
306
308
# If the input dtype is non-nested float, we infer a single dtype for the
307
309
# input and the float parameters, which is also the dtype of the MTGP's
@@ -573,9 +575,11 @@ def precompute_regression_model(
573
575
with tf .name_scope (name ) as name :
574
576
if tf .nest .is_nested (kernel .feature_ndims ):
575
577
input_dtype = dtype_util .common_dtype (
576
- [kernel , index_points , observation_index_points ],
578
+ [index_points , observation_index_points ],
577
579
dtype_hint = nest_util .broadcast_structure (
578
- kernel .feature_ndims , tf .float32 ))
580
+ kernel .feature_ndims , tf .float32
581
+ ),
582
+ )
579
583
dtype = dtype_util .common_dtype (
580
584
[observations , observation_noise_variance ,
581
585
predictive_noise_variance ], tf .float32 )
0 commit comments