@@ -54,18 +54,17 @@ struct ref_pooling_fwd_t : public gpu::generic::sycl::primitive_t {
54
54
55
55
const bool ok = is_fwd () && set_default_params () == status::success
56
56
&& (src_md (0 )->format_desc .blocking .inner_nblks == 0 )
57
- && (utils::everyone_is (
58
- s8, src_md (0 )->data_type , dst_md (0 )->data_type )
59
- || utils::everyone_is (u8 , src_md (0 )->data_type ,
60
- dst_md (0 )->data_type )
61
- || utils::everyone_is (f32 , src_md (0 )->data_type ,
62
- dst_md (0 )->data_type )
63
- || utils::everyone_is (bf16 , src_md (0 )->data_type ,
64
- dst_md (0 )->data_type )
65
- || utils::everyone_is (f16 , src_md (0 )->data_type ,
66
- dst_md (0 )->data_type )
67
- || utils::everyone_is (s32, src_md (0 )->data_type ,
68
- dst_md (0 )->data_type ))
57
+ && (!utils::one_of (
58
+ f64 , src_md (0 )->data_type , dst_md (0 )->data_type ))
59
+ && (IMPLICATION (src_md (0 )->data_type == bf16 ,
60
+ dst_md (0 )->data_type == bf16 ))
61
+ && (IMPLICATION (src_md (0 )->data_type == s8,
62
+ dst_md (0 )->data_type != u8 ))
63
+ && (IMPLICATION (src_md (0 )->data_type == u8 ,
64
+ dst_md (0 )->data_type != s8))
65
+ && (IMPLICATION (
66
+ src_md (0 )->data_type != dst_md (0 )->data_type ,
67
+ desc ()->prop_kind == forward_inference))
69
68
&& attr ()->has_default_values (sm::post_ops)
70
69
&& attr_.set_default_formats (dst_md (0 )) == status::success
71
70
&& md_dims_in_range (src_md ());
0 commit comments