Skip to content

Commit 81aa661

Browse files
committed
gpu: sycl: pooling: Enabling different src/dst datatypes
1 parent 511d36a commit 81aa661

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

src/gpu/generic/sycl/ref_pooling.hpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,17 @@ struct ref_pooling_fwd_t : public gpu::generic::sycl::primitive_t {
5353

5454
const bool ok = is_fwd() && set_default_params() == status::success
5555
&& (src_md(0)->format_desc.blocking.inner_nblks == 0)
56-
&& (utils::everyone_is(
57-
s8, src_md(0)->data_type, dst_md(0)->data_type)
58-
|| utils::everyone_is(u8, src_md(0)->data_type,
59-
dst_md(0)->data_type)
60-
|| utils::everyone_is(f32, src_md(0)->data_type,
61-
dst_md(0)->data_type)
62-
|| utils::everyone_is(bf16, src_md(0)->data_type,
63-
dst_md(0)->data_type)
64-
|| utils::everyone_is(f16, src_md(0)->data_type,
65-
dst_md(0)->data_type)
66-
|| utils::everyone_is(s32, src_md(0)->data_type,
67-
dst_md(0)->data_type))
56+
&& (!utils::one_of(
57+
f64, src_md(0)->data_type, dst_md(0)->data_type))
58+
&& (IMPLICATION(src_md(0)->data_type == bf16,
59+
dst_md(0)->data_type == bf16))
60+
&& (IMPLICATION(src_md(0)->data_type == s8,
61+
dst_md(0)->data_type != u8))
62+
&& (IMPLICATION(src_md(0)->data_type == u8,
63+
dst_md(0)->data_type != s8))
64+
&& (IMPLICATION(
65+
src_md(0)->data_type != dst_md(0)->data_type,
66+
desc()->prop_kind == forward_inference))
6867
&& attr()->has_default_values(sm::post_ops)
6968
&& attr_.set_default_formats(dst_md(0)) == status::success;
7069
if (!ok) return status::unimplemented;

tests/benchdnn/pool/ref_pool.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) {
3838
// XXX: this is a hack to let tests with padded area to pass for bf16
3939
// dt due to the library initialize values with -max_dt, but not -INF.
4040
float max_value = lowest_dt(prb->dst_dt());
41+
if (prb->src_dt() == dnnl_u8 || prb->src_dt() == dnnl_s8
42+
|| prb->src_dt() == dnnl_f16 || prb->src_dt() == dnnl_f32)
43+
max_value = lowest_dt(prb->src_dt());
4144
float avg_value = 0.;
4245
// Set initial value based on ws data type
4346
int ws_off = prb->kernel_size() <= UINT8_MAX ? UINT8_MAX : INT_MAX;

0 commit comments

Comments
 (0)