Skip to content

Commit 8d655ca

Browse files
Googlertensorflower-gardener
authored andcommitted
Removes/relaxes dtype requirements from kernel in multitask_gaussian_process_regression_model.py
PiperOrigin-RevId: 715188555
1 parent f1dd1c7 commit 8d655ca

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def _flattened_conditional_mean_fn_helper(
6767
observations = tf.convert_to_tensor(observations)
6868
if observation_index_points is not None:
6969
observation_index_points = nest_util.convert_to_nested_tensor(
70-
observation_index_points, dtype=kernel.dtype, allow_packing=True)
70+
observation_index_points, dtype_hint=kernel.dtype, allow_packing=True
71+
)
7172

7273
k_x_obs_linop = kernel.matrix_over_all_tasks(x, observation_index_points)
7374
if solve_on_observations is None:
@@ -296,12 +297,13 @@ def __init__(self,
296297

297298
input_dtype = dtype_util.common_dtype(
298299
dict(
299-
kernel=kernel,
300300
index_points=index_points,
301301
observation_index_points=observation_index_points,
302302
),
303303
dtype_hint=nest_util.broadcast_structure(
304-
kernel.feature_ndims, tf.float32))
304+
kernel.feature_ndims, tf.float32
305+
),
306+
)
305307

306308
# If the input dtype is non-nested float, we infer a single dtype for the
307309
# input and the float parameters, which is also the dtype of the MTGP's
@@ -573,9 +575,11 @@ def precompute_regression_model(
573575
with tf.name_scope(name) as name:
574576
if tf.nest.is_nested(kernel.feature_ndims):
575577
input_dtype = dtype_util.common_dtype(
576-
[kernel, index_points, observation_index_points],
578+
[index_points, observation_index_points],
577579
dtype_hint=nest_util.broadcast_structure(
578-
kernel.feature_ndims, tf.float32))
580+
kernel.feature_ndims, tf.float32
581+
),
582+
)
579583
dtype = dtype_util.common_dtype(
580584
[observations, observation_noise_variance,
581585
predictive_noise_variance], tf.float32)

0 commit comments

Comments
 (0)