diff --git a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp index 731807c7d0..f0f72aee82 100644 --- a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp +++ b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp @@ -47,6 +47,8 @@ namespace kernels namespace accumulators { +namespace su_ns = dpctl::tensor::sycl_utils; + using dpctl::tensor::ssize_t; using namespace dpctl::tensor::offset_utils; @@ -84,9 +86,18 @@ template struct CastTransformer } }; +template struct needs_workaround +{ + // workaround needed due to crash in JITing on CPU + // remove when CMPLRLLVM-65813 is resolved + static constexpr bool value = su_ns::IsSyclLogicalAnd::value || + su_ns::IsSyclLogicalOr::value; +}; + template struct can_use_inclusive_scan_over_group { - static constexpr bool value = sycl::has_known_identity::value; + static constexpr bool value = sycl::has_known_identity::value && + !needs_workaround::value; }; namespace detail @@ -144,8 +155,6 @@ template class stack_strided_t // Iterative cumulative summation -namespace su_ns = dpctl::tensor::sycl_utils; - using nwiT = std::uint32_t; template struct TypePairSupportDataForProdAccumulation { static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, // input int8_t + td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, @@ -130,6 +132,11 @@ struct TypePairSupportDataForProdAccumulation td_ns::NotDefinedEntry>::is_defined; }; +template +using CumProdScanOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; + template struct CumProd1DContigFactory { @@ -138,7 +145,7 @@ struct CumProd1DContigFactory if constexpr (TypePairSupportDataForProdAccumulation::is_defined) { - using ScanOpT = sycl::multiplies; + using ScanOpT = CumProdScanOpT; constexpr bool include_initial = false; if constexpr (std::is_same_v) { using dpctl::tensor::kernels::accumulators::NoOpTransformer; @@ -171,7 +178,7 @@ struct CumProd1DIncludeInitialContigFactory if constexpr (TypePairSupportDataForProdAccumulation::is_defined) { - using ScanOpT = sycl::multiplies; + using ScanOpT = CumProdScanOpT; constexpr bool include_initial = true; if constexpr (std::is_same_v) { using dpctl::tensor::kernels::accumulators::NoOpTransformer; @@ -204,7 +211,7 @@ struct CumProdStridedFactory if constexpr (TypePairSupportDataForProdAccumulation::is_defined) { - using ScanOpT = sycl::multiplies; + using ScanOpT = CumProdScanOpT; constexpr bool include_initial = false; if constexpr (std::is_same_v) { using dpctl::tensor::kernels::accumulators::NoOpTransformer; @@ -237,7 +244,7 @@ struct CumProdIncludeInitialStridedFactory if constexpr (TypePairSupportDataForProdAccumulation::is_defined) { - using ScanOpT = sycl::multiplies; + using ScanOpT = CumProdScanOpT; constexpr bool include_initial = true; if constexpr (std::is_same_v) { using dpctl::tensor::kernels::accumulators::NoOpTransformer; diff --git a/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp b/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp index 2e6cfddfb6..e44678e15f 100644 --- a/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp +++ b/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp @@ -70,10 +70,12 @@ template struct TypePairSupportDataForSumAccumulation { static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, // input int8_t + td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, td_ns::TypePairDefinedEntry, @@ -130,6 +132,10 @@ struct TypePairSupportDataForSumAccumulation td_ns::NotDefinedEntry>::is_defined; }; +template +using CumSumScanOpT = std:: + conditional_t, sycl::logical_or, sycl::plus>; + template struct CumSum1DContigFactory { @@ -138,7 +144,7 @@ struct CumSum1DContigFactory if constexpr (TypePairSupportDataForSumAccumulation::is_defined) { - using ScanOpT = sycl::plus; + using ScanOpT = CumSumScanOpT; constexpr bool include_initial = false; if constexpr (std::is_same_v) { using dpctl::tensor::kernels::accumulators::NoOpTransformer; @@ -171,7 +177,7 @@ struct CumSum1DIncludeInitialContigFactory if constexpr (TypePairSupportDataForSumAccumulation::is_defined) { - using ScanOpT = sycl::plus; + using ScanOpT = CumSumScanOpT; constexpr bool include_initial = true; if constexpr (std::is_same_v) { using dpctl::tensor::kernels::accumulators::NoOpTransformer; @@ -204,7 +210,7 @@ struct CumSumStridedFactory if constexpr (TypePairSupportDataForSumAccumulation::is_defined) { - using ScanOpT = sycl::plus; + using ScanOpT = CumSumScanOpT; constexpr bool include_initial = false; if constexpr (std::is_same_v) { using dpctl::tensor::kernels::accumulators::NoOpTransformer; @@ -237,7 +243,7 @@ struct CumSumIncludeInitialStridedFactory if constexpr (TypePairSupportDataForSumAccumulation::is_defined) { - using ScanOpT = sycl::plus; + using ScanOpT = CumSumScanOpT; constexpr bool include_initial = true; if constexpr (std::is_same_v) { using dpctl::tensor::kernels::accumulators::NoOpTransformer; diff --git a/dpctl/tests/test_tensor_accumulation.py b/dpctl/tests/test_tensor_accumulation.py index 962d2742a0..9c8eec91d1 100644 --- a/dpctl/tests/test_tensor_accumulation.py +++ b/dpctl/tests/test_tensor_accumulation.py @@ -421,3 +421,15 @@ def test_cumulative_sum_gh_1901(p): inp = dpt.ones(n, dtype=dt) r = dpt.cumulative_sum(inp, dtype=dt) assert dpt.all(r == dpt.arange(1, n + 1, dtype=dt)) + + +@pytest.mark.parametrize( + "dt", ["i1", "i2", "i4", "i8", "f2", "f4", "f8", "c8", "c16"] +) +def test_gh_2017(dt): + "See https://github.com/IntelPython/dpctl/issues/2017" + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt, q) + x = dpt.asarray([-1, 1], dtype=dpt.dtype(dt), sycl_queue=q) + r = dpt.cumulative_sum(x, dtype="?") + assert dpt.all(r)