@@ -53,18 +53,17 @@ struct ref_pooling_fwd_t : public gpu::generic::sycl::primitive_t {
53
53
54
54
const bool ok = is_fwd () && set_default_params () == status::success
55
55
&& (src_md (0 )->format_desc .blocking .inner_nblks == 0 )
56
- && (utils::everyone_is (
57
- s8, src_md (0 )->data_type , dst_md (0 )->data_type )
58
- || utils::everyone_is (u8 , src_md (0 )->data_type ,
59
- dst_md (0 )->data_type )
60
- || utils::everyone_is (f32 , src_md (0 )->data_type ,
61
- dst_md (0 )->data_type )
62
- || utils::everyone_is (bf16 , src_md (0 )->data_type ,
63
- dst_md (0 )->data_type )
64
- || utils::everyone_is (f16 , src_md (0 )->data_type ,
65
- dst_md (0 )->data_type )
66
- || utils::everyone_is (s32, src_md (0 )->data_type ,
67
- dst_md (0 )->data_type ))
56
+ && (!utils::one_of (
57
+ f64 , src_md (0 )->data_type , dst_md (0 )->data_type ))
58
+ && (IMPLICATION (src_md (0 )->data_type == bf16 ,
59
+ dst_md (0 )->data_type == bf16 ))
60
+ && (IMPLICATION (src_md (0 )->data_type == s8,
61
+ dst_md (0 )->data_type != u8 ))
62
+ && (IMPLICATION (src_md (0 )->data_type == u8 ,
63
+ dst_md (0 )->data_type != s8))
64
+ && (IMPLICATION (
65
+ src_md (0 )->data_type != dst_md (0 )->data_type ,
66
+ desc ()->prop_kind == forward_inference))
68
67
&& attr ()->has_default_values (sm::post_ops)
69
68
&& attr_.set_default_formats (dst_md (0 )) == status::success;
70
69
if (!ok) return status::unimplemented;
0 commit comments