Skip to content

Commit de28cf2

Browse files
authored
gpu: generic: sycl: pooling: enable different src/dst dt (#1878)
1 parent 1411937 commit de28cf2

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

src/gpu/generic/sycl/ref_pooling.hpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,17 @@ struct ref_pooling_fwd_t : public gpu::generic::sycl::primitive_t {
5454

5555
const bool ok = is_fwd() && set_default_params() == status::success
5656
&& (src_md(0)->format_desc.blocking.inner_nblks == 0)
57-
&& (utils::everyone_is(
58-
s8, src_md(0)->data_type, dst_md(0)->data_type)
59-
|| utils::everyone_is(u8, src_md(0)->data_type,
60-
dst_md(0)->data_type)
61-
|| utils::everyone_is(f32, src_md(0)->data_type,
62-
dst_md(0)->data_type)
63-
|| utils::everyone_is(bf16, src_md(0)->data_type,
64-
dst_md(0)->data_type)
65-
|| utils::everyone_is(f16, src_md(0)->data_type,
66-
dst_md(0)->data_type)
67-
|| utils::everyone_is(s32, src_md(0)->data_type,
68-
dst_md(0)->data_type))
57+
&& (!utils::one_of(
58+
f64, src_md(0)->data_type, dst_md(0)->data_type))
59+
&& (IMPLICATION(src_md(0)->data_type == bf16,
60+
dst_md(0)->data_type == bf16))
61+
&& (IMPLICATION(src_md(0)->data_type == s8,
62+
dst_md(0)->data_type != u8))
63+
&& (IMPLICATION(src_md(0)->data_type == u8,
64+
dst_md(0)->data_type != s8))
65+
&& (IMPLICATION(
66+
src_md(0)->data_type != dst_md(0)->data_type,
67+
desc()->prop_kind == forward_inference))
6968
&& attr()->has_default_values(sm::post_ops)
7069
&& attr_.set_default_formats(dst_md(0)) == status::success
7170
&& md_dims_in_range(src_md());

tests/benchdnn/pool/ref_pool.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) {
3838
// XXX: this is a hack to let tests with padded area to pass for bf16
3939
// dt due to the library initialize values with -max_dt, but not -INF.
4040
float max_value = lowest_dt(prb->dst_dt());
41+
if (is_nvidia_gpu()) max_value = lowest_dt(prb->src_dt());
4142
float avg_value = 0.;
4243
// Set initial value based on ws data type
4344
int ws_off = prb->kernel_size() <= UINT8_MAX ? UINT8_MAX : INT_MAX;

0 commit comments

Comments
 (0)