diff --git a/batched/dense/impl/KokkosBatched_Eigendecomposition_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Eigendecomposition_Serial_Internal.hpp index c56fd6ba87..3538f1d1eb 100644 --- a/batched/dense/impl/KokkosBatched_Eigendecomposition_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Eigendecomposition_Serial_Internal.hpp @@ -12,7 +12,6 @@ #include "KokkosBatched_Schur_Serial_Internal.hpp" #include "KokkosBatched_RightEigenvectorFromSchur_Serial_Internal.hpp" #include "KokkosBatched_LeftEigenvectorFromSchur_Serial_Internal.hpp" -#include "KokkosBatched_Gemm_Serial_Internal.hpp" namespace KokkosBatched { diff --git a/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp index 3331e847b2..b895a0b554 100644 --- a/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp @@ -43,343 +43,9 @@ #define __KOKKOSBATCHED_GEMM_SERIAL_IMPL_HPP__ #include "KokkosBatched_Util.hpp" -#include "KokkosBatched_Gemm_Serial_Internal.hpp" +#include "KokkosBlas3_gemm.hpp" namespace KokkosBatched { -/********************* BEGIN functor-level routines *********************/ -/// -/// Serial Impl -/// =========== - -/// -/// Implemented: -/// NT/NT, T/NT, NT/T, T/T -/// -/// Not yet immplemented (ConjTranspose): -/// CT/NT, NT/CT, CT/CT -/// - -/// -/// NT/NT -/// - -#if defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL__) && \ - defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_BATCHED__) && \ - defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemm::invoke(const ScalarType alpha, - const AViewType &A, - const BViewType &B, - const ScalarType beta, - const CViewType &C) { - typedef typename CViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - - const int m = C.extent(0), n = C.extent(1), k = A.extent(1); - - static_assert(is_vector::value, "value type is not vector type"); - static_assert( - vector_type::vector_length == 4 || vector_type::vector_length == 8, - "AVX, AVX2 and AVX512 is supported"); - const MKL_COMPACT_PACK format = - vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; - - // no error check - int r_val = 0; - if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { - mkl_dgemm_compact(MKL_COL_MAJOR, MKL_NOTRANS, MKL_NOTRANS, m, n, k, alpha, - (const double *)A.data(), A.stride_1(), - (const double *)B.data(), B.stride_1(), beta, - (double *)C.data(), C.stride_1(), format, - (MKL_INT)vector_type::vector_length); - } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { - mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_NOTRANS, MKL_NOTRANS, m, n, k, alpha, - (const double *)A.data(), A.stride_0(), - (const double *)B.data(), B.stride_0(), beta, - (double *)C.data(), C.stride_0(), format, - (MKL_INT)vector_type::vector_length); - } else { - r_val = -1; - } - return r_val; -} -#endif - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemm::invoke(const ScalarType alpha, - const AViewType &A, - const BViewType &B, - const ScalarType beta, - const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke( - C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), A.stride_0(), - A.stride_1(), B.data(), B.stride_0(), B.stride_1(), beta, C.data(), - C.stride_0(), C.stride_1()); -} - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemm::invoke( - const ScalarType alpha, const AViewType &A, const BViewType &B, - const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke( - C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), A.stride_0(), - A.stride_1(), B.data(), B.stride_0(), B.stride_1(), beta, C.data(), - C.stride_0(), C.stride_1()); -} - -/// -/// T/NT -/// - -#if defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL__) && \ - defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_BATCHED__) && \ - defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemm::invoke(const ScalarType alpha, - const AViewType &A, - const BViewType &B, - const ScalarType beta, - const CViewType &C) { - typedef typename CViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - - const int m = C.extent(0), n = C.extent(1), k = A.extent(0); - - static_assert(is_vector::value, "value type is not vector type"); - static_assert( - vector_type::vector_length == 4 || vector_type::vector_length == 8, - "AVX, AVX2 and AVX512 is supported"); - const MKL_COMPACT_PACK format = - vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; - - // no error check - int r_val = 0; - if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { - mkl_dgemm_compact(MKL_COL_MAJOR, MKL_TRANS, MKL_NOTRANS, m, n, k, alpha, - (const double *)A.data(), A.stride_1(), - (const double *)B.data(), B.stride_1(), beta, - (double *)C.data(), C.stride_1(), format, - (MKL_INT)vector_type::vector_length); - } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { - mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_TRANS, MKL_NOTRANS, m, n, k, alpha, - (const double *)A.data(), A.stride_0(), - (const double *)B.data(), B.stride_0(), beta, - (double *)C.data(), C.stride_0(), format, - (MKL_INT)vector_type::vector_length); - } else { - r_val = -1; - } - return r_val; -} -#endif - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemm::invoke( - const ScalarType alpha, const AViewType &A, const BViewType &B, - const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke( - C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), - A.stride_0(), B.data(), B.stride_0(), B.stride_1(), beta, C.data(), - C.stride_0(), C.stride_1()); -} - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemm::invoke( - const ScalarType alpha, const AViewType &A, const BViewType &B, - const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke( - C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), - A.stride_0(), B.data(), B.stride_0(), B.stride_1(), beta, C.data(), - C.stride_0(), C.stride_1()); -} - -/// -/// NT/T -/// - -#if defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL__) && \ - defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_BATCHED__) && \ - defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemm::invoke(const ScalarType alpha, - const AViewType &A, - const BViewType &B, - const ScalarType beta, - const CViewType &C) { - typedef typename CViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - - const int m = C.extent(0), n = C.extent(1), k = A.extent(1); - - static_assert(is_vector::value, "value type is not vector type"); - static_assert( - vector_type::vector_length == 4 || vector_type::vector_length == 8, - "AVX, AVX2 and AVX512 is supported"); - const MKL_COMPACT_PACK format = - vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; - - // no error check - int r_val = 0; - if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { - mkl_dgemm_compact(MKL_COL_MAJOR, MKL_NOTRANS, MKL_TRANS, m, n, k, alpha, - (const double *)A.data(), A.stride_1(), - (const double *)B.data(), B.stride_1(), beta, - (double *)C.data(), C.stride_1(), format, - (MKL_INT)vector_type::vector_length); - } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { - mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_NOTRANS, MKL_TRANS, m, n, k, alpha, - (const double *)A.data(), A.stride_0(), - (const double *)B.data(), B.stride_0(), beta, - (double *)C.data(), C.stride_0(), format, - (MKL_INT)vector_type::vector_length); - } else { - r_val = -1; - } - return r_val; -} -#endif - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemm::invoke( - const ScalarType alpha, const AViewType &A, const BViewType &B, - const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke( - C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), A.stride_0(), - A.stride_1(), B.data(), B.stride_1(), B.stride_0(), beta, C.data(), - C.stride_0(), C.stride_1()); -} - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemm::invoke( - const ScalarType alpha, const AViewType &A, const BViewType &B, - const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke( - C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), A.stride_0(), - A.stride_1(), B.data(), B.stride_1(), B.stride_0(), beta, C.data(), - C.stride_0(), C.stride_1()); -} - -/// -/// T/T -/// - -#if defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL__) && \ - defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_BATCHED__) && \ - defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemm::invoke( - const ScalarType alpha, const AViewType &A, const BViewType &B, - const ScalarType beta, const CViewType &C) { - typedef typename CViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - - const int m = C.extent(0), n = C.extent(1), k = A.extent(0); - - static_assert(is_vector::value, "value type is not vector type"); - static_assert( - vector_type::vector_length == 4 || vector_type::vector_length == 8, - "AVX, AVX2 and AVX512 is supported"); - const MKL_COMPACT_PACK format = - vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; - - // no error check - int r_val = 0; - if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { - mkl_dgemm_compact(MKL_COL_MAJOR, MKL_TRANS, MKL_TRANS, m, n, k, alpha, - (const double *)A.data(), A.stride_1(), - (const double *)B.data(), B.stride_1(), beta, - (double *)C.data(), C.stride_1(), format, - (MKL_INT)vector_type::vector_length); - } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { - mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_TRANS, MKL_TRANS, m, n, k, alpha, - (const double *)A.data(), A.stride_0(), - (const double *)B.data(), B.stride_0(), beta, - (double *)C.data(), C.stride_0(), format, - (MKL_INT)vector_type::vector_length); - } else { - r_val = -1; - } - return r_val; -} -#endif - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemm::invoke( - const ScalarType alpha, const AViewType &A, const BViewType &B, - const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke( - C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), - A.stride_0(), B.data(), B.stride_1(), B.stride_0(), beta, C.data(), - C.stride_0(), C.stride_1()); -} - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemm::invoke( - const ScalarType alpha, const AViewType &A, const BViewType &B, - const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke( - C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), - A.stride_0(), B.data(), B.stride_1(), B.stride_0(), beta, C.data(), - C.stride_0(), C.stride_1()); -} -/********************* END functor-level routines *********************/ - namespace Impl { /********************* BEGIN non-functor-level routines *********************/ template ::invoke(alpha, svA_row, svB_col, beta, - svC_ele); + KokkosBlas::SerialGemm::invoke(alpha, svA_row, svB_col, beta, + svC_ele); } KOKKOS_INLINE_FUNCTION @@ -481,7 +147,7 @@ class BatchedSerialGemm { auto svC = subview_wrapper(C, i, Kokkos::ALL(), Kokkos::ALL(), batch_layout_tag); - KokkosBatched::SerialGemm::invoke( + KokkosBlas::SerialGemm::invoke( alpha, svA, svB, beta, svC); } }; diff --git a/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp deleted file mode 100644 index 1548d602e2..0000000000 --- a/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp +++ /dev/null @@ -1,137 +0,0 @@ -#ifndef __KOKKOSBATCHED_GEMM_SERIAL_INTERNAL_HPP__ -#define __KOKKOSBATCHED_GEMM_SERIAL_INTERNAL_HPP__ - -/// \author Kyungjoo Kim (kyukim@sandia.gov) - -#include "KokkosBatched_Util.hpp" - -#include "KokkosBlas1_set_impl.hpp" -#include "KokkosBlas1_serial_scal_impl.hpp" - -#include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp" - -namespace KokkosBatched { - -/// -/// Serial Internal Impl -/// ==================== - -template -struct SerialGemmInternal { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const int m, const int n, const int k, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, - const ScalarType beta, - /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1); -}; - -template <> -template -KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( - const int m, const int n, const int k, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, - const ScalarType beta, - /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - - const ScalarType one(1.0), zero(0.0); - - if (beta == zero) - KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, C, cs0, cs1); - else if (beta != one) - KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1); - - if (alpha != zero) { - if (m <= 0 || n <= 0 || k <= 0) return 0; - - ValueType *KOKKOS_RESTRICT pC = C; - for (int p = 0; p < k; ++p) { - const ValueType *KOKKOS_RESTRICT pA = A + p * as1, - *KOKKOS_RESTRICT pB = B + p * bs0; - for (int i = 0; i < m; ++i) { - const ValueType tA(alpha * pA[i * as0]); -#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) -#pragma unroll -#endif - for (int j = 0; j < n; ++j) pC[i * cs0 + j * cs1] += tA * pB[j * bs1]; - } - } - } - return 0; -} - -template <> -template -KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( - const int m, const int n, const int k, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, - const ScalarType beta, - /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - - constexpr int mbAlgo = Algo::Gemm::Blocked::mb(); - constexpr int nbAlgo = Algo::Gemm::Blocked::mb(); - - const ScalarType one(1.0), zero(0.0); - - if (beta == zero) - KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, C, cs0, cs1); - else if (beta != one) - KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1); - - if (alpha != zero) { - if (m <= 0 || n <= 0 || k <= 0) return 0; - const ValueType alpha_value(alpha); - - InnerGemmFixC inner(as0, as1, bs0, bs1, cs0, cs1); - auto gemm = [&](const int ib, const int jb, const int pb, - const ValueType *KOKKOS_RESTRICT AA, - const ValueType *KOKKOS_RESTRICT BB, - /**/ ValueType *KOKKOS_RESTRICT CC) { - const int mb = mbAlgo, nb = nbAlgo; - for (int i = 0; i < ib; i += mb) - for (int j = 0; j < jb; j += nb) - inner.serial_invoke(alpha_value, AA + i * as0, BB + j * bs1, - (i + mb) > ib ? (ib - i) : mb, - (j + nb) > jb ? (jb - j) : nb, pb, - CC + i * cs0 + j * cs1); - }; - - const bool is_small = true; //(m*n*k <= 64*64*64); - if (is_small) { - gemm(m, n, k, A, B, C); - } else { - // // cache blocking - // const int - // nc = nb*10, kc = mb*4, mc = mb*4; - - // for (int jj=0;jj -struct TeamVectorGemm { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return TeamVectorGemmInternal::invoke( - member, C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_0(), B.stride_1(), beta, - C.data(), C.stride_0(), C.stride_1()); - } -}; - -/// -/// T/NT -/// - -template -struct TeamVectorGemm { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return TeamVectorGemmInternal::invoke( - member, C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_0(), B.stride_1(), beta, - C.data(), C.stride_0(), C.stride_1()); - } -}; - -/// -/// NT/T -/// - -template -struct TeamVectorGemm { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return TeamVectorGemmInternal::invoke( - member, C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_1(), B.stride_0(), beta, - C.data(), C.stride_0(), C.stride_1()); - } -}; - -/// -/// T/T -/// - -template -struct TeamVectorGemm { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return TeamVectorGemmInternal::invoke( - member, C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_1(), B.stride_0(), beta, - C.data(), C.stride_0(), C.stride_1()); - } -}; - -} // namespace KokkosBatched - -#endif diff --git a/batched/dense/impl/KokkosBatched_Gemm_TeamVector_Internal.hpp b/batched/dense/impl/KokkosBatched_Gemm_TeamVector_Internal.hpp deleted file mode 100644 index a516f765a1..0000000000 --- a/batched/dense/impl/KokkosBatched_Gemm_TeamVector_Internal.hpp +++ /dev/null @@ -1,114 +0,0 @@ -#ifndef __KOKKOSBATCHED_GEMM_TEAMVECTOR_INTERNAL_HPP__ -#define __KOKKOSBATCHED_GEMM_TEAMVECTOR_INTERNAL_HPP__ - -/// \author Kyungjoo Kim (kyukim@sandia.gov) - -#include "KokkosBatched_Util.hpp" - -#include "KokkosBlas1_set_impl.hpp" -#include "KokkosBlas1_team_scal_impl.hpp" - -namespace KokkosBatched { - -/// -/// TeamVector Internal Impl -/// ==================== -template -struct TeamVectorGemmInternal { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const int m, const int n, const int k, - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, - const int bs1, const ScalarType beta, - /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1); -}; - -template <> -template -KOKKOS_INLINE_FUNCTION int -TeamVectorGemmInternal::invoke( - const MemberType &member, const int m, const int n, const int k, - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, - const int bs1, const ScalarType beta, - /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - - const ScalarType one(1.0), zero(0.0); - - if (beta == zero) - KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, n, zero, C, cs0, - cs1); - else if (beta != one) - KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, n, beta, C, - cs0, cs1); - - if (alpha != ScalarType(0.0)) { - if (m <= 0 || n <= 0 || k <= 0) return 0; - - if (beta != one) member.team_barrier(); - - Kokkos::parallel_for(Kokkos::TeamThreadRange(member, m), [&](const int &i) { - const ValueType *KOKKOS_RESTRICT pA = A + i * as0; - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, n), - [&](const int &j) { - const ValueType *KOKKOS_RESTRICT pB = B + j * bs1; - - ValueType c = ValueType(0); - for (int p = 0; p < k; ++p) - c += pA[p * as1] * pB[p * bs0]; - C[i * cs0 + j * cs1] += alpha * c; - }); - }); - } - return 0; -} - -template <> -template -KOKKOS_INLINE_FUNCTION int -TeamVectorGemmInternal::invoke( - const MemberType &member, const int m, const int n, const int k, - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, - const int bs1, const ScalarType beta, - /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - - const ScalarType one(1.0), zero(0.0); - - if (beta == zero) - KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, n, zero, C, cs0, - cs1); - else if (beta != one) - KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, n, beta, C, - cs0, cs1); - - if (alpha != ScalarType(0.0)) { - if (m <= 0 || n <= 0 || k <= 0) return 0; - - if (beta != one) member.team_barrier(); - - Kokkos::parallel_for(Kokkos::TeamThreadRange(member, m), [&](const int &i) { - const ValueType *KOKKOS_RESTRICT pA = A + i * as0; - Kokkos::parallel_for( - Kokkos::ThreadVectorRange(member, n), [&](const int &j) { - const ValueType *KOKKOS_RESTRICT pB = B + j * bs1; - - ValueType c = ValueType(0); - for (int p = 0; p < k; ++p) - c += Kokkos::ArithTraits::conj(pA[p * as1]) * - pB[p * bs0]; - C[i * cs0 + j * cs1] += alpha * c; - }); - }); - } - return 0; -} - -} // namespace KokkosBatched - -#endif diff --git a/batched/dense/impl/KokkosBatched_Gemm_Team_Impl.hpp b/batched/dense/impl/KokkosBatched_Gemm_Team_Impl.hpp deleted file mode 100644 index 62526fc96e..0000000000 --- a/batched/dense/impl/KokkosBatched_Gemm_Team_Impl.hpp +++ /dev/null @@ -1,177 +0,0 @@ -#ifndef __KOKKOSBATCHED_GEMM_TEAM_IMPL_HPP__ -#define __KOKKOSBATCHED_GEMM_TEAM_IMPL_HPP__ - -/// \author Kyungjoo Kim (kyukim@sandia.gov) - -#include "KokkosBatched_Util.hpp" -#include "KokkosBatched_Gemm_Team_Internal.hpp" - -namespace KokkosBatched { - -/// -/// Team Impl -/// ========= - -/// -/// Implemented: -/// NT/NT, T/NT, NT/T, T/T -/// -/// Not yet implemented (ConjTranspose) -/// CT/NT, NT/CT, CT/CT -/// - -/// -/// NT/NT -/// - -template -struct TeamGemm { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return TeamGemmInternal::invoke( - member, C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_0(), B.stride_1(), beta, - C.data(), C.stride_0(), C.stride_1()); - } -}; - -template -struct TeamGemm { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return TeamGemmInternal::invoke( - member, C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_0(), B.stride_1(), beta, - C.data(), C.stride_0(), C.stride_1()); - } -}; - -/// -/// T/NT -/// - -template -struct TeamGemm { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return TeamGemmInternal::invoke( - member, C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_0(), B.stride_1(), beta, - C.data(), C.stride_0(), C.stride_1()); - } -}; - -template -struct TeamGemm { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return TeamGemmInternal::invoke( - member, C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_0(), B.stride_1(), beta, - C.data(), C.stride_0(), C.stride_1()); - } -}; - -/// -/// NT/T -/// - -template -struct TeamGemm { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return TeamGemmInternal::invoke( - member, C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_1(), B.stride_0(), beta, - C.data(), C.stride_0(), C.stride_1()); - } -}; - -template -struct TeamGemm { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return TeamGemmInternal::invoke( - member, C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_1(), B.stride_0(), beta, - C.data(), C.stride_0(), C.stride_1()); - } -}; - -/// -/// T/T -/// - -template -struct TeamGemm { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return TeamGemmInternal::invoke( - member, C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_1(), B.stride_0(), beta, - C.data(), C.stride_0(), C.stride_1()); - } -}; - -template -struct TeamGemm { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return TeamGemmInternal::invoke( - member, C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_1(), B.stride_0(), beta, - C.data(), C.stride_0(), C.stride_1()); - } -}; - -} // namespace KokkosBatched - -#endif diff --git a/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp b/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp deleted file mode 100644 index 4f147a98fc..0000000000 --- a/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp +++ /dev/null @@ -1,157 +0,0 @@ -#ifndef __KOKKOSBATCHED_GEMM_TEAM_INTERNAL_HPP__ -#define __KOKKOSBATCHED_GEMM_TEAM_INTERNAL_HPP__ - -/// \author Kyungjoo Kim (kyukim@sandia.gov) - -#include "KokkosBatched_Util.hpp" -#include "KokkosKernels_ExecSpaceUtils.hpp" - -#include "KokkosBlas1_set_impl.hpp" -#include "KokkosBlas1_team_scal_impl.hpp" - -#include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp" - -namespace KokkosBatched { - -/// -/// Team Internal Impl -/// ==================== -template -struct TeamGemmInternal { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const int m, const int n, const int k, - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, - const int bs1, const ScalarType beta, - /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1); -}; - -template <> -template -KOKKOS_INLINE_FUNCTION int TeamGemmInternal::invoke( - const MemberType &member, const int m, const int n, const int k, - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, - const int bs1, const ScalarType beta, - /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - - const ScalarType one(1.0), zero(0.0); - - if (beta == zero) - KokkosBlas::Impl::TeamSetInternal::invoke(member, m, n, zero, C, cs0, cs1); - else if (beta != one) - KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, n, beta, C, cs0, - cs1); - - if (alpha != ScalarType(0.0)) { - if (m <= 0 || n <= 0 || k <= 0) return 0; - - if (beta != one) member.team_barrier(); - - Kokkos::parallel_for( - Kokkos::TeamThreadRange(member, 0, m * n), [&](const int &ij) { - // assume layout right for batched computation - const int i = ij / n, j = ij % n; - const ValueType *KOKKOS_RESTRICT pA = A + i * as0, - *KOKKOS_RESTRICT pB = B + j * bs1; - - ValueType c = ValueType(0); - for (int p = 0; p < k; ++p) c += pA[p * as1] * pB[p * bs0]; - C[i * cs0 + j * cs1] += alpha * c; - }); - } - return 0; -} - -template <> -template -KOKKOS_INLINE_FUNCTION int TeamGemmInternal::invoke( - const MemberType &member, const int m, const int n, const int k, - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, - const int bs1, const ScalarType beta, - /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - - constexpr int mbAlgo = Algo::Gemm::Blocked::mb(); - constexpr int nbAlgo = Algo::Gemm::Blocked::mb(); - - const ScalarType one(1.0), zero(0.0); - - if (beta == zero) - KokkosBlas::Impl::TeamSetInternal::invoke(member, m, n, zero, C, cs0, cs1); - else if (beta != one) - KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, n, beta, C, cs0, - cs1); - - if (alpha != ScalarType(0.0)) { - if (m <= 0 || n <= 0 || k <= 0) return 0; - - if (beta != one) member.team_barrier(); - - /// - /// GPU case: team size is large and blocksize (mb,nb) is small - InnerGemmFixC inner(as0, as1, bs0, bs1, cs0, cs1); - auto gemm = [&](const int ib, const int jb, const int pb, - const ValueType *KOKKOS_RESTRICT AA, - const ValueType *KOKKOS_RESTRICT BB, - /**/ ValueType *KOKKOS_RESTRICT CC) { - // Made this non-const in order to WORKAROUND issue #349 - int mb = mbAlgo, mp = (ib % mb), mq = (ib / mb) + (mp > 0), nb = nbAlgo, - np = (jb % nb), nq = (jb / nb) + (np > 0); - - // square tiling - Kokkos::parallel_for( - Kokkos::TeamThreadRange(member, mq * nq), [&](const int &ij) { - int i, j; - // note: the condition is constexpr - if (KokkosKernels::Impl::kk_is_gpu_exec_space< - typename MemberType::execution_space>()) { - i = ij % mq * mb; - j = ij / mq * nb; - } else { - i = ij / nq * mb; - j = ij % nq * nb; - } - inner.serial_invoke( - alpha, AA + i * as0, BB + j * bs1, (i + mb) > ib ? mp : mb, - (j + nb) > jb ? np : nb, pb, CC + i * cs0 + j * cs1); - }); - }; - - const bool is_small = true; //(m*n*k <= 64*64*64); - if (is_small) { - gemm(m, n, k, A, B, C); - } else { - // // cache blocking - // const int - // nc = nb*10, kc = mb*4, mc = mb*4; - - // for (int jj=0;jj -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC::team_invoke( - const MemberType &member, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, - const int k, - /**/ ValueType *KOKKOS_RESTRICT C) { - Kokkos::parallel_for( - Kokkos::TeamThreadRange(member, 0, mb * nb), [&](const int &ij) { - const int i = ij / nb, j = ij % nb; - - const ValueType *KOKKOS_RESTRICT pA = A + i * _as0, - *KOKKOS_RESTRICT pB = B + j * _bs1; - - ValueType c = 0; - for (int p = 0; p < k; ++p) c += pA[p * _as1] * pB[p * _bs0]; - C[i * _cs0 + j * _cs1] += alpha * c; - }); - return 0; -} - -template -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC::team_invoke( - const MemberType &member, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, - const int m, const int n, const int k, - /**/ ValueType *KOKKOS_RESTRICT C) { - Kokkos::parallel_for( - Kokkos::TeamThreadRange(member, 0, m * n), [&](const int &ij) { - const int i = ij / n, j = ij % n; - - const ValueType *KOKKOS_RESTRICT pA = A + i * _as0, - *KOKKOS_RESTRICT pB = B + j * _bs1; - - ValueType c = 0; - for (int p = 0; p < k; ++p) c += pA[p * _as1] * pB[p * _bs0]; - C[i * _cs0 + j * _cs1] += alpha * c; - }); - return 0; -} -} // namespace KokkosBatched - -#endif diff --git a/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp index 5523f20653..5e33b8b9bd 100644 --- a/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp @@ -7,7 +7,7 @@ #include "KokkosBatched_Vector.hpp" #include "KokkosBatched_InnerLU_Serial_Impl.hpp" #include "KokkosBatched_InnerTrsm_Serial_Impl.hpp" -#include "KokkosBatched_Gemm_Serial_Internal.hpp" +#include "KokkosBlas3_serial_gemm_internal.hpp" namespace KokkosBatched { @@ -106,7 +106,7 @@ KOKKOS_INLINE_FUNCTION int SerialLU_Internal::invoke( trsm_run.serial_invoke(Ap, pb, m_abr, Ap + mb * as0); // gemm update - SerialGemmInternal::invoke( + KokkosBlas::Impl::SerialGemmInternal::invoke( m_abr, n_abr, pb, minus_one, Ap + mb * as0, as0, as1, Ap + mb * as1, as0, as1, one, Ap + mb * as0 + mb * as1, as0, as1); } diff --git a/batched/dense/impl/KokkosBatched_LU_Team_Internal.hpp b/batched/dense/impl/KokkosBatched_LU_Team_Internal.hpp index 77b327e625..6960c40f18 100644 --- a/batched/dense/impl/KokkosBatched_LU_Team_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_LU_Team_Internal.hpp @@ -9,7 +9,7 @@ #include "KokkosBatched_InnerTrsm_Serial_Impl.hpp" #include "KokkosBatched_Trsm_Team_Internal.hpp" -#include "KokkosBatched_Gemm_Team_Internal.hpp" +#include "KokkosBlas3_team_gemm_internal.hpp" namespace KokkosBatched { @@ -138,7 +138,7 @@ KOKKOS_INLINE_FUNCTION int TeamLU_Internal::invoke( member.team_barrier(); // gemm update - TeamGemmInternal::invoke( + KokkosBlas::Impl::TeamGemmInternal::invoke( member, m_abr, n_abr, pb, minus_one, Ap + mb * as0, as0, as1, Ap + mb * as1, as0, as1, one, Ap + mb * as0 + mb * as1, as0, as1); } diff --git a/batched/dense/impl/KokkosBatched_SolveUTV_TeamVector_Internal.hpp b/batched/dense/impl/KokkosBatched_SolveUTV_TeamVector_Internal.hpp index 88d0bfe561..82493b4ddb 100644 --- a/batched/dense/impl/KokkosBatched_SolveUTV_TeamVector_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_SolveUTV_TeamVector_Internal.hpp @@ -5,10 +5,10 @@ #include "KokkosBatched_Util.hpp" -#include "KokkosBlas2_team_gemv_impl.hpp" +#include "KokkosBlas2_team_gemv_internal.hpp" #include "KokkosBatched_Trsv_TeamVector_Internal.hpp" -#include "KokkosBatched_Gemm_TeamVector_Internal.hpp" +#include "KokkosBlas3_team_gemm_internal.hpp" #include "KokkosBatched_Trsm_TeamVector_Internal.hpp" namespace KokkosBatched { @@ -81,7 +81,7 @@ struct TeamVectorSolveUTV_Internal { /// T is matrix_rank x matrix_rank /// V is matrix_rank x n /// W = U^T B - TeamVectorGemmInternal::invoke( + KokkosBlas::Impl::TeamVectorGemmInternal::invoke( member, matrix_rank, nrhs, m, one, U, us1, us0, B, bs0, bs1, zero, W, ws0, ws1); member.team_barrier(); @@ -92,13 +92,13 @@ struct TeamVectorSolveUTV_Internal { member.team_barrier(); /// X = V^T W - TeamVectorGemmInternal::invoke( + KokkosBlas::Impl::TeamVectorGemmInternal::invoke( member, n, nrhs, matrix_rank, one, V, vs1, vs0, W, ws0, ws1, zero, X, xs0, xs1); member.team_barrier(); } else { /// W = U^T B - TeamVectorGemmInternal::invoke( + KokkosBlas::Impl::TeamVectorGemmInternal::invoke( member, matrix_rank, nrhs, m, one, U, us1, us0, B, bs0, bs1, zero, X, xs0, xs1); member.team_barrier(); diff --git a/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp index b29b54931f..ba56af2c34 100644 --- a/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp @@ -7,7 +7,7 @@ #include "KokkosBlas1_set_impl.hpp" #include "KokkosBlas1_serial_scal_impl.hpp" -#include "KokkosBatched_InnerGemmFixA_Serial_Impl.hpp" +#include "KokkosBlas3_gemm_inner_fix.hpp" #include "KokkosBatched_InnerTrsm_Serial_Impl.hpp" namespace KokkosBatched { @@ -96,7 +96,8 @@ SerialTrsmInternalLeftLower::invoke( InnerTrsmLeftLowerUnitDiag trsm_u(as0, as1, bs0, bs1); InnerTrsmLeftLowerNonUnitDiag trsm_n(as0, as1, bs0, bs1); - InnerGemmFixA gemm(as0, as1, bs0, bs1, bs0, bs1); + KokkosBlas::InnerGemmFixA gemm(as0, as1, bs0, bs1, bs0, + bs1); auto trsm = [&](const int ib, const int jb, const ValueType *KOKKOS_RESTRICT AA, /**/ ValueType *KOKKOS_RESTRICT BB) { @@ -213,7 +214,8 @@ SerialTrsmInternalLeftUpper::invoke( InnerTrsmLeftUpperUnitDiag trsm_u(as0, as1, bs0, bs1); InnerTrsmLeftUpperNonUnitDiag trsm_n(as0, as1, bs0, bs1); - InnerGemmFixA gemm(as0, as1, bs0, bs1, bs0, bs1); + KokkosBlas::InnerGemmFixA gemm(as0, as1, bs0, bs1, bs0, + bs1); auto trsm = [&](const int ib, const int jb, const ValueType *KOKKOS_RESTRICT AA, diff --git a/batched/dense/impl/KokkosBatched_Trsm_Team_Internal.hpp b/batched/dense/impl/KokkosBatched_Trsm_Team_Internal.hpp index f9e2bed8f8..8d67631ece 100644 --- a/batched/dense/impl/KokkosBatched_Trsm_Team_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Trsm_Team_Internal.hpp @@ -9,7 +9,7 @@ #include "KokkosBlas1_set_impl.hpp" #include "KokkosBlas1_team_scal_impl.hpp" #include "KokkosBatched_InnerTrsm_Serial_Impl.hpp" -#include "KokkosBatched_Gemm_Team_Internal.hpp" +#include "KokkosBlas3_team_gemm_internal.hpp" namespace KokkosBatched { @@ -135,7 +135,7 @@ TeamTrsmInternalLeftLower::invoke( member.team_barrier(); // gemm update - TeamGemmInternal::invoke( + KokkosBlas::Impl::TeamGemmInternal::invoke( member, ib - p - pb, jb, pb, minus_one, Ap + pb * as0, as0, as1, Bp, bs0, bs1, one, Bp + pb * bs0, bs0, bs1); } @@ -270,7 +270,7 @@ TeamTrsmInternalLeftUpper::invoke( member.team_barrier(); // gemm update - TeamGemmInternal::invoke( + KokkosBlas::Impl::TeamGemmInternal::invoke( member, p, jb, pb, minus_one, Ap - p * as0, as0, as1, Bp, bs0, bs1, one, BB, bs0, bs1); } diff --git a/batched/dense/impl/KokkosBatched_Trsv_Team_Internal.hpp b/batched/dense/impl/KokkosBatched_Trsv_Team_Internal.hpp index 5583b58537..6cf5aa085e 100644 --- a/batched/dense/impl/KokkosBatched_Trsv_Team_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Trsv_Team_Internal.hpp @@ -8,7 +8,7 @@ #include "KokkosBlas1_set_impl.hpp" #include "KokkosBlas1_team_scal_impl.hpp" #include "KokkosBatched_InnerTrsm_Serial_Impl.hpp" -#include "KokkosBlas2_team_gemv_spec.hpp" +#include "KokkosBlas2_gemv.hpp" namespace KokkosBatched { diff --git a/batched/dense/src/KokkosBatched_Gemm_Decl.hpp b/batched/dense/src/KokkosBatched_Gemm_Decl.hpp index 9e830c95d4..d01628027a 100644 --- a/batched/dense/src/KokkosBatched_Gemm_Decl.hpp +++ b/batched/dense/src/KokkosBatched_Gemm_Decl.hpp @@ -49,22 +49,29 @@ #include #include #include +#include namespace KokkosBatched { /********************* BEGIN functor-level routines *********************/ +// clang-format off +// Note: formatting gets mislead by [[deprecated]] attributes + /// /// Serial Gemm /// template -struct SerialGemm { +struct [[deprecated]] SerialGemm { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, - const CViewType &C); + const CViewType &C) { + return KokkosBlas::SerialGemm::invoke( + alpha, A, B, beta, C); + } }; /// @@ -73,12 +80,15 @@ struct SerialGemm { template -struct TeamGemm { +struct [[deprecated]] TeamGemm { template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C); + const BViewType &B, const ScalarType beta, const CViewType &C) { + return KokkosBlas::TeamGemm::invoke( + member, alpha, A, B, beta, C); + } }; /// @@ -87,12 +97,15 @@ struct TeamGemm { template -struct TeamVectorGemm { +struct [[deprecated]] TeamVectorGemm { template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, - const BViewType &B, const ScalarType beta, const CViewType &C); + const BViewType &B, const ScalarType beta, const CViewType &C) { + return KokkosBlas::TeamVectorGemm::invoke( + member, alpha, A, B, beta, C); + } }; /// @@ -100,23 +113,16 @@ struct TeamVectorGemm { /// template -struct Gemm { +struct [[deprecated]] Gemm { template KOKKOS_FORCEINLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - int r_val = 0; - if (std::is_same::value) { - r_val = SerialGemm::invoke(alpha, A, B, - beta, C); - } else if (std::is_same::value) { - r_val = TeamGemm::invoke( - member, alpha, A, B, beta, C); - } - return r_val; + return KokkosBlas::Gemm(member, alpha, A, B, beta, C); } }; +// clang-format on /********************* END functor-level routines *********************/ /********************* BEGIN non-functor-level routines *********************/ @@ -647,8 +653,6 @@ int BatchedGemm(BatchedGemmHandleType *const handle, const ScalarType alpha, } // namespace KokkosBatched #include "KokkosBatched_Gemm_Serial_Impl.hpp" -#include "KokkosBatched_Gemm_Team_Impl.hpp" -#include "KokkosBatched_Gemm_TeamVector_Impl.hpp" #include "KokkosBatched_Gemm_DblBuf_Impl.hpp" #include "KokkosBatched_Gemm_Armpl_Impl.hpp" diff --git a/batched/dense/src/KokkosBatched_InnerGemmFixA_Decl.hpp b/batched/dense/src/KokkosBatched_InnerGemmFixA_Decl.hpp deleted file mode 100644 index de44fc10cc..0000000000 --- a/batched/dense/src/KokkosBatched_InnerGemmFixA_Decl.hpp +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef __KOKKOSBATCHED_INNER_GEMM_FIX_A_DECL_HPP__ -#define __KOKKOSBATCHED_INNER_GEMM_FIX_A_DECL_HPP__ - -/// \author Kyungjoo Kim (kyukim@sandia.gov) - -namespace KokkosBatched { - -template -struct InnerGemmFixA { - const int _as0, _as1, _bs0, _bs1, _cs0, _cs1; - - KOKKOS_INLINE_FUNCTION - InnerGemmFixA(const int as0, const int as1, const int bs0, const int bs1, - const int cs0, const int cs1) - : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1), _cs0(cs0), _cs1(cs1) {} - - // serial rank update - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, - const int n, - /**/ ValueType *KOKKOS_RESTRICT C); - - // serial rank update for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, - const int m, const int n, - const int k, - /**/ ValueType *KOKKOS_RESTRICT C); -}; -} // namespace KokkosBatched - -#endif diff --git a/batched/dense/src/KokkosBatched_InnerGemmFixB_Decl.hpp b/batched/dense/src/KokkosBatched_InnerGemmFixB_Decl.hpp deleted file mode 100644 index a1e4a2caf4..0000000000 --- a/batched/dense/src/KokkosBatched_InnerGemmFixB_Decl.hpp +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef __KOKKOSBATCHED_INNER_GEMM_FIX_B_DECL_HPP__ -#define __KOKKOSBATCHED_INNER_GEMM_FIX_B_DECL_HPP__ - -/// \author Kyungjoo Kim (kyukim@sandia.gov) - -namespace KokkosBatched { - -template -struct InnerGemmFixB { - const int _as0, _as1, _bs0, _bs1, _cs0, _cs1; - - KOKKOS_INLINE_FUNCTION - InnerGemmFixA(const int as0, const int as1, const int bs0, const int bs1, - const int cs0, const int cs1) - : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1), _cs0(cs0), _cs1(cs1) {} - - // serial rank update - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, - const int n, - /**/ ValueType *KOKKOS_RESTRICT C); - - // serial rank update for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, - const int m, const int n, - const int k, - /**/ ValueType *KOKKOS_RESTRICT C); -}; -} // namespace KokkosBatched - -#endif diff --git a/batched/dense/src/KokkosBatched_InnerGemmFixC_Decl.hpp b/batched/dense/src/KokkosBatched_InnerGemmFixC_Decl.hpp deleted file mode 100644 index a694609731..0000000000 --- a/batched/dense/src/KokkosBatched_InnerGemmFixC_Decl.hpp +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef __KOKKOSBATCHED_INNER_GEMM_FIX_C_DECL_HPP__ -#define __KOKKOSBATCHED_INNER_GEMM_FIX_C_DECL_HPP__ - -/// \author Kyungjoo Kim (kyukim@sandia.gov) - -namespace KokkosBatched { - -template -struct InnerGemmFixC { - const int _as0, _as1, _bs0, _bs1, _cs0, _cs1; - - KOKKOS_INLINE_FUNCTION - InnerGemmFixC(const int as0, const int as1, const int bs0, const int bs1, - const int cs0, const int cs1) - : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1), _cs0(cs0), _cs1(cs1) {} - - // serial rank update - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, - const int k, - /**/ ValueType *KOKKOS_RESTRICT C); - - // serial rank update for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, - const int m, const int k, - /**/ ValueType *KOKKOS_RESTRICT C); - - // serial rank update for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, - const int m, const int n, - const int k, - /**/ ValueType *KOKKOS_RESTRICT C); - - template - KOKKOS_INLINE_FUNCTION int team_invoke(const MemberType &member, - const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, - const int k, - /**/ ValueType *KOKKOS_RESTRICT C); - - // team rank update for remainder - template - KOKKOS_INLINE_FUNCTION int team_invoke(const MemberType &member, - const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, - const int m, const int n, const int k, - /**/ ValueType *KOKKOS_RESTRICT C); -}; -} // namespace KokkosBatched - -#endif diff --git a/batched/dense/unit_test/Test_Batched_Dense.hpp b/batched/dense/unit_test/Test_Batched_Dense.hpp index a154e9e14f..da88632add 100644 --- a/batched/dense/unit_test/Test_Batched_Dense.hpp +++ b/batched/dense/unit_test/Test_Batched_Dense.hpp @@ -7,9 +7,6 @@ #include "Test_Batched_SerialAxpy_Complex.hpp" #include "Test_Batched_SerialEigendecomposition.hpp" #include "Test_Batched_SerialEigendecomposition_Real.hpp" -#include "Test_Batched_SerialGemm.hpp" -#include "Test_Batched_SerialGemm_Real.hpp" -#include "Test_Batched_SerialGemm_Complex.hpp" #include "Test_Batched_BatchedGemm.hpp" #include "Test_Batched_BatchedGemm_Real.hpp" #include "Test_Batched_BatchedGemm_Complex.hpp" @@ -42,9 +39,6 @@ #include "Test_Batched_TeamAxpy.hpp" #include "Test_Batched_TeamAxpy_Real.hpp" #include "Test_Batched_TeamAxpy_Complex.hpp" -#include "Test_Batched_TeamGemm.hpp" -#include "Test_Batched_TeamGemm_Real.hpp" -#include "Test_Batched_TeamGemm_Complex.hpp" #include "Test_Batched_TeamGesv.hpp" #include "Test_Batched_TeamGesv_Real.hpp" #include "Test_Batched_TeamInverseLU.hpp" @@ -69,9 +63,6 @@ #include "Test_Batched_TeamVectorAxpy_Complex.hpp" #include "Test_Batched_TeamVectorEigendecomposition.hpp" #include "Test_Batched_TeamVectorEigendecomposition_Real.hpp" -#include "Test_Batched_TeamVectorGemm.hpp" -#include "Test_Batched_TeamVectorGemm_Real.hpp" -#include "Test_Batched_TeamVectorGemm_Complex.hpp" #include "Test_Batched_TeamVectorGesv.hpp" #include "Test_Batched_TeamVectorGesv_Real.hpp" #include "Test_Batched_TeamVectorQR.hpp" diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm.hpp deleted file mode 100644 index a7ec3db6a9..0000000000 --- a/batched/dense/unit_test/Test_Batched_SerialGemm.hpp +++ /dev/null @@ -1,255 +0,0 @@ -/// \author Kyungjoo Kim (kyukim@sandia.gov) - -#include "gtest/gtest.h" -#include "Kokkos_Core.hpp" -#include "Kokkos_Random.hpp" - -//#include "KokkosBatched_Vector.hpp" - -#include "KokkosBatched_Gemm_Decl.hpp" -#include "KokkosBatched_Gemm_Serial_Impl.hpp" - -#include "KokkosKernels_TestUtils.hpp" - -using namespace KokkosBatched; - -namespace Test { -namespace Gemm { - -template -struct ParamTag { - typedef TA transA; - typedef TB transB; -}; - -template -struct Functor_TestBatchedSerialGemm { - ViewType _a, _b, _c; - - ScalarType _alpha, _beta; - - KOKKOS_INLINE_FUNCTION - Functor_TestBatchedSerialGemm(const ScalarType alpha, const ViewType &a, - const ViewType &b, const ScalarType beta, - const ViewType &c) - : _a(a), _b(b), _c(c), _alpha(alpha), _beta(beta) {} - - KOKKOS_INLINE_FUNCTION - void operator()(const ParamTagType &, const int k) const { - auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); - auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); - auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL()); - - SerialGemm::invoke(_alpha, aa, bb, _beta, cc); - } - - inline void run() { - typedef typename ViewType::value_type value_type; - std::string name_region("KokkosBatched::Test::SerialGemm"); - const std::string name_value_type = Test::value_type_name(); - std::string name = name_region + name_value_type; - Kokkos::Profiling::pushRegion(name.c_str()); - Kokkos::RangePolicy policy(0, _c.extent(0)); - Kokkos::parallel_for(name.c_str(), policy, *this); - Kokkos::Profiling::popRegion(); - } -}; - -template -void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, - const int matBdim1, const int matBdim2, - const int matCdim1, const int matCdim2) { - using execution_space = typename DeviceType::execution_space; - using transA = typename ParamTagType::transA; - using transB = typename ParamTagType::transB; - using value_type = typename ViewType::value_type; - using ats = Kokkos::Details::ArithTraits; - - /// randomized input testing views - ScalarType alpha = ScalarType(1.5); - ScalarType beta = ScalarType(3.0); - - ViewType a_expected("a_expected", N, matAdim1, matAdim2), - a_actual("a_actual", N, matAdim1, matAdim2), - b_expected("b_expected", N, matBdim1, matBdim2), - b_actual("b_actual", N, matBdim1, matBdim2), - c_expected("c_expected", N, matCdim1, matCdim2), - c_actual("c_actual", N, matCdim1, matCdim2); - - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(a_expected, random, value_type(1.0)); - Kokkos::fill_random(b_expected, random, value_type(1.0)); - Kokkos::fill_random(c_expected, random, value_type(1.0)); - - Kokkos::fence(); - - Kokkos::deep_copy(a_actual, a_expected); - Kokkos::deep_copy(b_actual, b_expected); - Kokkos::deep_copy(c_actual, c_expected); - - Functor_BatchedVanillaGEMM - vgemm; - vgemm.A_t = std::is_same::value; - vgemm.B_t = std::is_same::value; - vgemm.A_c = vgemm.B_c = false; - vgemm.A = a_expected; - vgemm.B = b_expected; - vgemm.C = c_expected; - vgemm.alpha = alpha; - vgemm.beta = beta; - vgemm.run(); // Compute c_expected - Functor_TestBatchedSerialGemm(alpha, a_actual, b_actual, beta, - c_actual) - .run(); - - typename ViewType::HostMirror c_expected_host = - Kokkos::create_mirror_view(c_expected); - typename ViewType::HostMirror c_actual_host = - Kokkos::create_mirror_view(c_actual); - - // Copy to host for comparison - Kokkos::deep_copy(c_expected_host, c_expected); - Kokkos::deep_copy(c_actual_host, c_actual); - - Kokkos::fence(); - - // check c_expected = c_actual - // std::conditional<, float, - using mag_type = typename ats::mag_type; - mag_type sum(1), diff(0); - - mag_type eps = ats::epsilon(); - - eps *= std::is_same::value || - std::is_same::value - ? 4 - : 1e3; - - for (int k = 0; k < N; ++k) - for (int i = 0; i < matCdim1; ++i) - for (int j = 0; j < matCdim2; ++j) { - sum += ats::abs(c_expected_host(k, i, j)); - diff += ats::abs(c_expected_host(k, i, j) - c_actual_host(k, i, j)); - } - EXPECT_NEAR_KK(diff / sum, 0, eps); -} -} // namespace Gemm -} // namespace Test - -template -int test_batched_gemm() { -#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) - { - typedef Kokkos::View - ViewType; - Test::Gemm::impl_test_batched_gemm(0, 10, 10, 10, - 10, 10, 10); - for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); - Test::Gemm::impl_test_batched_gemm(1024, i, i, - i, i, i, i); - } - for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); - int dimM = i; - int dimN = 2 * i; - int dimK = 3 * i; - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( - 1024, dimM, dimK, dimK, dimN, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( - 1024, dimM, dimK, dimN, dimK, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( - 1024, dimK, dimM, dimK, dimN, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( - 1024, dimK, dimM, dimN, dimK, dimM, dimN); - } - } - } -#endif -#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) - { - typedef Kokkos::View - ViewType; - Test::Gemm::impl_test_batched_gemm(0, 10, 10, 10, - 10, 10, 10); - for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutRight, Blksize %d\n", i); - Test::Gemm::impl_test_batched_gemm(1024, i, i, - i, i, i, i); - } - for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); - int dimM = i; - int dimN = 2 * i; - int dimK = 3 * i; - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( - 1024, dimM, dimK, dimK, dimN, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( - 1024, dimM, dimK, dimN, dimK, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( - 1024, dimK, dimM, dimK, dimN, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( - 1024, dimK, dimM, dimN, dimK, dimM, dimN); - } - } - } -#endif - - return 0; -} diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp deleted file mode 100644 index 225b043f71..0000000000 --- a/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp +++ /dev/null @@ -1,85 +0,0 @@ -#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) - -/// dcomplex, dcomplex - -TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_dcomplex_dcomplex) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, - Kokkos::complex, param_tag_type, algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_dcomplex_dcomplex) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, - Kokkos::complex, param_tag_type, algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_dcomplex_dcomplex) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, - Kokkos::complex, param_tag_type, algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_dcomplex_dcomplex) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, - Kokkos::complex, param_tag_type, algo_tag_type>(); -} -// TEST_F( TestCategory, batched_scalar_serial_gemm_ct_nt_dcomplex_dcomplex ) { -// typedef ::Test::Gemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_gemm,Kokkos::complex,param_tag_type,algo_tag_type>(); -// } -// TEST_F( TestCategory, batched_scalar_serial_gemm_nt_ct_dcomplex_dcomplex ) { -// typedef ::Test::Gemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_gemm,Kokkos::complex,param_tag_type,algo_tag_type>(); -// } - -/// dcomplex, double - -TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_dcomplex_double) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, double, - param_tag_type, algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_dcomplex_double) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, double, - param_tag_type, algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_dcomplex_double) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, double, - param_tag_type, algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_dcomplex_double) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, double, - param_tag_type, algo_tag_type>(); -} -// TEST_F( TestCategory, batched_scalar_serial_gemm_ct_nt_dcomplex_double ) { -// typedef ::Test::Gemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_gemm,double,param_tag_type,algo_tag_type>(); -// } -// TEST_F( TestCategory, batched_scalar_serial_gemm_nt_ct_dcomplex_double ) { -// typedef ::Test::Gemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_gemm,double,param_tag_type,algo_tag_type>(); -// } - -#endif diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp deleted file mode 100644 index c10e6c0b78..0000000000 --- a/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp +++ /dev/null @@ -1,155 +0,0 @@ -#if defined(KOKKOS_BHALF_T_IS_FLOAT) -TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_bhalf_bhalf) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - - test_batched_gemm(); - test_batched_gemm(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_bhalf_bhalf) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - - test_batched_gemm(); - test_batched_gemm(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_bhalf_bhalf) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - - test_batched_gemm(); - test_batched_gemm(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_bhalf_bhalf) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - - test_batched_gemm(); - test_batched_gemm(); -} -#endif // KOKKOS_BHALF_T_IS_FLOAT - -#if defined(KOKKOS_HALF_T_IS_FLOAT) -TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_half_half) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - - test_batched_gemm(); - test_batched_gemm(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_half_half) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - - test_batched_gemm(); - test_batched_gemm(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_half_half) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - - test_batched_gemm(); - test_batched_gemm(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_half_half) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - - test_batched_gemm(); - test_batched_gemm(); -} -#endif // KOKKOS_HALF_T_IS_FLOAT - -#if defined(KOKKOSKERNELS_INST_FLOAT) -TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_float_float) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_float_float) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_float_float) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_float_float) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); -} -#endif - -#if defined(KOKKOSKERNELS_INST_DOUBLE) -TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_double_double) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_double_double) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_double_double) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); -} -TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_double_double) { - typedef ::Test::Gemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); -} -#endif diff --git a/batched/dense/unit_test/Test_Batched_SerialGesv.hpp b/batched/dense/unit_test/Test_Batched_SerialGesv.hpp index 7f9723a4b9..019ee32922 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGesv.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGesv.hpp @@ -6,7 +6,7 @@ #include "KokkosBatched_Gesv.hpp" #include "KokkosBatched_Dot.hpp" -#include "KokkosBlas2_serial_gemv_impl.hpp" +#include "KokkosBlas2_gemv.hpp" #include "KokkosKernels_TestUtils.hpp" diff --git a/batched/dense/unit_test/Test_Batched_SerialInverseLU.hpp b/batched/dense/unit_test/Test_Batched_SerialInverseLU.hpp index fd7d0478fc..0413e0a452 100644 --- a/batched/dense/unit_test/Test_Batched_SerialInverseLU.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialInverseLU.hpp @@ -6,8 +6,7 @@ //#include "KokkosBatched_Vector.hpp" -#include "KokkosBatched_Gemm_Decl.hpp" -#include "KokkosBatched_Gemm_Serial_Impl.hpp" +#include "KokkosBlas3_gemm.hpp" #include "KokkosBatched_LU_Decl.hpp" #include "KokkosBatched_LU_Serial_Impl.hpp" #include "KokkosBatched_InverseLU_Decl.hpp" @@ -47,8 +46,9 @@ struct Functor_BatchedSerialGemm { for (int i = 0; i < static_cast(aa.extent(0)); ++i) aa(i, i) += 10.0; - SerialGemm::invoke(_alpha, aa, bb, _beta, cc); + KokkosBlas::SerialGemm::invoke(_alpha, aa, bb, _beta, cc); } inline void run() { diff --git a/batched/dense/unit_test/Test_Batched_SerialSolveLU.hpp b/batched/dense/unit_test/Test_Batched_SerialSolveLU.hpp index b6d8e1aecf..323d5be709 100644 --- a/batched/dense/unit_test/Test_Batched_SerialSolveLU.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialSolveLU.hpp @@ -6,8 +6,7 @@ //#include "KokkosBatched_Vector.hpp" -#include "KokkosBatched_Gemm_Decl.hpp" -#include "KokkosBatched_Gemm_Serial_Impl.hpp" +#include "KokkosBlas3_gemm.hpp" #include "KokkosBatched_LU_Decl.hpp" #include "KokkosBatched_LU_Serial_Impl.hpp" #include "KokkosBatched_SolveLU_Decl.hpp" @@ -47,8 +46,9 @@ struct Functor_BatchedSerialGemm { for (int i = 0; i < static_cast(aa.extent(0)); ++i) aa(i, i) += 10.0; - SerialGemm::invoke(_alpha, aa, bb, _beta, cc); + KokkosBlas::SerialGemm::invoke(_alpha, aa, bb, _beta, cc); } inline void run() { diff --git a/batched/dense/unit_test/Test_Batched_TeamGemm.hpp b/batched/dense/unit_test/Test_Batched_TeamGemm.hpp deleted file mode 100644 index d5aa853482..0000000000 --- a/batched/dense/unit_test/Test_Batched_TeamGemm.hpp +++ /dev/null @@ -1,264 +0,0 @@ -/// \author Kyungjoo Kim (kyukim@sandia.gov) - -#include "gtest/gtest.h" -#include "Kokkos_Core.hpp" -#include "Kokkos_Random.hpp" - -//#include "KokkosBatched_Vector.hpp" - -#include "KokkosBatched_Gemm_Decl.hpp" -#include "KokkosBatched_Gemm_Serial_Impl.hpp" -#include "KokkosBatched_Gemm_Team_Impl.hpp" - -#include "KokkosKernels_TestUtils.hpp" - -using namespace KokkosBatched; - -namespace Test { -namespace TeamGemm { - -template -struct ParamTag { - typedef TA transA; - typedef TB transB; -}; - -template -struct Functor_TestBatchedTeamGemm { - ViewType _a, _b, _c; - - ScalarType _alpha, _beta; - - KOKKOS_INLINE_FUNCTION - Functor_TestBatchedTeamGemm(const ScalarType alpha, const ViewType &a, - const ViewType &b, const ScalarType beta, - const ViewType &c) - : _a(a), _b(b), _c(c), _alpha(alpha), _beta(beta) {} - - template - KOKKOS_INLINE_FUNCTION void operator()(const ParamTagType &, - const MemberType &member) const { - const int k = member.league_rank(); - - auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); - auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); - auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL()); - - KokkosBatched::TeamGemm::invoke(member, _alpha, aa, bb, _beta, - cc); - } - - inline void run() { - typedef typename ViewType::value_type value_type; - std::string name_region("KokkosBatched::Test::TeamGemm"); - const std::string name_value_type = Test::value_type_name(); - std::string name = name_region + name_value_type; - Kokkos::Profiling::pushRegion(name.c_str()); - const int league_size = _c.extent(0); - Kokkos::TeamPolicy policy(league_size, - Kokkos::AUTO); - Kokkos::parallel_for(name.c_str(), policy, *this); - Kokkos::Profiling::popRegion(); - } -}; - -template -void impl_test_batched_teamgemm(const int N, const int matAdim1, - const int matAdim2, const int matBdim1, - const int matBdim2, const int matCdim1, - const int matCdim2) { - using transA = typename ParamTagType::transA; - using transB = typename ParamTagType::transB; - using execution_space = typename DeviceType::execution_space; - using value_type = typename ViewType::value_type; - using ats = Kokkos::Details::ArithTraits; - - /// randomized input testing views - ScalarType alpha = ScalarType(1.5), beta = ScalarType(3.0); - - ViewType a_expected("a_expected", N, matAdim1, matAdim2), - a_actual("a_actual", N, matAdim1, matAdim2), - b_expected("b_expected", N, matBdim1, matBdim2), - b_actual("b_actual", N, matBdim1, matBdim2), - c_expected("c_expected", N, matCdim1, matCdim2), - c_actual("c_actual", N, matCdim1, matCdim2); - - Kokkos::Random_XorShift64_Pool random( - 13718); - - Kokkos::fill_random(a_expected, random, value_type(1.0)); - Kokkos::fill_random(b_expected, random, value_type(1.0)); - Kokkos::fill_random(c_expected, random, value_type(1.0)); - - Kokkos::fence(); - - Kokkos::deep_copy(a_actual, a_expected); - Kokkos::deep_copy(b_actual, b_expected); - Kokkos::deep_copy(c_actual, c_expected); - - Functor_BatchedVanillaGEMM - vgemm; - vgemm.A_t = std::is_same::value; - vgemm.B_t = std::is_same::value; - vgemm.A_c = vgemm.B_c = false; - vgemm.A = a_expected; - vgemm.B = b_expected; - vgemm.C = c_expected; - vgemm.alpha = alpha; - vgemm.beta = beta; - vgemm.run(); // Compute c_expected - - Functor_TestBatchedTeamGemm(alpha, a_actual, b_actual, beta, - c_actual) - .run(); - - Kokkos::fence(); - - typename ViewType::HostMirror c_expected_host = - Kokkos::create_mirror_view(c_expected); - typename ViewType::HostMirror c_actual_host = - Kokkos::create_mirror_view(c_actual); - - // Copy to host for comparision - Kokkos::deep_copy(c_expected_host, c_expected); - Kokkos::deep_copy(c_actual_host, c_actual); - - using mag_type = typename ats::mag_type; - mag_type sum(1), diff(0); - mag_type eps = ats::epsilon(); - - eps *= std::is_same::value || - std::is_same::value - ? 4 - : 1e3; - - for (int k = 0; k < N; ++k) - for (int i = 0; i < matCdim1; ++i) - for (int j = 0; j < matCdim2; ++j) { - sum += ats::abs(c_expected_host(k, i, j)); - diff += ats::abs(c_expected_host(k, i, j) - c_actual_host(k, i, j)); - } - EXPECT_NEAR_KK(diff / sum, 0, eps); -} -} // namespace TeamGemm -} // namespace Test - -// void (*impl_test)(const int, const int, const int, const int, const int, -// const int, const int) -template -int test_batched_teamgemm() { -#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) - { - typedef Kokkos::View - ViewType; - Test::TeamGemm::impl_test_batched_teamgemm( - 0, 10, 10, 10, 10, 10, 10); - for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); - Test::TeamGemm::impl_test_batched_teamgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, i, i, i, i, i, i); - } - for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); - int dimM = i; - int dimN = 2 * i; - int dimK = 3 * i; - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamGemm::impl_test_batched_teamgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimM, dimK, dimK, dimN, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamGemm::impl_test_batched_teamgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimM, dimK, dimN, dimK, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamGemm::impl_test_batched_teamgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimK, dimM, dimK, dimN, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamGemm::impl_test_batched_teamgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimK, dimM, dimN, dimK, dimM, dimN); - } - } - } -#endif -#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) - { - typedef Kokkos::View - ViewType; - Test::TeamGemm::impl_test_batched_teamgemm( - 0, 10, 10, 10, 10, 10, 10); - for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutRight, Blksize %d\n", i); - Test::TeamGemm::impl_test_batched_teamgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, i, i, i, i, i, i); - } - for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); - int dimM = i; - int dimN = 2 * i; - int dimK = 3 * i; - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamGemm::impl_test_batched_teamgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimM, dimK, dimK, dimN, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamGemm::impl_test_batched_teamgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimM, dimK, dimN, dimK, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamGemm::impl_test_batched_teamgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimK, dimM, dimK, dimN, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamGemm::impl_test_batched_teamgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimK, dimM, dimN, dimK, dimM, dimN); - } - } - } -#endif - - return 0; -} diff --git a/batched/dense/unit_test/Test_Batched_TeamGemm_Complex.hpp b/batched/dense/unit_test/Test_Batched_TeamGemm_Complex.hpp deleted file mode 100644 index 92852f45af..0000000000 --- a/batched/dense/unit_test/Test_Batched_TeamGemm_Complex.hpp +++ /dev/null @@ -1,90 +0,0 @@ - -#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) - -/// dcomplex, dcomplex - -TEST_F(TestCategory, batched_scalar_team_gemm_nt_nt_dcomplex_dcomplex) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm, - Kokkos::complex, param_tag_type, - algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_t_nt_dcomplex_dcomplex) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm, - Kokkos::complex, param_tag_type, - algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_nt_t_dcomplex_dcomplex) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm, - Kokkos::complex, param_tag_type, - algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_t_t_dcomplex_dcomplex) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm, - Kokkos::complex, param_tag_type, - algo_tag_type>(); -} -// TEST_F( TestCategory, batched_scalar_team_gemm_ct_nt_dcomplex_dcomplex ) { -// typedef ::Test::TeamGemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_teamgemm,Kokkos::complex,param_tag_type,algo_tag_type>(); -// } -// TEST_F( TestCategory, batched_scalar_team_gemm_nt_ct_dcomplex_dcomplex ) { -// typedef ::Test::TeamGemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_teamgemm,Kokkos::complex,param_tag_type,algo_tag_type>(); -// } - -/// dcomplex, double - -TEST_F(TestCategory, batched_scalar_team_gemm_nt_nt_dcomplex_double) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm, double, - param_tag_type, algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_t_nt_dcomplex_double) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm, double, - param_tag_type, algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_nt_t_dcomplex_double) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm, double, - param_tag_type, algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_t_t_dcomplex_double) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm, double, - param_tag_type, algo_tag_type>(); -} -// TEST_F( TestCategory, batched_scalar_team_gemm_ct_nt_dcomplex_double ) { -// typedef ::Test::TeamGemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_teamgemm,double,param_tag_type,algo_tag_type>(); -// } -// TEST_F( TestCategory, batched_scalar_team_gemm_nt_ct_dcomplex_double ) { -// typedef ::Test::TeamGemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_teamgemm,double,param_tag_type,algo_tag_type>(); -// } - -#endif diff --git a/batched/dense/unit_test/Test_Batched_TeamGemm_Real.hpp b/batched/dense/unit_test/Test_Batched_TeamGemm_Real.hpp deleted file mode 100644 index 361675ed9c..0000000000 --- a/batched/dense/unit_test/Test_Batched_TeamGemm_Real.hpp +++ /dev/null @@ -1,155 +0,0 @@ -#if defined(KOKKOS_BHALF_T_IS_FLOAT) -TEST_F(TestCategory, batched_scalar_team_gemm_nt_nt_bhalf_bhalf) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - - test_batched_teamgemm(); - test_batched_teamgemm(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_t_nt_bhalf_bhalf) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - - test_batched_teamgemm(); - test_batched_teamgemm(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_nt_t_bhalf_bhalf) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - - test_batched_teamgemm(); - test_batched_teamgemm(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_t_t_bhalf_bhalf) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - - test_batched_teamgemm(); - test_batched_teamgemm(); -} -#endif // KOKKOS_BHALF_T_IS_FLOAT - -#if defined(KOKKOS_HALF_T_IS_FLOAT) -TEST_F(TestCategory, batched_scalar_team_gemm_nt_nt_half_half) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - - test_batched_teamgemm(); - test_batched_teamgemm(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_t_nt_half_half) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - - test_batched_teamgemm(); - test_batched_teamgemm(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_nt_t_half_half) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - - test_batched_teamgemm(); - test_batched_teamgemm(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_t_t_half_half) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - - test_batched_teamgemm(); - test_batched_teamgemm(); -} -#endif // KOKKOS_HALF_T_IS_FLOAT - -#if defined(KOKKOSKERNELS_INST_FLOAT) -TEST_F(TestCategory, batched_scalar_team_gemm_nt_nt_float_float) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_t_nt_float_float) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_nt_t_float_float) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_t_t_float_float) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm(); -} -#endif - -#if defined(KOKKOSKERNELS_INST_DOUBLE) -TEST_F(TestCategory, batched_scalar_team_gemm_nt_nt_double_double) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_t_nt_double_double) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_nt_t_double_double) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm(); -} -TEST_F(TestCategory, batched_scalar_team_gemm_t_t_double_double) { - typedef ::Test::TeamGemm::ParamTag - param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_teamgemm(); -} -#endif diff --git a/batched/dense/unit_test/Test_Batched_TeamGesv.hpp b/batched/dense/unit_test/Test_Batched_TeamGesv.hpp index a7acfdcf9b..4947d31b37 100644 --- a/batched/dense/unit_test/Test_Batched_TeamGesv.hpp +++ b/batched/dense/unit_test/Test_Batched_TeamGesv.hpp @@ -6,7 +6,7 @@ #include "KokkosBatched_Gesv.hpp" #include "KokkosBatched_Dot.hpp" -#include "KokkosBlas2_serial_gemv_impl.hpp" +#include "KokkosBlas2_gemv.hpp" #include "KokkosKernels_TestUtils.hpp" diff --git a/batched/dense/unit_test/Test_Batched_TeamInverseLU.hpp b/batched/dense/unit_test/Test_Batched_TeamInverseLU.hpp index 4db8a69155..f5f991effc 100644 --- a/batched/dense/unit_test/Test_Batched_TeamInverseLU.hpp +++ b/batched/dense/unit_test/Test_Batched_TeamInverseLU.hpp @@ -6,8 +6,7 @@ //#include "KokkosBatched_Vector.hpp" -#include "KokkosBatched_Gemm_Decl.hpp" -#include "KokkosBatched_Gemm_Team_Impl.hpp" +#include "KokkosBlas3_gemm.hpp" #include "KokkosBatched_LU_Decl.hpp" #include "KokkosBatched_LU_Team_Impl.hpp" #include "KokkosBatched_InverseLU_Decl.hpp" @@ -53,10 +52,10 @@ struct Functor_BatchedTeamGemm { } member.team_barrier(); - KokkosBatched::TeamGemm::invoke(member, _alpha, aa, bb, _beta, - cc); + KokkosBlas::TeamGemm::invoke(member, _alpha, aa, bb, _beta, + cc); } inline void run() { diff --git a/batched/dense/unit_test/Test_Batched_TeamSolveLU.hpp b/batched/dense/unit_test/Test_Batched_TeamSolveLU.hpp index 201cc025fc..ee2a0b703d 100644 --- a/batched/dense/unit_test/Test_Batched_TeamSolveLU.hpp +++ b/batched/dense/unit_test/Test_Batched_TeamSolveLU.hpp @@ -6,8 +6,7 @@ //#include "KokkosBatched_Vector.hpp" -#include "KokkosBatched_Gemm_Decl.hpp" -#include "KokkosBatched_Gemm_Team_Impl.hpp" +#include "KokkosBlas3_gemm.hpp" #include "KokkosBatched_LU_Decl.hpp" #include "KokkosBatched_LU_Team_Impl.hpp" #include "KokkosBatched_SolveLU_Decl.hpp" @@ -53,10 +52,10 @@ struct Functor_BatchedTeamGemm { } member.team_barrier(); - KokkosBatched::TeamGemm::invoke(member, _alpha, aa, bb, _beta, - cc); + KokkosBlas::TeamGemm::invoke(member, _alpha, aa, bb, _beta, + cc); } inline void run() { diff --git a/batched/dense/unit_test/Test_Batched_TeamVectorGemm.hpp b/batched/dense/unit_test/Test_Batched_TeamVectorGemm.hpp deleted file mode 100644 index 8d10440bc2..0000000000 --- a/batched/dense/unit_test/Test_Batched_TeamVectorGemm.hpp +++ /dev/null @@ -1,263 +0,0 @@ -#include "gtest/gtest.h" -#include "Kokkos_Core.hpp" -#include "Kokkos_Random.hpp" - -#include "KokkosBatched_Gemm_Decl.hpp" -#include "KokkosBatched_Gemm_TeamVector_Impl.hpp" - -#include "KokkosKernels_TestUtils.hpp" - -using namespace KokkosBatched; - -namespace Test { -namespace TeamVectorGemm { - -template -struct ParamTag { - typedef TA transA; - typedef TB transB; -}; - -template -struct Functor_TestBatchedTeamVector { - ViewType _a, _b, _c; - - ScalarType _alpha, _beta; - - KOKKOS_INLINE_FUNCTION - Functor_TestBatchedTeamVector(const ScalarType alpha, const ViewType &a, - const ViewType &b, const ScalarType beta, - const ViewType &c) - : _a(a), _b(b), _c(c), _alpha(alpha), _beta(beta) {} - - template - KOKKOS_INLINE_FUNCTION void operator()(const ParamTagType &, - const MemberType &member) const { - const int k = member.league_rank(); - - auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); - auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); - auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL()); - - KokkosBatched::TeamVectorGemm::invoke(member, _alpha, aa, bb, - _beta, cc); - } - - inline void run() { - typedef typename ViewType::value_type value_type; - std::string name_region("KokkosBatched::Test::TeamVector"); - const std::string name_value_type = Test::value_type_name(); - std::string name = name_region + name_value_type; - Kokkos::Profiling::pushRegion(name.c_str()); - const int league_size = _c.extent(0); - Kokkos::TeamPolicy policy(league_size, - Kokkos::AUTO); - Kokkos::parallel_for(name.c_str(), policy, *this); - Kokkos::Profiling::popRegion(); - } -}; - -template -void impl_test_batched_teamvectorgemm(const int N, const int matAdim1, - const int matAdim2, const int matBdim1, - const int matBdim2, const int matCdim1, - const int matCdim2) { - using transA = typename ParamTagType::transA; - using transB = typename ParamTagType::transB; - using execution_space = typename DeviceType::execution_space; - using value_type = typename ViewType::value_type; - using ats = Kokkos::Details::ArithTraits; - - /// randomized input testing views - ScalarType alpha = ScalarType(1.5), beta = ScalarType(3.0); - - ViewType a_expected("a_expected", N, matAdim1, matAdim2), - a_actual("a_actual", N, matAdim1, matAdim2), - b_expected("b_expected", N, matBdim1, matBdim2), - b_actual("b_actual", N, matBdim1, matBdim2), - c_expected("c_expected", N, matCdim1, matCdim2), - c_actual("c_actual", N, matCdim1, matCdim2); - - Kokkos::Random_XorShift64_Pool random( - 13718); - - Kokkos::fill_random(a_expected, random, value_type(1.0)); - Kokkos::fill_random(b_expected, random, value_type(1.0)); - Kokkos::fill_random(c_expected, random, value_type(1.0)); - - Kokkos::fence(); - - Kokkos::deep_copy(a_actual, a_expected); - Kokkos::deep_copy(b_actual, b_expected); - Kokkos::deep_copy(c_actual, c_expected); - - // Functor_TestBatchedTeamVector(alpha, a_expected, b_expected, - // beta, c_expected).run(); - Functor_BatchedVanillaGEMM - vgemm; - vgemm.A_t = std::is_same::value; - vgemm.B_t = std::is_same::value; - vgemm.A_c = vgemm.B_c = false; - vgemm.A = a_expected; - vgemm.B = b_expected; - vgemm.C = c_expected; - vgemm.alpha = alpha; - vgemm.beta = beta; - vgemm.run(); // Compute c_expected - - Functor_TestBatchedTeamVector(alpha, a_actual, b_actual, beta, - c_actual) - .run(); - - Kokkos::fence(); - - typename ViewType::HostMirror c_expected_host = - Kokkos::create_mirror_view(c_expected); - typename ViewType::HostMirror c_actual_host = - Kokkos::create_mirror_view(c_actual); - - // Copy to host for comparison - Kokkos::deep_copy(c_expected_host, c_expected); - Kokkos::deep_copy(c_actual_host, c_actual); - - using mag_type = typename ats::mag_type; - mag_type sum(1), diff(0); - - mag_type eps = ats::epsilon(); - - eps *= std::is_same::value || - std::is_same::value - ? 4 - : 1e3; - - for (int k = 0; k < N; ++k) - for (int i = 0; i < matCdim1; ++i) - for (int j = 0; j < matCdim2; ++j) { - sum += ats::abs(c_expected_host(k, i, j)); - diff += ats::abs(c_expected_host(k, i, j) - c_actual_host(k, i, j)); - } - EXPECT_NEAR_KK(diff / sum, 0, eps); -} -} // namespace TeamVectorGemm -} // namespace Test - -// void (*impl_test)(const int, const int, const int, const int, const int, -// const int, const int) -template -int test_batched_teamvectorgemm() { -#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) - { - typedef Kokkos::View - ViewType; - Test::TeamVectorGemm::impl_test_batched_teamvectorgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 0, 10, 10, 10, 10, 10, 10); - for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); - Test::TeamVectorGemm::impl_test_batched_teamvectorgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, i, i, i, i, i, i); - } - for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); - int dimM = i; - int dimN = 2 * i; - int dimK = 3 * i; - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamVectorGemm::impl_test_batched_teamvectorgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimM, dimK, dimK, dimN, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamVectorGemm::impl_test_batched_teamvectorgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimM, dimK, dimN, dimK, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamVectorGemm::impl_test_batched_teamvectorgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimK, dimM, dimK, dimN, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamVectorGemm::impl_test_batched_teamvectorgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimK, dimM, dimN, dimK, dimM, dimN); - } - } - } -#endif -#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) - { - typedef Kokkos::View - ViewType; - Test::TeamVectorGemm::impl_test_batched_teamvectorgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 0, 10, 10, 10, 10, 10, 10); - for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutRight, Blksize %d\n", i); - Test::TeamVectorGemm::impl_test_batched_teamvectorgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, i, i, i, i, i, i); - } - for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); - int dimM = i; - int dimN = 2 * i; - int dimK = 3 * i; - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamVectorGemm::impl_test_batched_teamvectorgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimM, dimK, dimK, dimN, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamVectorGemm::impl_test_batched_teamvectorgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimM, dimK, dimN, dimK, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamVectorGemm::impl_test_batched_teamvectorgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimK, dimM, dimK, dimN, dimM, dimN); - } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::TeamVectorGemm::impl_test_batched_teamvectorgemm< - DeviceType, ViewType, ScalarType, ParamTagType, AlgoTagType>( - 1024, dimK, dimM, dimN, dimK, dimM, dimN); - } - } - } -#endif - - return 0; -} diff --git a/batched/dense/unit_test/Test_Batched_TeamVectorGemm_Complex.hpp b/batched/dense/unit_test/Test_Batched_TeamVectorGemm_Complex.hpp deleted file mode 100644 index a348c35b98..0000000000 --- a/batched/dense/unit_test/Test_Batched_TeamVectorGemm_Complex.hpp +++ /dev/null @@ -1,79 +0,0 @@ -#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) -TEST_F(TestCategory, batched_scalar_team_vector_gemm_nt_nt_scomplex_scomplex) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm,Kokkos::complex,param_tag_type,Algo::Gemm::Blocked>(); - test_batched_teamvectorgemm, - Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_t_nt_scomplex_scomplex) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm,Kokkos::complex,param_tag_type,Algo::Gemm::Blocked>(); - test_batched_teamvectorgemm, - Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_nt_t_scomplex_scomplex) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm,Kokkos::complex,param_tag_type,Algo::Gemm::Blocked>(); - test_batched_teamvectorgemm, - Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_t_t_scomplex_scomplex) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm,Kokkos::complex,param_tag_type,Algo::Gemm::Blocked>(); - test_batched_teamvectorgemm, - Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); -} -#endif - -#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) -TEST_F(TestCategory, batched_scalar_team_vector_gemm_nt_nt_dcomplex_dcomplex) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm,Kokkos::complex,param_tag_type,Algo::Gemm::Blocked>(); - test_batched_teamvectorgemm, - Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_t_nt_dcomplex_dcomplex) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm,Kokkos::complex,param_tag_type,Algo::Gemm::Blocked>(); - test_batched_teamvectorgemm, - Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_nt_t_dcomplex_dcomplex) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm,Kokkos::complex,param_tag_type,Algo::Gemm::Blocked>(); - test_batched_teamvectorgemm, - Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_t_t_dcomplex_dcomplex) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm,Kokkos::complex,param_tag_type,Algo::Gemm::Blocked>(); - test_batched_teamvectorgemm, - Kokkos::complex, param_tag_type, - Algo::Gemm::Unblocked>(); -} -#endif diff --git a/batched/dense/unit_test/Test_Batched_TeamVectorGemm_Real.hpp b/batched/dense/unit_test/Test_Batched_TeamVectorGemm_Real.hpp deleted file mode 100644 index ed43e31bf7..0000000000 --- a/batched/dense/unit_test/Test_Batched_TeamVectorGemm_Real.hpp +++ /dev/null @@ -1,151 +0,0 @@ -#if defined(KOKKOS_BHALF_T_IS_FLOAT) -TEST_F(TestCategory, batched_scalar_team_vector_gemm_nt_nt_bhalf_bhalf) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_t_nt_bhalf_bhalf) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_nt_t_bhalf_bhalf) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_t_t_bhalf_bhalf) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -#endif // KOKKOS_BHALF_T_IS_FLOAT - -#if defined(KOKKOS_HALF_T_IS_FLOAT) -TEST_F(TestCategory, batched_scalar_team_vector_gemm_nt_nt_half_half) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_t_nt_half_half) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_nt_t_half_half) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_t_t_half_half) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -#endif // KOKKOS_HALF_T_IS_FLOAT - -#if defined(KOKKOSKERNELS_INST_FLOAT) -TEST_F(TestCategory, batched_scalar_team_vector_gemm_nt_nt_float_float) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_t_nt_float_float) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_nt_t_float_float) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_t_t_float_float) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -#endif - -#if defined(KOKKOSKERNELS_INST_DOUBLE) -TEST_F(TestCategory, batched_scalar_team_vector_gemm_nt_nt_double_double) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_t_nt_double_double) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_nt_t_double_double) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -TEST_F(TestCategory, batched_scalar_team_vector_gemm_t_t_double_double) { - typedef ::Test::TeamVectorGemm::ParamTag - param_tag_type; - - // test_batched_teamvectorgemm(); - test_batched_teamvectorgemm(); -} -#endif diff --git a/batched/dense/unit_test/Test_Batched_TeamVectorGesv.hpp b/batched/dense/unit_test/Test_Batched_TeamVectorGesv.hpp index c09e395edd..938308a4fa 100644 --- a/batched/dense/unit_test/Test_Batched_TeamVectorGesv.hpp +++ b/batched/dense/unit_test/Test_Batched_TeamVectorGesv.hpp @@ -6,7 +6,6 @@ #include "KokkosBatched_Gesv.hpp" #include "KokkosBatched_Dot.hpp" -#include "KokkosBlas2_serial_gemv_impl.hpp" #include "KokkosKernels_TestUtils.hpp" diff --git a/batched/dense/unit_test/Test_Batched_TeamVectorQR.hpp b/batched/dense/unit_test/Test_Batched_TeamVectorQR.hpp index bb5cd89c9b..9c51e27c9d 100644 --- a/batched/dense/unit_test/Test_Batched_TeamVectorQR.hpp +++ b/batched/dense/unit_test/Test_Batched_TeamVectorQR.hpp @@ -6,7 +6,7 @@ #include "KokkosBlas1_set.hpp" #include "KokkosBatched_Copy_Decl.hpp" -#include "KokkosBlas2_team_gemv_spec.hpp" +#include "KokkosBlas2_gemv.hpp" #include "KokkosBatched_Trsv_Decl.hpp" #include "KokkosBatched_QR_Decl.hpp" #include "KokkosBatched_ApplyQ_Decl.hpp" @@ -53,7 +53,7 @@ struct Functor_TestBatchedTeamVectorQR { member.team_barrier(); /// bb = AA*xx - KokkosBlas::TeamVectorGemv::invoke(member, one, aa, xx, zero, bb); member.team_barrier(); diff --git a/batched/dense/unit_test/Test_Batched_TeamVectorQR_WithColumnPivoting.hpp b/batched/dense/unit_test/Test_Batched_TeamVectorQR_WithColumnPivoting.hpp index 743810d4ce..ef87f5d57b 100644 --- a/batched/dense/unit_test/Test_Batched_TeamVectorQR_WithColumnPivoting.hpp +++ b/batched/dense/unit_test/Test_Batched_TeamVectorQR_WithColumnPivoting.hpp @@ -6,7 +6,7 @@ #include "KokkosBatched_Copy_Decl.hpp" #include "KokkosBatched_ApplyPivot_Decl.hpp" -#include "KokkosBlas2_team_gemv_spec.hpp" +#include "KokkosBlas2_gemv.hpp" #include "KokkosBatched_Trsv_Decl.hpp" #include "KokkosBatched_QR_WithColumnPivoting_Decl.hpp" #include "KokkosBatched_ApplyQ_Decl.hpp" @@ -53,7 +53,7 @@ struct Functor_TestBatchedTeamVectorQR_WithColumnPivoting { member.team_barrier(); /// bb = AA*xx - KokkosBlas::TeamVectorGemv::invoke(member, one, aa, xx, zero, bb); member.team_barrier(); diff --git a/batched/dense/unit_test/Test_Batched_TeamVectorSolveUTV.hpp b/batched/dense/unit_test/Test_Batched_TeamVectorSolveUTV.hpp index 08375a95f5..729513319d 100644 --- a/batched/dense/unit_test/Test_Batched_TeamVectorSolveUTV.hpp +++ b/batched/dense/unit_test/Test_Batched_TeamVectorSolveUTV.hpp @@ -6,7 +6,7 @@ #include "KokkosBatched_Copy_Decl.hpp" #include "KokkosBatched_ApplyPivot_Decl.hpp" -#include "KokkosBlas2_team_gemv_spec.hpp" +#include "KokkosBlas2_gemv.hpp" #include "KokkosBatched_Trsv_Decl.hpp" #include "KokkosBatched_UTV_Decl.hpp" #include "KokkosBatched_SolveUTV_Decl.hpp" @@ -79,7 +79,7 @@ struct Functor_TestBatchedTeamVectorSolveUTV { TeamVectorCopy::invoke(member, aa, ac); /// bb = AA*xx - KokkosBlas::TeamVectorGemv::invoke(member, one, aa, xx, zero, bb); member.team_barrier(); diff --git a/batched/dense/unit_test/Test_Batched_TeamVectorSolveUTV2.hpp b/batched/dense/unit_test/Test_Batched_TeamVectorSolveUTV2.hpp index 77bec61c28..5760a80c89 100644 --- a/batched/dense/unit_test/Test_Batched_TeamVectorSolveUTV2.hpp +++ b/batched/dense/unit_test/Test_Batched_TeamVectorSolveUTV2.hpp @@ -6,7 +6,7 @@ #include "KokkosBatched_Copy_Decl.hpp" #include "KokkosBatched_ApplyPivot_Decl.hpp" -#include "KokkosBatched_Gemm_Decl.hpp" +#include "KokkosBlas3_gemm.hpp" #include "KokkosBatched_Trsv_Decl.hpp" #include "KokkosBatched_UTV_Decl.hpp" #include "KokkosBatched_SolveUTV_Decl.hpp" @@ -81,11 +81,9 @@ struct Functor_TestBatchedTeamVectorSolveUTV2 { TeamVectorCopy::invoke(member, aa, ac); /// bb = AA*xx - KokkosBatched::TeamVectorGemm::invoke(member, one, - aa, xx, zero, - bb); + KokkosBlas::TeamVectorGemm::invoke(member, one, aa, + xx, zero, bb); member.team_barrier(); /// Solving Ax = b using UTV transformation diff --git a/batched/dense/unit_test/Test_Batched_TeamVectorUTV.hpp b/batched/dense/unit_test/Test_Batched_TeamVectorUTV.hpp index 06ca4b2fb8..c680fe9501 100644 --- a/batched/dense/unit_test/Test_Batched_TeamVectorUTV.hpp +++ b/batched/dense/unit_test/Test_Batched_TeamVectorUTV.hpp @@ -6,7 +6,7 @@ #include "KokkosBatched_Copy_Decl.hpp" #include "KokkosBatched_ApplyPivot_Decl.hpp" -#include "KokkosBlas2_team_gemv_spec.hpp" +#include "KokkosBlas2_gemv.hpp" #include "KokkosBatched_Trsv_Decl.hpp" #include "KokkosBatched_UTV_Decl.hpp" @@ -78,7 +78,7 @@ struct Functor_TestBatchedTeamVectorUTV { TeamVectorCopy::invoke(member, aa, ac); /// bb = AA*xx - KokkosBlas::TeamVectorGemv::invoke(member, one, aa, xx, zero, bb); member.team_barrier(); @@ -98,7 +98,7 @@ struct Functor_TestBatchedTeamVectorUTV { auto vm = Kokkos::subview(vv, range_upto_rank, Kokkos::ALL()); if (matrix_rank < m) { /// w = U^T b - KokkosBlas::TeamVectorGemv::invoke(member, one, um, bb, zero, ww); member.team_barrier(); @@ -109,13 +109,13 @@ struct Functor_TestBatchedTeamVectorUTV { member.team_barrier(); /// x = V^T w - KokkosBlas::TeamVectorGemv::invoke(member, one, vm, ww, zero, xx); member.team_barrier(); } else { /// x = U^T b - KokkosBlas::TeamVectorGemv::invoke(member, one, um, bb, zero, xx); member.team_barrier(); diff --git a/batched/sparse/impl/KokkosBatched_GMRES_Serial_Impl.hpp b/batched/sparse/impl/KokkosBatched_GMRES_Serial_Impl.hpp index f8435754f6..b3c880de53 100644 --- a/batched/sparse/impl/KokkosBatched_GMRES_Serial_Impl.hpp +++ b/batched/sparse/impl/KokkosBatched_GMRES_Serial_Impl.hpp @@ -54,7 +54,7 @@ #include "KokkosBatched_Givens_Serial_Internal.hpp" #include "KokkosBatched_Trsm_Decl.hpp" #include "KokkosBatched_Identity.hpp" -#include "KokkosBlas2_serial_gemv_impl.hpp" +#include "KokkosBlas2_gemv.hpp" namespace KokkosBatched { diff --git a/blas/impl/KokkosBlas1_scal_spec.hpp b/blas/impl/KokkosBlas1_scal_spec.hpp index 1ec18c7469..7c54ef1933 100644 --- a/blas/impl/KokkosBlas1_scal_spec.hpp +++ b/blas/impl/KokkosBlas1_scal_spec.hpp @@ -397,4 +397,4 @@ struct Scal #include -#endif // KOKKOS_BLAS1_MV_IMPL_SCAL_HPP_ +#endif // KOKKOS_BLAS1_IMPL_SCAL_SPEC_HPP_ diff --git a/blas/impl/KokkosBlas1_set_impl.hpp b/blas/impl/KokkosBlas1_set_impl.hpp index a3870a2e15..ec95f5deb8 100644 --- a/blas/impl/KokkosBlas1_set_impl.hpp +++ b/blas/impl/KokkosBlas1_set_impl.hpp @@ -160,6 +160,43 @@ struct TeamVectorSetInternal { } }; +/// +/// ThreadVector Internal Impl +/// ================== +struct ThreadVectorSetInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, + const int m, const ScalarType alpha, + /* */ ValueType *KOKKOS_RESTRICT A, + const int as0) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, m), + [&](const int &i) { A[i * as0] = alpha; }); + // member.team_barrier(); + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, + const int m, const int n, + const ScalarType alpha, + /* */ ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1) { + if (m > n) { + Kokkos::parallel_for( + Kokkos::ThreadVectorRange(member, m), [&](const int &i) { + SerialSetInternal::invoke(n, alpha, A + i * as0, as1); + }); + } else { + Kokkos::parallel_for( + Kokkos::ThreadVectorRange(member, n), [&](const int &j) { + SerialSetInternal::invoke(m, alpha, A + j * as1, as0); + }); + } + // member.team_barrier(); + return 0; + } +}; + } // namespace Impl } // namespace KokkosBlas diff --git a/blas/impl/KokkosBlas1_team_scal_impl.hpp b/blas/impl/KokkosBlas1_team_scal_impl.hpp index 6f4fdf40b0..d89238906d 100644 --- a/blas/impl/KokkosBlas1_team_scal_impl.hpp +++ b/blas/impl/KokkosBlas1_team_scal_impl.hpp @@ -129,6 +129,43 @@ struct TeamVectorScaleInternal { } }; +/// +/// ThreadVector Internal Impl +/// ==================== +struct ThreadVectorScaleInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, + const int m, const ScalarType alpha, + /* */ ValueType *KOKKOS_RESTRICT A, + const int as0) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, m), + [&](const int &i) { A[i * as0] *= alpha; }); + // member.team_barrier(); + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, + const int m, const int n, + const ScalarType alpha, + /* */ ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1) { + if (m > n) { + Kokkos::parallel_for( + Kokkos::ThreadVectorRange(member, m), [&](const int &i) { + SerialScaleInternal::invoke(n, alpha, A + i * as0, as1); + }); + } else { + Kokkos::parallel_for( + Kokkos::ThreadVectorRange(member, n), [&](const int &j) { + SerialScaleInternal::invoke(m, alpha, A + j * as1, as0); + }); + } + // member.team_barrier(); + return 0; + } +}; + } // namespace Impl } // namespace KokkosBlas diff --git a/blas/impl/KokkosBlas2_serial_gemv_impl.hpp b/blas/impl/KokkosBlas2_serial_gemv_impl.hpp index 0d7f52702b..d00e39a728 100644 --- a/blas/impl/KokkosBlas2_serial_gemv_impl.hpp +++ b/blas/impl/KokkosBlas2_serial_gemv_impl.hpp @@ -49,21 +49,6 @@ #include "KokkosBlas_util.hpp" #include "KokkosBlas2_serial_gemv_internal.hpp" -namespace KokkosBlas { - -template -struct SerialGemv { - template - KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType /*alpha*/, - const AViewType & /*A*/, - const xViewType & /*x*/, - const ScalarType /*beta*/, - const yViewType & /*y*/); -}; - -} // namespace KokkosBlas - #include "KokkosBlas2_serial_gemv_tpl_spec_decl.hpp" namespace KokkosBlas { @@ -72,88 +57,26 @@ namespace KokkosBlas { /// Serial Impl /// =========== -/// -/// NT -/// - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemv::invoke( - const ScalarType alpha, const AViewType &A, const xViewType &x, - const ScalarType beta, const yViewType &y) { - return Impl::SerialGemvInternal::invoke( - A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(), A.stride_1(), - x.data(), x.stride_0(), beta, y.data(), y.stride_0()); -} - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemv::invoke( - const ScalarType alpha, const AViewType &A, const xViewType &x, - const ScalarType beta, const yViewType &y) { - return Impl::SerialGemvInternal::invoke( - A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(), A.stride_1(), - x.data(), x.stride_0(), beta, y.data(), y.stride_0()); -} - -/// -/// T -/// - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemv::invoke( - const ScalarType alpha, const AViewType &A, const xViewType &x, - const ScalarType beta, const yViewType &y) { - return Impl::SerialGemvInternal::invoke( - A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), A.stride_0(), - x.data(), x.stride_0(), beta, y.data(), y.stride_0()); -} - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemv::invoke( - const ScalarType alpha, const AViewType &A, const xViewType &x, - const ScalarType beta, const yViewType &y) { - return Impl::SerialGemvInternal::invoke( - A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), A.stride_0(), - x.data(), x.stride_0(), beta, y.data(), y.stride_0()); -} - -/// -/// CT -/// - -template <> -template -KOKKOS_INLINE_FUNCTION int -SerialGemv::invoke( - const ScalarType alpha, const AViewType &A, const xViewType &x, - const ScalarType beta, const yViewType &y) { - return Impl::SerialGemvInternal::invoke( - Impl::OpConj(), A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), - A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0()); -} - -template <> +template template -KOKKOS_INLINE_FUNCTION int -SerialGemv::invoke( +KOKKOS_INLINE_FUNCTION int SerialGemv::invoke( const ScalarType alpha, const AViewType &A, const xViewType &x, const ScalarType beta, const yViewType &y) { - return Impl::SerialGemvInternal::invoke( - Impl::OpConj(), A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), - A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0()); + static_assert(std::is_same::value || + std::is_same::value || + std::is_same::value, + "Algorithm not supported"); + + using TransA = Impl::MatrixModeInfo; + const auto ae0 = TransA::extent(A, 0); + const auto ae1 = TransA::extent(A, 1); + const auto as0 = TransA::stride_0(A); + const auto as1 = TransA::stride_1(A); + + return Impl::SerialGemvInternal::invoke( + ae0, ae1, alpha, A.data(), as0, as1, x.data(), x.stride_0(), beta, + y.data(), y.stride_0()); } } // namespace KokkosBlas diff --git a/blas/impl/KokkosBlas2_serial_gemv_inner_multiple_dot.hpp b/blas/impl/KokkosBlas2_serial_gemv_inner_multiple_dot.hpp index 0e9d015dc8..d532b5afde 100644 --- a/blas/impl/KokkosBlas2_serial_gemv_inner_multiple_dot.hpp +++ b/blas/impl/KokkosBlas2_serial_gemv_inner_multiple_dot.hpp @@ -46,24 +46,11 @@ /// \author Kyungjoo Kim (kyukim@sandia.gov) +#include "KokkosBlas_util.hpp" + namespace KokkosBlas { namespace Impl { -struct OpID { - template - KOKKOS_INLINE_FUNCTION ValueType operator()(ValueType v) const { - return v; - } -}; - -struct OpConj { - template - KOKKOS_INLINE_FUNCTION ValueType operator()(ValueType v) const { - using KAT = Kokkos::Details::ArithTraits; - return KAT::conj(v); - } -}; - template struct InnerMultipleDotProduct { const int _as0, _as1, _xs0, _ys0; diff --git a/blas/impl/KokkosBlas2_team_gemv_impl.hpp b/blas/impl/KokkosBlas2_team_gemv_impl.hpp index a4cf662cc9..e452155ab2 100644 --- a/blas/impl/KokkosBlas2_team_gemv_impl.hpp +++ b/blas/impl/KokkosBlas2_team_gemv_impl.hpp @@ -45,197 +45,77 @@ #ifndef KOKKOSBLAS2_TEAM_GEMV_IMPL_HPP_ #define KOKKOSBLAS2_TEAM_GEMV_IMPL_HPP_ -#include "KokkosBlas1_set_impl.hpp" -#include "KokkosBlas1_team_scal_impl.hpp" -#include "KokkosBlas2_serial_gemv_inner_multiple_dot.hpp" +#include namespace KokkosBlas { -namespace Impl { -template -struct TeamGemvInternal { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, OpA op, const int m, const int n, - const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, - const int as0, const int as1, const ValueXType *KOKKOS_RESTRICT x, - const int xs0, const ScalarType beta, - /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0); - - // default OpA = OpID - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const int m, const int n, - const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, - const int as0, const int as1, const ValueXType *KOKKOS_RESTRICT x, - const int xs0, const ScalarType beta, - /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0) { - return invoke(member, OpID{}, m, n, alpha, A, as0, as1, x, xs0, beta, y, - ys0); - } -}; - -template -struct TeamVectorGemvInternal { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, OpA op, const int m, const int n, - const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, - const int as0, const int as1, const ValueXType *KOKKOS_RESTRICT x, - const int xs0, const ScalarType beta, - /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0); - - // default OpA = OpID - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType &member, const int m, const int n, - const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, - const int as0, const int as1, const ValueXType *KOKKOS_RESTRICT x, - const int xs0, const ScalarType beta, - /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0) { - return invoke(member, OpID{}, m, n, alpha, A, as0, as1, x, xs0, beta, y, - ys0); - } -}; - -/// -/// Team Internal Impl -/// ==================== - -template <> -template -KOKKOS_INLINE_FUNCTION int TeamGemvInternal::invoke( - const MemberType &member, OpA op, const int m, const int n, - const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueXType *KOKKOS_RESTRICT x, const int xs0, - const ScalarType beta, - /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0) { - const ScalarType one(1.0), zero(0.0); - - // y = beta y + alpha A x - // y (m), A(m x n), B(n) - - if (beta == zero) - KokkosBlas::Impl::TeamSetInternal::invoke(member, m, zero, y, ys0); - else if (beta != one) - KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, beta, y, ys0); - - if (alpha != zero) { - if (m <= 0 || n <= 0) return 0; - - if (beta != one) member.team_barrier(); - - Kokkos::parallel_for(Kokkos::TeamThreadRange(member, 0, m), - [&](const int &i) { - ValueYType t(0); - const ValueAType *KOKKOS_RESTRICT tA = (A + i * as0); -#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) -#pragma unroll -#endif - for (int j = 0; j < n; ++j) - t += op(tA[j * as1]) * x[j * xs0]; - y[i * ys0] += alpha * t; - }); - } - return 0; +template +template +KOKKOS_INLINE_FUNCTION int TeamGemv::invoke( + const MemberType& member, const ScalarType alpha, const AViewType& A, + const xViewType& x, const ScalarType beta, const yViewType& y) { + static_assert(std::is_same::value || + std::is_same::value, + "Algorithm not supported"); + static_assert(AViewType::Rank == 2, + "KokkosBlas::TeamGemv requires rank-2 A matrix"); + + using TransA = Impl::MatrixModeInfo; + const auto ae0 = TransA::extent(A, 0); + const auto ae1 = TransA::extent(A, 1); + const auto as0 = TransA::stride_0(A); + const auto as1 = TransA::stride_1(A); + + return Impl::TeamGemvInternal::invoke( + member, ae0, ae1, alpha, A.data(), as0, as1, x.data(), x.stride_0(), beta, + y.data(), y.stride_0()); } -template <> -template -KOKKOS_INLINE_FUNCTION int TeamGemvInternal::invoke( - const MemberType &member, OpA /* op */, const int m, const int n, - const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueXType *KOKKOS_RESTRICT x, const int xs0, - const ScalarType beta, - /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0) { - const ScalarType one(1.0), zero(0.0); - - // y = beta y + alpha A x - // y (m), A(m x n), B(n) - - constexpr int mbAlgo = Algo::Gemv::Blocked::mb(); - - if (beta == zero) - KokkosBlas::Impl::TeamSetInternal::invoke(member, m, zero, y, ys0); - else if (beta != one) - KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, beta, y, ys0); - - if (alpha != zero) { - if (m <= 0 || n <= 0) return 0; - - if (beta != one) member.team_barrier(); - - KokkosBlas::Impl::InnerMultipleDotProduct inner(as0, as1, xs0, ys0); - const int tsize = member.team_size(); - const int mb_a = m / tsize + (m % tsize > 0), mb_b = mbAlgo; - // Made this non-const in order to WORKAROUND issue #349 - int mb = mb_a < mb_b ? mb_a : mb_b, mp = m % mb; - - Kokkos::parallel_for(Kokkos::TeamThreadRange(member, (m / mb) + (mp > 0)), - [&](const int &ii) { - const int i = ii * mb; - inner.serial_invoke(alpha, A + i * as0, x, - (i + mb) > m ? (m - i) : mb, - n, y + i * ys0); - }); - member.team_barrier(); - } - - return 0; +template +template +KOKKOS_INLINE_FUNCTION int TeamVectorGemv::invoke( + const MemberType& member, const ScalarType alpha, const AViewType& A, + const xViewType& x, const ScalarType beta, const yViewType& y) { + static_assert(std::is_same::value, + "Algorithm not supported"); + static_assert(AViewType::Rank == 2, + "KokkosBlas::TeamVectorGemv requires rank-2 A matrix"); + + using TransA = Impl::MatrixModeInfo; + const auto ae0 = TransA::extent(A, 0); + const auto ae1 = TransA::extent(A, 1); + const auto as0 = TransA::stride_0(A); + const auto as1 = TransA::stride_1(A); + + return Impl::TeamVectorGemvInternal::invoke( + member, ae0, ae1, alpha, A.data(), as0, as1, x.data(), x.stride_0(), beta, + y.data(), y.stride_0()); } -/// -/// TeamVector Internal Impl -/// ==================== - -template <> -template -KOKKOS_INLINE_FUNCTION int -TeamVectorGemvInternal::invoke( - const MemberType &member, OpA op, const int m, const int n, - const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueXType *KOKKOS_RESTRICT x, const int xs0, - const ScalarType beta, - /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0) { - const ScalarType one(1.0), zero(0.0); - - // y = beta y + alpha A x - // y (m), A(m x n), B(n) - - if (beta == zero) - KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, zero, y, ys0); - else if (beta != one) - KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, beta, y, ys0); - - if (alpha != zero) { - if (m <= 0 || n <= 0) return 0; - - if (beta != one) member.team_barrier(); - - Kokkos::parallel_for(Kokkos::TeamThreadRange(member, m), [&](const int &i) { - ValueYType t(0); - const ValueAType *KOKKOS_RESTRICT tA = (A + i * as0); - Kokkos::parallel_reduce( - Kokkos::ThreadVectorRange(member, n), - [&](const int &j, ValueYType &update) { - update += op(tA[j * as1]) * x[j * xs0]; - }, - t); - Kokkos::single(Kokkos::PerThread(member), - [&]() { y[i * ys0] += alpha * t; }); - }); - } - return 0; +template +template +KOKKOS_INLINE_FUNCTION int ThreadVectorGemv::invoke( + const MemberType& member, const ScalarType alpha, const AViewType& A, + const xViewType& x, const ScalarType beta, const yViewType& y) { + static_assert(std::is_same::value, + "Algorithm not supported"); + static_assert(AViewType::Rank == 2, + "Batched TeamVectorGemv requires rank-2 A matrix"); + + using TransA = Impl::MatrixModeInfo; + const auto ae0 = TransA::extent(A, 0); + const auto ae1 = TransA::extent(A, 1); + const auto as0 = TransA::stride_0(A); + const auto as1 = TransA::stride_1(A); + + return Impl::ThreadVectorGemvInternal::invoke( + member, ae0, ae1, alpha, A.data(), as0, as1, x.data(), x.stride_0(), beta, + y.data(), y.stride_0()); } -} // namespace Impl } // namespace KokkosBlas #endif diff --git a/blas/impl/KokkosBlas2_team_gemv_internal.hpp b/blas/impl/KokkosBlas2_team_gemv_internal.hpp new file mode 100644 index 0000000000..7262ca93d3 --- /dev/null +++ b/blas/impl/KokkosBlas2_team_gemv_internal.hpp @@ -0,0 +1,317 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef KOKKOSBLAS2_TEAM_GEMV_INTERNAL_HPP_ +#define KOKKOSBLAS2_TEAM_GEMV_INTERNAL_HPP_ + +#include "KokkosBlas1_set_impl.hpp" +#include "KokkosBlas1_team_scal_impl.hpp" +#include "KokkosBlas2_serial_gemv_inner_multiple_dot.hpp" + +namespace KokkosBlas { +namespace Impl { + +template +struct TeamGemvInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke( + const MemberType &member, OpA op, const int m, const int n, + const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueXType *KOKKOS_RESTRICT x, + const int xs0, const ScalarType beta, + /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0); + + // default OpA = OpID + template + KOKKOS_INLINE_FUNCTION static int invoke( + const MemberType &member, const int m, const int n, + const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueXType *KOKKOS_RESTRICT x, + const int xs0, const ScalarType beta, + /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0) { + return invoke(member, OpID{}, m, n, alpha, A, as0, as1, x, xs0, beta, y, + ys0); + } +}; + +template +struct TeamVectorGemvInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke( + const MemberType &member, OpA op, const int m, const int n, + const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueXType *KOKKOS_RESTRICT x, + const int xs0, const ScalarType beta, + /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0); + + // default OpA = OpID + template + KOKKOS_INLINE_FUNCTION static int invoke( + const MemberType &member, const int m, const int n, + const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueXType *KOKKOS_RESTRICT x, + const int xs0, const ScalarType beta, + /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0) { + return invoke(member, OpID{}, m, n, alpha, A, as0, as1, x, xs0, beta, y, + ys0); + } +}; + +template +struct ThreadVectorGemvInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke( + const MemberType &member, OpA op, const int m, const int n, + const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueXType *KOKKOS_RESTRICT x, + const int xs0, const ScalarType beta, + /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0); + + // default OpA = OpID + template + KOKKOS_INLINE_FUNCTION static int invoke( + const MemberType &member, const int m, const int n, + const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueXType *KOKKOS_RESTRICT x, + const int xs0, const ScalarType beta, + /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0) { + return invoke(member, OpID{}, m, n, alpha, A, as0, as1, x, xs0, beta, y, + ys0); + } +}; + +/// +/// Team Internal Impl +/// ==================== + +template <> +template +KOKKOS_INLINE_FUNCTION int TeamGemvInternal::invoke( + const MemberType &member, OpA op, const int m, const int n, + const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueXType *KOKKOS_RESTRICT x, const int xs0, + const ScalarType beta, + /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0) { + const ScalarType one(1.0), zero(0.0); + + // y = beta y + alpha A x + // y (m), A(m x n), B(n) + + if (beta == zero) + KokkosBlas::Impl::TeamSetInternal::invoke(member, m, zero, y, ys0); + else if (beta != one) + KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, beta, y, ys0); + + if (alpha != zero) { + if (m <= 0 || n <= 0) return 0; + + if (beta != one) member.team_barrier(); + + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, 0, m), + [&](const int &i) { + ValueYType t(0); + const ValueAType *KOKKOS_RESTRICT tA = (A + i * as0); +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int j = 0; j < n; ++j) + t += op(tA[j * as1]) * x[j * xs0]; + y[i * ys0] += alpha * t; + }); + } + return 0; +} + +template <> +template +KOKKOS_INLINE_FUNCTION int TeamGemvInternal::invoke( + const MemberType &member, OpA /* op */, const int m, const int n, + const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueXType *KOKKOS_RESTRICT x, const int xs0, + const ScalarType beta, + /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0) { + const ScalarType one(1.0), zero(0.0); + + // y = beta y + alpha A x + // y (m), A(m x n), B(n) + + constexpr int mbAlgo = Algo::Gemv::Blocked::mb(); + + if (beta == zero) + KokkosBlas::Impl::TeamSetInternal::invoke(member, m, zero, y, ys0); + else if (beta != one) + KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, beta, y, ys0); + + if (alpha != zero) { + if (m <= 0 || n <= 0) return 0; + + if (beta != one) member.team_barrier(); + + KokkosBlas::Impl::InnerMultipleDotProduct inner(as0, as1, xs0, ys0); + const int tsize = member.team_size(); + const int mb_a = m / tsize + (m % tsize > 0), mb_b = mbAlgo; + // Made this non-const in order to WORKAROUND issue #349 + int mb = mb_a < mb_b ? mb_a : mb_b, mp = m % mb; + + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, (m / mb) + (mp > 0)), + [&](const int &ii) { + const int i = ii * mb; + inner.serial_invoke(alpha, A + i * as0, x, + (i + mb) > m ? (m - i) : mb, + n, y + i * ys0); + }); + member.team_barrier(); + } + + return 0; +} + +/// +/// TeamVector Internal Impl +/// ==================== + +template <> +template +KOKKOS_INLINE_FUNCTION int +TeamVectorGemvInternal::invoke( + const MemberType &member, OpA op, const int m, const int n, + const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueXType *KOKKOS_RESTRICT x, const int xs0, + const ScalarType beta, + /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0) { + const ScalarType one(1.0), zero(0.0); + + // y = beta y + alpha A x + // y (m), A(m x n), B(n) + + if (beta == zero) + KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, zero, y, ys0); + else if (beta != one) + KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, beta, y, ys0); + + if (alpha != zero) { + if (m <= 0 || n <= 0) return 0; + + if (beta != one) member.team_barrier(); + + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, m), [&](const int &i) { + ValueYType t(0); + const ValueAType *KOKKOS_RESTRICT tA = (A + i * as0); + Kokkos::parallel_reduce( + Kokkos::ThreadVectorRange(member, n), + [&](const int &j, ValueYType &update) { + update += op(tA[j * as1]) * x[j * xs0]; + }, + t); + Kokkos::single(Kokkos::PerThread(member), + [&]() { y[i * ys0] += alpha * t; }); + }); + } + return 0; +} + +/// +/// ThreadVector Internal Impl +/// ==================== + +template <> +template +KOKKOS_INLINE_FUNCTION int +ThreadVectorGemvInternal::invoke( + const MemberType &member, OpA op, const int m, const int n, + const ScalarType alpha, const ValueAType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueXType *KOKKOS_RESTRICT x, const int xs0, + const ScalarType beta, + /**/ ValueYType *KOKKOS_RESTRICT y, const int ys0) { + const ScalarType one(1.0), zero(0.0); + + // y = beta y + alpha A x + // y (m), A(m x n), B(n) + + constexpr int mbAlgo = Algo::Gemv::Blocked::mb(); + + if (beta == zero) + KokkosBlas::Impl::ThreadVectorSetInternal::invoke(member, m, zero, y, ys0); + else if (beta != one) + KokkosBlas::Impl::ThreadVectorScaleInternal::invoke(member, m, beta, y, + ys0); + + if (alpha != zero) { + if (m <= 0 || n <= 0) return 0; + + if (beta != one) member.team_barrier(); + + KokkosBlas::Impl::InnerMultipleDotProduct inner(as0, as1, xs0, ys0); + const int tsize = member.team_size(); + const int mb_a = m / tsize + (m % tsize > 0), mb_b = mbAlgo; + // Made this non-const in order to WORKAROUND issue #349 + int mb = mb_a < mb_b ? mb_a : mb_b, mp = m % mb; + + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, (m / mb) + (mp > 0)), + [&](const int &ii) { + const int i = ii * mb; + inner.serial_invoke(alpha, A + i * as0, x, + (i + mb) > m ? (m - i) : mb, + n, y + i * ys0); + }); + member.team_barrier(); + } + + return 0; +} + +} // namespace Impl +} // namespace KokkosBlas + +#endif diff --git a/blas/impl/KokkosBlas2_team_gemv_spec.hpp b/blas/impl/KokkosBlas2_team_gemv_spec.hpp deleted file mode 100644 index 92aac23f26..0000000000 --- a/blas/impl/KokkosBlas2_team_gemv_spec.hpp +++ /dev/null @@ -1,245 +0,0 @@ -/* -//@HEADER -// ************************************************************************ -// -// Kokkos v. 3.0 -// Copyright (2020) National Technology & Engineering -// Solutions of Sandia, LLC (NTESS). -// -// Under the terms of Contract DE-NA0003525 with NTESS, -// the U.S. Government retains certain rights in this software. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// 1. Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the Corporation nor the names of the -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE -// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// -// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) -// -// ************************************************************************ -//@HEADER -*/ - -#ifndef KOKKOSBLAS2_TEAM_GEMV_SPEC_HPP_ -#define KOKKOSBLAS2_TEAM_GEMV_SPEC_HPP_ - -#include -#include -#include -#include -#include - -namespace KokkosBlas { - -template -struct TeamGemv { - template - KOKKOS_INLINE_FUNCTION static int invoke(const MemberType& /*member*/, - const ScalarType /*alpha*/, - const AViewType& /*A*/, - const xViewType& /*x*/, - const ScalarType /*beta*/, - const yViewType& /*y*/); -}; - -template -struct TeamVectorGemv { - template - KOKKOS_INLINE_FUNCTION static int invoke(const MemberType& /*member*/, - const ScalarType /*alpha*/, - const AViewType& /*A*/, - const xViewType& /*x*/, - const ScalarType /*beta*/, - const yViewType& /*y*/); -}; - -/// -/// NT -/// - -template -struct TeamGemv { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType& member, const ScalarType alpha, const AViewType& A, - const xViewType& x, const ScalarType beta, const yViewType& y) { - static_assert(AViewType::Rank == 2, - "KokkosBlas::TeamGemv requires rank-2 A matrix"); - return Impl::TeamGemvInternal::invoke( - member, A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(), - A.stride_1(), x.data(), x.stride_0(), beta, y.data(), y.stride_0()); - } -}; - -template -struct TeamGemv { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType& member, const ScalarType alpha, const AViewType& A, - const xViewType& x, const ScalarType beta, const yViewType& y) { - static_assert(AViewType::Rank == 2, - "KokkosBlas::TeamGemv requires rank-2 A matrix"); - return Impl::TeamGemvInternal::invoke( - member, A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(), - A.stride_1(), x.data(), x.stride_0(), beta, y.data(), y.stride_0()); - } -}; - -/// -/// T -/// - -template -struct TeamGemv { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType& member, const ScalarType alpha, const AViewType& A, - const xViewType& x, const ScalarType beta, const yViewType& y) { - static_assert(AViewType::Rank == 2, - "BLAS TeamGemv requires rank-2 A matrix"); - return Impl::TeamGemvInternal::invoke( - member, A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), - A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0()); - } -}; - -template -struct TeamGemv { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType& member, const ScalarType alpha, const AViewType& A, - const xViewType& x, const ScalarType beta, const yViewType& y) { - static_assert(AViewType::Rank == 2, - "BLAS TeamGemv requires rank-2 A matrix"); - return Impl::TeamGemvInternal::invoke( - member, A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), - A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0()); - } -}; - -/// -/// CT -/// - -template -struct TeamGemv { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType& member, const ScalarType alpha, const AViewType& A, - const xViewType& x, const ScalarType beta, const yViewType& y) { - static_assert(AViewType::Rank == 2, - "BLAS TeamGemv requires rank-2 A matrix"); - return Impl::TeamGemvInternal::invoke( - member, Impl::OpConj{}, A.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), x.data(), x.stride_0(), beta, y.data(), - y.stride_0()); - } -}; - -template -struct TeamGemv { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType& member, const ScalarType alpha, const AViewType& A, - const xViewType& x, const ScalarType beta, const yViewType& y) { - static_assert(AViewType::Rank == 2, - "BLAS TeamGemv requires rank-2 A matrix"); - return Impl::TeamGemvInternal::invoke( - member, Impl::OpConj{}, A.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), x.data(), x.stride_0(), beta, y.data(), - y.stride_0()); - } -}; - -/// -/// NT -/// - -template -struct TeamVectorGemv { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType& member, const ScalarType alpha, const AViewType& A, - const xViewType& x, const ScalarType beta, const yViewType& y) { - static_assert(AViewType::Rank == 2, - "Batched TeamVectorGemv requires rank-2 A matrix"); - return Impl::TeamVectorGemvInternal::invoke( - member, A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(), - A.stride_1(), x.data(), x.stride_0(), beta, y.data(), y.stride_0()); - } -}; - -/// -/// T -/// - -template -struct TeamVectorGemv { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType& member, const ScalarType alpha, const AViewType& A, - const xViewType& x, const ScalarType beta, const yViewType& y) { - static_assert(AViewType::Rank == 2, - "Batched TeamVectorGemv requires rank-2 A matrix"); - return Impl::TeamVectorGemvInternal::invoke( - member, A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), - A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0()); - } -}; - -/// -/// CT -/// - -template -struct TeamVectorGemv { - template - KOKKOS_INLINE_FUNCTION static int invoke( - const MemberType& member, const ScalarType alpha, const AViewType& A, - const xViewType& x, const ScalarType beta, const yViewType& y) { - static_assert(AViewType::Rank == 2, - "Batched TeamVectorGemv requires rank-2 A matrix"); - return Impl::TeamVectorGemvInternal::invoke( - member, Impl::OpConj{}, A.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), x.data(), x.stride_0(), beta, y.data(), - y.stride_0()); - } -}; - -} // namespace KokkosBlas - -#endif diff --git a/blas/src/KokkosBlas2_serial_gemv.hpp b/blas/impl/KokkosBlas3_serial_gemm_impl.hpp similarity index 55% rename from blas/src/KokkosBlas2_serial_gemv.hpp rename to blas/impl/KokkosBlas3_serial_gemm_impl.hpp index cb568095b2..b367bfc279 100644 --- a/blas/src/KokkosBlas2_serial_gemv.hpp +++ b/blas/impl/KokkosBlas3_serial_gemm_impl.hpp @@ -42,47 +42,46 @@ //@HEADER */ -#ifndef KOKKOSBLAS2_SERIAL_GEMV_HPP_ -#define KOKKOSBLAS2_SERIAL_GEMV_HPP_ +#ifndef KOKKOSBLAS3_SERIAL_GEMM_IMPL_HPP_ +#define KOKKOSBLAS3_SERIAL_GEMM_IMPL_HPP_ -#include "KokkosBlas2_serial_gemv_impl.hpp" +#include "KokkosBlas3_serial_gemm_internal.hpp" +#include "KokkosBlas3_serial_gemm_tpl_spec_decl.hpp" #include "KokkosBlas_util.hpp" namespace KokkosBlas { -namespace Experimental { -template -void KOKKOS_INLINE_FUNCTION serial_gemv(const char trans, - const ScalarType& alpha, - const MatrixType& A, const XVector& x, - const ScalarType& beta, - const YVector& y) { - if (trans == 'N' || trans == 'n') { - using mode = KokkosBlas::Trans::NoTranspose; - KokkosBlas::SerialGemv::invoke(alpha, A, x, beta, y); - } else if (trans == 'T' || trans == 't') { - using mode = KokkosBlas::Trans::Transpose; - KokkosBlas::SerialGemv::invoke(alpha, A, x, beta, y); - } else if (trans == 'C' || trans == 'c') { - using mode = KokkosBlas::Trans::ConjTranspose; - KokkosBlas::SerialGemv::invoke(alpha, A, x, beta, y); - } else { - Kokkos::abort("Matrix mode not supported"); - } -} +/// +/// Serial Impl +/// =========== -// default AlgoTag -template -void KOKKOS_INLINE_FUNCTION serial_gemv(const char trans, - const ScalarType& alpha, - const MatrixType& A, const XVector& x, - const ScalarType& beta, - const YVector& y) { - serial_gemv(trans, alpha, A, x, beta, y); -} +template +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, + const ScalarType beta, const CViewType &C) { + // C = beta C + alpha opA(A) opB(B) + // C (m x n), A(m x k), B(k x n) + static_assert(std::is_same::value || + std::is_same::value || + std::is_same::value, + "Algorithm not supported"); -} // namespace Experimental + using OpA = typename Impl::MatrixModeInfo::Op; + using OpB = typename Impl::MatrixModeInfo::Op; + using TransA = Impl::MatrixModeInfo; + using TransB = Impl::MatrixModeInfo; + const auto ae1 = TransA::extent(A, 1); + const auto as0 = TransA::stride_0(A); + const auto as1 = TransA::stride_1(A); + const auto bs0 = TransB::stride_0(B); + const auto bs1 = TransB::stride_1(B); + + return Impl::SerialGemmInternal::invoke( + OpA{}, OpB{}, C.extent(0), C.extent(1), ae1, alpha, A.data(), as0, as1, + B.data(), bs0, bs1, beta, C.data(), C.stride_0(), C.stride_1()); +} } // namespace KokkosBlas #endif diff --git a/batched/dense/impl/KokkosBatched_InnerGemmFixA_Serial_Impl.hpp b/blas/impl/KokkosBlas3_serial_gemm_inner_fixa_impl.hpp similarity index 95% rename from batched/dense/impl/KokkosBatched_InnerGemmFixA_Serial_Impl.hpp rename to blas/impl/KokkosBlas3_serial_gemm_inner_fixa_impl.hpp index 17ea32aa74..1ee0c3d463 100644 --- a/batched/dense/impl/KokkosBatched_InnerGemmFixA_Serial_Impl.hpp +++ b/blas/impl/KokkosBlas3_serial_gemm_inner_fixa_impl.hpp @@ -1,12 +1,53 @@ -#ifndef __KOKKOSBATCHED_INNER_GEMM_FIX_A_SERIAL_IMPL_HPP__ -#define __KOKKOSBATCHED_INNER_GEMM_FIX_A_SERIAL_IMPL_HPP__ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef KOKKOSBLAS3_INNER_GEMM_FIXA_SERIAL_IMPL_HPP +#define KOKKOSBLAS3_INNER_GEMM_FIXA_SERIAL_IMPL_HPP /// \author Kyungjoo Kim (kyukim@sandia.gov) -#include "KokkosBatched_Util.hpp" -#include "KokkosBatched_InnerGemmFixA_Decl.hpp" - -namespace KokkosBatched { +namespace KokkosBlas { /// /// Inner kernel (5x5) @@ -1305,6 +1346,6 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<1, 1>::serial_invoke( return 0; } -} // namespace KokkosBatched +} // namespace KokkosBlas #endif diff --git a/batched/dense/impl/KokkosBatched_InnerGemmFixB_Serial_Impl.hpp b/blas/impl/KokkosBlas3_serial_gemm_inner_fixb_impl.hpp similarity index 94% rename from batched/dense/impl/KokkosBatched_InnerGemmFixB_Serial_Impl.hpp rename to blas/impl/KokkosBlas3_serial_gemm_inner_fixb_impl.hpp index b948b115f8..c33724dccc 100644 --- a/batched/dense/impl/KokkosBatched_InnerGemmFixB_Serial_Impl.hpp +++ b/blas/impl/KokkosBlas3_serial_gemm_inner_fixb_impl.hpp @@ -1,12 +1,53 @@ -#ifndef __KOKKOSBATCHED_INNER_GEMM_FIX_B_SERIAL_IMPL_HPP__ -#define __KOKKOSBATCHED_INNER_GEMM_FIX_B_SERIAL_IMPL_HPP__ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef KOKKOSBLAS3_INNER_GEMM_FIXB_SERIAL_IMPL_HPP +#define KOKKOSBLAS3_INNER_GEMM_FIXB_SERIAL_IMPL_HPP /// \author Kyungjoo Kim (kyukim@sandia.gov) -#include "KokkosBatched_Util.hpp" -#include "KokkosBatched_InnerGemmFixB_Decl.hpp" - -namespace KokkosBatched { +namespace KokkosBlas { /// /// Inner kernel (5x5) @@ -1272,6 +1313,7 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixB<0, 0>::serial_invoke( } return 0; } -} // namespace KokkosBatched + +} // namespace KokkosBlas #endif diff --git a/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp b/blas/impl/KokkosBlas3_serial_gemm_inner_fixc_impl.hpp similarity index 68% rename from batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp rename to blas/impl/KokkosBlas3_serial_gemm_inner_fixc_impl.hpp index 247f232dce..b235249c80 100644 --- a/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp +++ b/blas/impl/KokkosBlas3_serial_gemm_inner_fixc_impl.hpp @@ -1,22 +1,64 @@ -#ifndef __KOKKOSBATCHED_INNER_GEMM_FIX_C_SERIAL_IMPL_HPP__ -#define __KOKKOSBATCHED_INNER_GEMM_FIX_C_SERIAL_IMPL_HPP__ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef KOKKOSBLAS3_INNER_GEMM_FIXC_SERIAL_IMPL_HPP +#define KOKKOSBLAS3_INNER_GEMM_FIXC_SERIAL_IMPL_HPP /// \author Kyungjoo Kim (kyukim@sandia.gov) -#include "KokkosBatched_Util.hpp" -#include "KokkosBatched_InnerGemmFixC_Decl.hpp" - -namespace KokkosBatched { +namespace KokkosBlas { /// /// Inner kernel (5x5) /// ================== template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -35,16 +77,16 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - b_p3 = B[p * _bs0 + j3]; - a_4p = A[i4 + p * _as1]; - b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); + a_4p = opA(A[i4 + p * _as1]); + b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -103,10 +145,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 4>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -124,15 +167,15 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 4>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - b_p3 = B[p * _bs0 + j3]; - a_4p = A[i4 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); + a_4p = opA(A[i4 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -181,10 +224,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 4>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 3>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -200,14 +244,14 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 3>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - a_4p = A[i4 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + a_4p = opA(A[i4 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -246,10 +290,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 3>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 2>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -264,13 +309,13 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 2>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - a_3p = A[i3 + p * _as1]; - a_4p = A[i4 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + a_3p = opA(A[i3 + p * _as1]); + a_4p = opA(A[i4 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -299,10 +344,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 2>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 1>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -316,12 +362,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 1>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - a_2p = A[i2 + p * _as1]; - a_3p = A[i3 + p * _as1]; - a_4p = A[i4 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + a_2p = opA(A[i2 + p * _as1]); + a_3p = opA(A[i3 + p * _as1]); + a_4p = opA(A[i4 + p * _as1]); c_00 += a_0p * b_p0; c_10 += a_1p * b_p0; @@ -340,10 +386,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 1>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 5>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -362,15 +409,15 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 5>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - b_p3 = B[p * _bs0 + j3]; - /**/ b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); + /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -419,10 +466,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 5>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 5>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -440,14 +488,14 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 5>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; - /**/ b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); + /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -486,10 +534,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 5>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 5>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -506,13 +555,13 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 5>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; - /**/ b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); + /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -541,10 +590,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 5>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 5>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -561,12 +611,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 5>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - /**/ b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; - /**/ b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + /**/ b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); + /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -588,10 +638,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 5>::serial_invoke( /// ================== template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -610,14 +661,14 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - b_p3 = B[p * _bs0 + j3]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -658,10 +709,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 3>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -678,13 +730,13 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 3>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -717,10 +769,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 3>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 2>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -736,12 +789,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 2>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - a_3p = A[i3 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + a_3p = opA(A[i3 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -766,10 +819,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 2>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 1>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -783,11 +837,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 1>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - a_2p = A[i2 + p * _as1]; - a_3p = A[i3 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + a_2p = opA(A[i2 + p * _as1]); + a_3p = opA(A[i3 + p * _as1]); c_00 += a_0p * b_p0; c_10 += a_1p * b_p0; @@ -804,10 +858,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 1>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 4>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -825,13 +880,13 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 4>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -864,10 +919,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 4>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 4>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -885,12 +941,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 4>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -915,10 +971,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 4>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 4>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -935,11 +992,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 4>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - /**/ b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + /**/ b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -960,10 +1017,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 4>::serial_invoke( /// ================== template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -979,12 +1037,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1011,10 +1069,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 2>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -1029,11 +1088,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 2>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1054,10 +1113,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 2>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 1>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -1070,10 +1130,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 1>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - a_2p = A[i2 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + a_2p = opA(A[i2 + p * _as1]); c_00 += a_0p * b_p0; c_10 += a_1p * b_p0; @@ -1088,10 +1148,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 1>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 3>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -1107,11 +1168,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 3>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1131,10 +1192,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 3>::serial_invoke( return 0; } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 3>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -1149,10 +1211,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 3>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - /**/ b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + /**/ b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1171,10 +1233,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 3>::serial_invoke( /// ================== template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -1187,10 +1250,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1207,10 +1270,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 1>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -1222,9 +1286,9 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 1>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); c_00 += a_0p * b_p0; c_10 += a_1p * b_p0; @@ -1237,10 +1301,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 1>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 2>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -1252,9 +1317,9 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 2>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - /* */ b_p1 = B[p * _bs0 + j1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + /* */ b_p1 = opB(B[p * _bs0 + j1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; } @@ -1270,10 +1335,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 2>::serial_invoke( /// ================== template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (k <= 0) return 0; @@ -1285,8 +1351,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke( #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); c_00 += a_0p * b_p0; } C[0 * _cs0 + 0 * _cs1] += alpha * c_00; @@ -1295,37 +1361,38 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<0, 1>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int m, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int m, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (m <= 0 || k <= 0) return 0; switch (m) { case 5: { InnerGemmFixC<5, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 4: { InnerGemmFixC<4, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 3: { InnerGemmFixC<3, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 2: { InnerGemmFixC<2, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 1: { InnerGemmFixC<1, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } default: { @@ -1337,10 +1404,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<0, 1>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int m, const int n, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (m <= 0 || n <= 0 || k <= 0) return 0; if (!(m <= 5 && n <= 5)) @@ -1350,52 +1418,52 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke( switch (m * 10 + n) { case 55: { InnerGemmFixC<5, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 54: { InnerGemmFixC<5, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 53: { InnerGemmFixC<5, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 52: { InnerGemmFixC<5, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 51: { InnerGemmFixC<5, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 45: { InnerGemmFixC<4, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 35: { InnerGemmFixC<3, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 25: { InnerGemmFixC<2, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 15: { InnerGemmFixC<1, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } default: { InnerGemmFixC<4, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, m, n, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, m, n, k, C); break; } } @@ -1403,10 +1471,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int m, const int n, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (m <= 0 || n <= 0 || k <= 0) return 0; if (!(m <= 4 && n <= 4)) @@ -1416,42 +1485,42 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke( switch (m * 10 + n) { case 44: { InnerGemmFixC<4, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 43: { InnerGemmFixC<4, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 42: { InnerGemmFixC<4, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 41: { InnerGemmFixC<4, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 34: { InnerGemmFixC<3, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 24: { InnerGemmFixC<2, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 14: { InnerGemmFixC<1, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } default: { InnerGemmFixC<3, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, m, n, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, m, n, k, C); break; } } @@ -1459,10 +1528,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int m, const int n, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (m <= 0 || n <= 0 || k <= 0) return 0; if (!(m <= 3 && n <= 3)) @@ -1472,32 +1542,32 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke( switch (m * 10 + n) { case 33: { InnerGemmFixC<3, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 32: { InnerGemmFixC<3, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 31: { InnerGemmFixC<3, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 23: { InnerGemmFixC<2, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 13: { InnerGemmFixC<1, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } default: { InnerGemmFixC<2, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, m, n, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, m, n, k, C); break; } } @@ -1505,10 +1575,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int m, const int n, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (m <= 0 || n <= 0 || k <= 0) return 0; if (!(m <= 2 && n <= 2)) @@ -1518,22 +1589,22 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke( switch (m * 10 + n) { case 22: { InnerGemmFixC<2, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 21: { InnerGemmFixC<2, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 12: { InnerGemmFixC<1, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 11: { InnerGemmFixC<1, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } } @@ -1541,10 +1612,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke( - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, - const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, + OpA opA, OpB opB, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int m, const int n, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { if (m <= 0 || n <= 0 || k <= 0) return 0; if (!(m <= 1 && n <= 1)) @@ -1555,6 +1627,6 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke( ; } -} // namespace KokkosBatched +} // namespace KokkosBlas #endif diff --git a/blas/impl/KokkosBlas3_serial_gemm_internal.hpp b/blas/impl/KokkosBlas3_serial_gemm_internal.hpp new file mode 100644 index 0000000000..a0468f7361 --- /dev/null +++ b/blas/impl/KokkosBlas3_serial_gemm_internal.hpp @@ -0,0 +1,195 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef KOKKOSBLAS3_GEMM_SERIAL_INTERNAL_HPP +#define KOKKOSBLAS3_GEMM_SERIAL_INTERNAL_HPP + +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "KokkosBlas_util.hpp" +#include "KokkosBlas1_set_impl.hpp" +#include "KokkosBlas1_serial_scal_impl.hpp" +#include "KokkosBlas3_gemm_inner_fix.hpp" + +namespace KokkosBlas { +namespace Impl { + +/// +/// Serial Internal Impl +/// ==================== + +template +struct SerialGemmInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke( + OpA opA, OpB opB, const int m, const int n, const int k, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, + const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1); + + // default OpA=OpB=Impl::OpID + template + KOKKOS_INLINE_FUNCTION static int invoke( + const int m, const int n, const int k, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, + const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { + return invoke(OpID{}, OpID{}, m, n, k, alpha, A, as0, as1, B, bs0, bs1, + beta, C, cs0, cs1); + } +}; + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( + OpA opA, OpB opB, const int m, const int n, const int k, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, + const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { + // C = beta C + alpha opA(A) opB(B) + // C (m x n), A(m x k), B(k x n) + + const ScalarType one(1.0), zero(0.0); + + if (beta == zero) + KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, C, cs0, cs1); + else if (beta != one) + KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1); + + if (alpha != zero) { + if (m <= 0 || n <= 0 || k <= 0) return 0; + + ValueType *KOKKOS_RESTRICT pC = C; + for (int p = 0; p < k; ++p) { + const ValueType *KOKKOS_RESTRICT pA = A + p * as1; + const ValueType *KOKKOS_RESTRICT pB = B + p * bs0; + for (int i = 0; i < m; ++i) { + const ValueType tA(alpha * opA(pA[i * as0])); +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int j = 0; j < n; ++j) + pC[i * cs0 + j * cs1] += tA * opB(pB[j * bs1]); + } + } + } + return 0; +} + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( + OpA opA, OpB opB, const int m, const int n, const int k, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, + const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { + // C = beta C + alpha A B + // C (m x n), A(m x k), B(k x n) + + constexpr int mbAlgo = Algo::Gemm::Blocked::mb(); + constexpr int nbAlgo = Algo::Gemm::Blocked::mb(); + + const ScalarType one(1.0), zero(0.0); + + if (beta == zero) + KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, C, cs0, cs1); + else if (beta != one) + KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1); + + if (alpha != zero) { + if (m <= 0 || n <= 0 || k <= 0) return 0; + const ValueType alpha_value(alpha); + + KokkosBlas::InnerGemmFixC inner(as0, as1, bs0, bs1, cs0, + cs1); + auto gemm = [&](const int ib, const int jb, const int pb, + const ValueType *KOKKOS_RESTRICT AA, + const ValueType *KOKKOS_RESTRICT BB, + /**/ ValueType *KOKKOS_RESTRICT CC) { + const int mb = mbAlgo, nb = nbAlgo; + for (int i = 0; i < ib; i += mb) + for (int j = 0; j < jb; j += nb) + inner.serial_invoke(opA, opB, alpha_value, AA + i * as0, BB + j * bs1, + (i + mb) > ib ? (ib - i) : mb, + (j + nb) > jb ? (jb - j) : nb, pb, + CC + i * cs0 + j * cs1); + }; + + const bool is_small = true; //(m*n*k <= 64*64*64); + if (is_small) { + gemm(m, n, k, A, B, C); + } else { + // // cache blocking + // const int + // nc = nb*10, kc = mb*4, mc = mb*4; + + // for (int jj=0;jj +template +KOKKOS_INLINE_FUNCTION int TeamGemm::invoke( + const MemberType &member, const ScalarType alpha, const AViewType &A, + const BViewType &B, const ScalarType beta, const CViewType &C) { + // C = beta C + alpha A B + // C (m x n), A(m x k), B(k x n) + + static_assert(std::is_same::value || + std::is_same::value, + "Algorithm not supported"); + + using OpA = typename Impl::MatrixModeInfo::Op; + using OpB = typename Impl::MatrixModeInfo::Op; + using TransA = Impl::MatrixModeInfo; + using TransB = Impl::MatrixModeInfo; + const auto ae1 = TransA::extent(A, 1); + const auto as0 = TransA::stride_0(A); + const auto as1 = TransA::stride_1(A); + const auto bs0 = TransB::stride_0(B); + const auto bs1 = TransB::stride_1(B); + + return Impl::TeamGemmInternal::invoke( + OpA{}, OpB{}, member, C.extent(0), C.extent(1), ae1, alpha, A.data(), as0, + as1, B.data(), bs0, bs1, beta, C.data(), C.stride_0(), C.stride_1()); +} + +/// +/// Implemented: +/// NT/NT, T/NT, NT/T, T/T +/// +/// Not yet implemented (ConjTranspose) +/// CT/NT, NT/CT, CT/CT +/// + +template +template +KOKKOS_INLINE_FUNCTION int +TeamVectorGemm::invoke( + const MemberType &member, const ScalarType alpha, const AViewType &A, + const BViewType &B, const ScalarType beta, const CViewType &C) { + // C = beta C + alpha A B + // C (m x n), A(m x k), B(k x n) + + static_assert(std::is_same::value, + "Algorithm not supported"); + + using OpA = typename Impl::MatrixModeInfo::Op; + using OpB = typename Impl::MatrixModeInfo::Op; + using TransA = Impl::MatrixModeInfo; + using TransB = Impl::MatrixModeInfo; + const auto ae1 = TransA::extent(A, 1); + const auto as0 = TransA::stride_0(A); + const auto as1 = TransA::stride_1(A); + const auto bs0 = TransB::stride_0(B); + const auto bs1 = TransB::stride_1(B); + + return Impl::TeamVectorGemmInternal::invoke( + OpA{}, OpB{}, member, C.extent(0), C.extent(1), ae1, alpha, A.data(), as0, + as1, B.data(), bs0, bs1, beta, C.data(), C.stride_0(), C.stride_1()); +} + +/// +/// Implemented: +/// NT/NT, T/NT, NT/T, T/T +/// +/// Not yet implemented (ConjTranspose) +/// CT/NT, NT/CT, CT/CT +/// + +template +template +KOKKOS_INLINE_FUNCTION int +ThreadVectorGemm::invoke( + const MemberType &member, const ScalarType alpha, const AViewType &A, + const BViewType &B, const ScalarType beta, const CViewType &C) { + // C = beta C + alpha A B + // C (m x n), A(m x k), B(k x n) + + static_assert(std::is_same::value, + "Algorithm not supported"); + + using TransA = Impl::MatrixModeInfo; + using TransB = Impl::MatrixModeInfo; + const auto ae1 = TransA::extent(A, 1); + const auto as0 = TransA::stride_0(A); + const auto as1 = TransA::stride_1(A); + const auto bs0 = TransB::stride_0(B); + const auto bs1 = TransB::stride_1(B); + + return Impl::ThreadVectorGemmInternal::invoke( + member, C.extent(0), C.extent(1), ae1, alpha, A.data(), as0, as1, + B.data(), bs0, bs1, beta, C.data(), C.stride_0(), C.stride_1()); +} + +} // namespace KokkosBlas + +#endif diff --git a/blas/impl/KokkosBlas3_team_gemm_inner_fixc_impl.hpp b/blas/impl/KokkosBlas3_team_gemm_inner_fixc_impl.hpp new file mode 100644 index 0000000000..e784ff0033 --- /dev/null +++ b/blas/impl/KokkosBlas3_team_gemm_inner_fixc_impl.hpp @@ -0,0 +1,98 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef KOKKOSBLAS3_INNER_GEMM_FIXC_TEAM_IMPL_HPP +#define KOKKOSBLAS3_INNER_GEMM_FIXC_TEAM_IMPL_HPP + +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "KokkosBatched_Util.hpp" + +namespace KokkosBlas { + +template +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC::team_invoke( + const MemberType &member, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int k, + /**/ ValueType *KOKKOS_RESTRICT C) { + Kokkos::parallel_for( + Kokkos::TeamThreadRange(member, 0, mb * nb), [&](const int &ij) { + const int i = ij / nb, j = ij % nb; + + const ValueType *KOKKOS_RESTRICT pA = A + i * _as0, + *KOKKOS_RESTRICT pB = B + j * _bs1; + + ValueType c = 0; + for (int p = 0; p < k; ++p) c += pA[p * _as1] * pB[p * _bs0]; + C[i * _cs0 + j * _cs1] += alpha * c; + }); + return 0; +} + +template +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC::team_invoke( + const MemberType &member, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, + const int m, const int n, const int k, + /**/ ValueType *KOKKOS_RESTRICT C) { + Kokkos::parallel_for( + Kokkos::TeamThreadRange(member, 0, m * n), [&](const int &ij) { + const int i = ij / n, j = ij % n; + + const ValueType *KOKKOS_RESTRICT pA = A + i * _as0, + *KOKKOS_RESTRICT pB = B + j * _bs1; + + ValueType c = 0; + for (int p = 0; p < k; ++p) c += pA[p * _as1] * pB[p * _bs0]; + C[i * _cs0 + j * _cs1] += alpha * c; + }); + return 0; +} + +} // namespace KokkosBlas + +#endif diff --git a/blas/impl/KokkosBlas3_team_gemm_internal.hpp b/blas/impl/KokkosBlas3_team_gemm_internal.hpp new file mode 100644 index 0000000000..61e30382fe --- /dev/null +++ b/blas/impl/KokkosBlas3_team_gemm_internal.hpp @@ -0,0 +1,412 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef KOKKOSBLAS3_TEAM_GEMM_INTERNAL_HPP_ +#define KOKKOSBLAS3_TEAM_GEMM_INTERNAL_HPP_ + +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "KokkosBatched_Util.hpp" +#include "KokkosBlas1_set_impl.hpp" +#include "KokkosBlas1_team_scal_impl.hpp" +#include "KokkosBlas3_serial_gemm_inner_fixc_impl.hpp" +#include "KokkosKernels_ExecSpaceUtils.hpp" + +namespace KokkosBlas { +namespace Impl { + +/// +/// Team Internal Impl +/// ==================== + +template +struct TeamGemmInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke( + OpA opA, OpB opB, const MemberType &member, const int m, const int n, + const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, + const int bs0, const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1); + + // default OpA=OpB=Impl::OpID + template + KOKKOS_INLINE_FUNCTION static int invoke( + const MemberType &member, const int m, const int n, const int k, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, + const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { + return invoke(OpID{}, OpID{}, member, m, n, k, alpha, A, as0, as1, B, bs0, + bs1, beta, C, cs0, cs1); + } +}; + +template <> +template +KOKKOS_INLINE_FUNCTION int TeamGemmInternal::invoke( + OpA opA, OpB opB, const MemberType &member, const int m, const int n, + const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, + const int bs0, const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { + // C = beta C + alpha A B + // C (m x n), A(m x k), B(k x n) + + const ScalarType one(1.0), zero(0.0); + + if (beta == zero) + KokkosBlas::Impl::TeamSetInternal::invoke(member, m, n, zero, C, cs0, cs1); + else if (beta != one) + KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, n, beta, C, cs0, + cs1); + + if (alpha != ScalarType(0.0)) { + if (m <= 0 || n <= 0 || k <= 0) return 0; + + if (beta != one) member.team_barrier(); + + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, 0, m * n), + [&](const int &ij) { + // assume layout right for batched computation + const int i = ij / n, j = ij % n; + const ValueType *KOKKOS_RESTRICT pA = A + i * as0; + const ValueType *KOKKOS_RESTRICT pB = B + j * bs1; + + ValueType c = ValueType(0); + for (int p = 0; p < k; ++p) + c += opA(pA[p * as1]) * opB(pB[p * bs0]); + C[i * cs0 + j * cs1] += alpha * c; + }); + } + return 0; +} + +template <> +template +KOKKOS_INLINE_FUNCTION int TeamGemmInternal::invoke( + OpA opA, OpB opB, const MemberType &member, const int m, const int n, + const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, + const int bs0, const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { + // C = beta C + alpha A B + // C (m x n), A(m x k), B(k x n) + + constexpr int mbAlgo = Algo::Gemm::Blocked::mb(); + constexpr int nbAlgo = Algo::Gemm::Blocked::mb(); + + const ScalarType one(1.0), zero(0.0); + + if (beta == zero) + KokkosBlas::Impl::TeamSetInternal::invoke(member, m, n, zero, C, cs0, cs1); + else if (beta != one) + KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, n, beta, C, cs0, + cs1); + + if (alpha != ScalarType(0.0)) { + if (m <= 0 || n <= 0 || k <= 0) return 0; + + if (beta != one) member.team_barrier(); + + /// + /// GPU case: team size is large and blocksize (mb,nb) is small + KokkosBlas::InnerGemmFixC inner(as0, as1, bs0, bs1, cs0, + cs1); + auto gemm = [&](const int ib, const int jb, const int pb, + const ValueType *KOKKOS_RESTRICT AA, + const ValueType *KOKKOS_RESTRICT BB, + /**/ ValueType *KOKKOS_RESTRICT CC) { + // Made this non-const in order to WORKAROUND issue #349 + int mb = mbAlgo, mp = (ib % mb), mq = (ib / mb) + (mp > 0), nb = nbAlgo, + np = (jb % nb), nq = (jb / nb) + (np > 0); + + // square tiling + Kokkos::parallel_for( + Kokkos::TeamThreadRange(member, mq * nq), [&](const int &ij) { + int i, j; + // note: the condition is constexpr + if (KokkosKernels::Impl::kk_is_gpu_exec_space< + typename MemberType::execution_space>()) { + i = ij % mq * mb; + j = ij / mq * nb; + } else { + i = ij / nq * mb; + j = ij % nq * nb; + } + inner.serial_invoke(opA, opB, alpha, AA + i * as0, BB + j * bs1, + (i + mb) > ib ? mp : mb, + (j + nb) > jb ? np : nb, pb, + CC + i * cs0 + j * cs1); + }); + }; + + const bool is_small = true; //(m*n*k <= 64*64*64); + if (is_small) { + gemm(m, n, k, A, B, C); + } else { + // // cache blocking + // const int + // nc = nb*10, kc = mb*4, mc = mb*4; + + // for (int jj=0;jj +struct TeamVectorGemmInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke( + OpA opA, OpB opB, const MemberType &member, const int m, const int n, + const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, + const int bs0, const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1); + + // default OpA=OpB=Impl::OpID + template + KOKKOS_INLINE_FUNCTION static int invoke( + const MemberType &member, const int m, const int n, const int k, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, + const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { + return invoke(OpID{}, OpID{}, member, m, n, k, alpha, A, as0, as1, B, bs0, + bs1, beta, C, cs0, cs1); + } +}; + +template <> +template +KOKKOS_INLINE_FUNCTION int +TeamVectorGemmInternal::invoke( + OpA opA, OpB opB, const MemberType &member, const int m, const int n, + const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, + const int bs0, const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { + // C = beta C + alpha A B + // C (m x n), A(m x k), B(k x n) + + const ScalarType one(1.0), zero(0.0); + + if (beta == zero) + KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, n, zero, C, cs0, + cs1); + else if (beta != one) + KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, n, beta, C, + cs0, cs1); + + if (alpha != ScalarType(0.0)) { + if (m <= 0 || n <= 0 || k <= 0) return 0; + + if (beta != one) member.team_barrier(); + + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, m), [&](const int &i) { + const ValueType *KOKKOS_RESTRICT pA = A + i * as0; + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, n), + [&](const int &j) { + const ValueType *KOKKOS_RESTRICT pB = B + j * bs1; + + ValueType c = ValueType(0); + for (int p = 0; p < k; ++p) + c += opA(pA[p * as1]) * opB(pB[p * bs0]); + C[i * cs0 + j * cs1] += alpha * c; + }); + }); + } + return 0; +} + +/// +/// ThreadVector Internal Impl +/// ==================== + +template +struct ThreadVectorGemmInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke( + const MemberType &member, const int m, const int n, const int k, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, + const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1); +}; + +template <> +template +KOKKOS_INLINE_FUNCTION int +ThreadVectorGemmInternal::invoke( + const MemberType &member, const int m, const int n, const int k, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, + const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { + // C = beta C + alpha A B + // C (m x n), A(m x k), B(k x n) + + const ScalarType one(1.0), zero(0.0); + + if (beta == zero) + KokkosBlas::Impl::ThreadVectorSetInternal::invoke(member, m, n, zero, C, + cs0, cs1); + else if (beta != one) + KokkosBlas::Impl::ThreadVectorScaleInternal::invoke(member, m, n, beta, C, + cs0, cs1); + + if (alpha != ScalarType(0.0)) { + if (m <= 0 || n <= 0 || k <= 0) return 0; + + if (beta != one) member.team_barrier(); + + Kokkos::parallel_for( + Kokkos::ThreadVectorRange(member, 0, m * n), [&](const int &ij) { + // assume layout right for batched computation + const int i = ij / n, j = ij % n; + const ValueType *KOKKOS_RESTRICT pA = A + i * as0, + *KOKKOS_RESTRICT pB = B + j * bs1; + + ValueType c = ValueType(0); + for (int p = 0; p < k; ++p) c += pA[p * as1] * pB[p * bs0]; + C[i * cs0 + j * cs1] += alpha * c; + }); + } + return 0; +} + +template <> +template +KOKKOS_INLINE_FUNCTION int +ThreadVectorGemmInternal::invoke( + const MemberType &member, const int m, const int n, const int k, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, + const int bs1, const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { + // C = beta C + alpha A B + // C (m x n), A(m x k), B(k x n) + + constexpr int mbAlgo = Algo::Gemm::Blocked::mb(); + constexpr int nbAlgo = Algo::Gemm::Blocked::mb(); + + const ScalarType one(1.0), zero(0.0); + + if (beta == zero) + KokkosBlas::Impl::ThreadVectorSetInternal::invoke(member, m, n, zero, C, + cs0, cs1); + else if (beta != one) + KokkosBlas::Impl::ThreadVectorScaleInternal::invoke(member, m, n, beta, C, + cs0, cs1); + + if (alpha != ScalarType(0.0)) { + if (m <= 0 || n <= 0 || k <= 0) return 0; + + if (beta != one) member.team_barrier(); + + /// + /// GPU case: team size is large and blocksize (mb,nb) is small + KokkosBlas::InnerGemmFixC inner(as0, as1, bs0, bs1, cs0, + cs1); + auto gemm = [&](const int ib, const int jb, const int pb, + const ValueType *KOKKOS_RESTRICT AA, + const ValueType *KOKKOS_RESTRICT BB, + /**/ ValueType *KOKKOS_RESTRICT CC) { + // Made this non-const in order to WORKAROUND issue #349 + int mb = mbAlgo, mp = (ib % mb), mq = (ib / mb) + (mp > 0), nb = nbAlgo, + np = (jb % nb), nq = (jb / nb) + (np > 0); + + // square tiling + Kokkos::parallel_for( + Kokkos::ThreadVectorRange(member, mq * nq), [&](const int &ij) { + int i, j; + // note: the condition is constexpr + if (KokkosKernels::Impl::kk_is_gpu_exec_space< + typename MemberType::execution_space>()) { + i = ij % mq * mb; + j = ij / mq * nb; + } else { + i = ij / nq * mb; + j = ij % nq * nb; + } + inner.serial_invoke( + alpha, AA + i * as0, BB + j * bs1, (i + mb) > ib ? mp : mb, + (j + nb) > jb ? np : nb, pb, CC + i * cs0 + j * cs1); + }); + }; + + gemm(m, n, k, A, B, C); + } + return 0; +} + +} // namespace Impl +} // namespace KokkosBlas + +#endif diff --git a/blas/impl/KokkosBlas_util.hpp b/blas/impl/KokkosBlas_util.hpp index dcee8283d6..0b4abc7590 100644 --- a/blas/impl/KokkosBlas_util.hpp +++ b/blas/impl/KokkosBlas_util.hpp @@ -61,6 +61,9 @@ struct Mode { struct TeamVector { static const char *name() { return "TeamVector"; } }; + struct ThreadVector { + static const char *name() { return "ThreadVector"; } + }; }; struct Trans { @@ -223,6 +226,72 @@ struct Algo { namespace Impl { +// matrix value fetching conenient for passing conjugation as template param +struct OpID { + template + KOKKOS_INLINE_FUNCTION ValueType operator()(ValueType v) const { + return v; + } +}; + +struct OpConj { + template + KOKKOS_INLINE_FUNCTION ValueType operator()(ValueType v) const { + using KAT = Kokkos::Details::ArithTraits; + return KAT::conj(v); + } +}; + +// This utility fetches matrix extents and strides based on transpose mode +template +struct MatrixModeInfo; + +template <> +struct MatrixModeInfo { + using Op = OpID; + + template + static size_t stride_0(ViewType v) { + return v.stride_0(); + } + template + static size_t stride_1(ViewType v) { + return v.stride_1(); + } + + template + static size_t extent(ViewType v, size_t i) { + assert(i == 0 || i == 1); + return v.extent(i); + } +}; + +template <> +struct MatrixModeInfo { + using Op = OpID; + + template + static size_t stride_0(ViewType v) { + return v.stride_1(); + } + template + static size_t stride_1(ViewType v) { + return v.stride_0(); + } + + template + static size_t extent(ViewType v, size_t i) { + assert(i == 0 || i == 1); + return v.extent(1 - i); + } +}; + +template <> +struct MatrixModeInfo + : public MatrixModeInfo { + using Op = OpConj; +}; + // Helper to choose the work distribution for a TeamPolicy computing multiple // reductions. Each team computes a partial reduction and atomically contributes // to the final result. diff --git a/blas/src/KokkosBlas1_scal.hpp b/blas/src/KokkosBlas1_scal.hpp index d533efe535..46c7d814d6 100644 --- a/blas/src/KokkosBlas1_scal.hpp +++ b/blas/src/KokkosBlas1_scal.hpp @@ -159,6 +159,22 @@ struct TeamVectorScale { } }; +/// +/// ThreadVector Scale +/// + +template +struct ThreadVectorScale { + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType& member, + const ScalarType alpha, + const AViewType& A) { + return Impl::ThreadVectorScaleInternal::invoke(member, A.extent(0), + A.extent(1), alpha, A.data(), + A.stride_0(), A.stride_1()); + } +}; + } // namespace KokkosBlas #endif diff --git a/blas/src/KokkosBlas1_set.hpp b/blas/src/KokkosBlas1_set.hpp index 61c03ec17a..a99160511d 100644 --- a/blas/src/KokkosBlas1_set.hpp +++ b/blas/src/KokkosBlas1_set.hpp @@ -94,6 +94,22 @@ struct TeamVectorSet { } }; +/// +/// ThreadVector Set +/// + +template +struct ThreadVectorSet { + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, + const ScalarType alpha, + const AViewType &A) { + return Impl::ThreadVectorSetInternal::invoke(member, A.extent(0), + A.extent(1), alpha, A.data(), + A.stride_0(), A.stride_1()); + } +}; + } // namespace KokkosBlas #endif diff --git a/blas/src/KokkosBlas2_gemv.hpp b/blas/src/KokkosBlas2_gemv.hpp index fe8418cc40..448ca62fbf 100644 --- a/blas/src/KokkosBlas2_gemv.hpp +++ b/blas/src/KokkosBlas2_gemv.hpp @@ -49,8 +49,7 @@ /// Tpetra::MultiVector use cases. #include -#include -#include +#include #include #include #include @@ -208,7 +207,179 @@ void gemv(const char trans[], typename AViewType::const_value_type& alpha, gemv(space, trans, alpha, A, x, beta, y); } +/********************* BEGIN functor-level routines *********************/ + +template +struct SerialGemv { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType /*alpha*/, + const AViewType& /*A*/, + const xViewType& /*x*/, + const ScalarType /*beta*/, + const yViewType& /*y*/); +}; + +template +struct TeamGemv { + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType& /*member*/, + const ScalarType /*alpha*/, + const AViewType& /*A*/, + const xViewType& /*x*/, + const ScalarType /*beta*/, + const yViewType& /*y*/); +}; + +template +struct TeamVectorGemv { + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType& /*member*/, + const ScalarType /*alpha*/, + const AViewType& /*A*/, + const xViewType& /*x*/, + const ScalarType /*beta*/, + const yViewType& /*y*/); +}; + +template +struct ThreadVectorGemv { + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType& /*member*/, + const ScalarType /*alpha*/, + const AViewType& /*A*/, + const xViewType& /*x*/, + const ScalarType /*beta*/, + const yViewType& /*y*/); +}; + namespace Experimental { + +template +void KOKKOS_INLINE_FUNCTION serial_gemv(const char trans, + const ScalarType& alpha, + const MatrixType& A, const XVector& x, + const ScalarType& beta, + const YVector& y) { + if (trans == 'N' || trans == 'n') { + using mode = KokkosBlas::Trans::NoTranspose; + KokkosBlas::SerialGemv::invoke(alpha, A, x, beta, y); + } else if (trans == 'T' || trans == 't') { + using mode = KokkosBlas::Trans::Transpose; + KokkosBlas::SerialGemv::invoke(alpha, A, x, beta, y); + } else if (trans == 'C' || trans == 'c') { + using mode = KokkosBlas::Trans::ConjTranspose; + KokkosBlas::SerialGemv::invoke(alpha, A, x, beta, y); + } else { + Kokkos::abort("Matrix mode not supported"); + } +} + +// default AlgoTag +template +void KOKKOS_INLINE_FUNCTION serial_gemv(const char trans, + const ScalarType& alpha, + const MatrixType& A, const XVector& x, + const ScalarType& beta, + const YVector& y) { + serial_gemv(trans, alpha, A, x, beta, y); +} + +template +void KOKKOS_INLINE_FUNCTION team_gemv(const MemberType& team, const char trans, + const ScalarType& alpha, + const MatrixType& A, const XVector& x, + const ScalarType& beta, + const YVector& y) { + if (trans == 'N' || trans == 'n') + TeamGemv::invoke(team, alpha, A, x, beta, y); + else if (trans == 'T' || trans == 't') + TeamGemv::invoke(team, alpha, A, x, beta, y); + else if (trans == 'C' || trans == 'c') + TeamGemv::invoke(team, alpha, A, x, beta, y); + else { + Kokkos::abort("Matrix mode not supported"); + } +} + +// default AlgoTag +template +void KOKKOS_INLINE_FUNCTION team_gemv(const MemberType& team, const char trans, + const ScalarType& alpha, + const MatrixType& A, const XVector& x, + const ScalarType& beta, + const YVector& y) { + team_gemv(team, trans, alpha, A, x, beta, y); +} + +template +void KOKKOS_INLINE_FUNCTION +teamvector_gemv(const MemberType& team, const char trans, + const ScalarType& alpha, const MatrixType& A, const XVector& x, + const ScalarType& beta, const YVector& y) { + if (trans == 'N' || trans == 'n') { + KokkosBlas::TeamVectorGemv::invoke( + team, alpha, A, x, beta, y); + } else if (trans == 'T' || trans == 't') { + KokkosBlas::TeamVectorGemv::invoke( + team, alpha, A, x, beta, y); + } else if (trans == 'C' || trans == 'c') { + KokkosBlas::TeamVectorGemv::invoke( + team, alpha, A, x, beta, y); + } else { + Kokkos::abort("Matrix mode not supported"); + } +} + +// default AlgoTag +template +void KOKKOS_INLINE_FUNCTION +team_vector_gemv(const MemberType& team, const char trans, + const ScalarType& alpha, const MatrixType& A, const XVector& x, + const ScalarType& beta, const YVector& y) { + teamvector_gemv(team, trans, alpha, A, x, + beta, y); +} + +template +void KOKKOS_INLINE_FUNCTION +threadvector_gemv(const MemberType& team, const char trans, + const ScalarType& alpha, const MatrixType& A, + const XVector& x, const ScalarType& beta, const YVector& y) { + if (trans == 'N' || trans == 'n') { + KokkosBlas::ThreadVectorGemv::invoke( + team, alpha, A, x, beta, y); + } else if (trans == 'T' || trans == 't') { + KokkosBlas::ThreadVectorGemv::invoke( + team, alpha, A, x, beta, y); + } else if (trans == 'C' || trans == 'c') { + KokkosBlas::ThreadVectorGemv::invoke( + team, alpha, A, x, beta, y); + } else { + Kokkos::abort("Matrix mode not supported"); + } +} + +// default AlgoTag +template +void KOKKOS_INLINE_FUNCTION +threadvector_gemv(const MemberType& team, const char trans, + const ScalarType& alpha, const MatrixType& A, + const XVector& x, const ScalarType& beta, const YVector& y) { + threadvector_gemv(team, trans, alpha, A, x, + beta, y); +} + /// /// Selective Interface /// @@ -258,7 +429,23 @@ struct Gemv { } }; +template +struct Gemv { + template + static void KOKKOS_INLINE_FUNCTION + invoke(const MemberType& member, const char trans, const ScalarType& alpha, + const MatrixType& A, const XVector& x, const ScalarType& beta, + const YVector& y) { + threadvector_gemv(member, trans, alpha, A, x, beta, y); + } +}; + } // namespace Experimental +/********************* END functor-level routines ***********************/ } // namespace KokkosBlas +#include "KokkosBlas2_serial_gemv_impl.hpp" +#include "KokkosBlas2_team_gemv_impl.hpp" + #endif // KOKKOS_BLAS2_MV_HPP_ diff --git a/blas/src/KokkosBlas2_team_gemv.hpp b/blas/src/KokkosBlas2_team_gemv.hpp deleted file mode 100644 index ddc216b8af..0000000000 --- a/blas/src/KokkosBlas2_team_gemv.hpp +++ /dev/null @@ -1,119 +0,0 @@ -/* -//@HEADER -// ************************************************************************ -// -// Kokkos v. 3.0 -// Copyright (2020) National Technology & Engineering -// Solutions of Sandia, LLC (NTESS). -// -// Under the terms of Contract DE-NA0003525 with NTESS, -// the U.S. Government retains certain rights in this software. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// 1. Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the Corporation nor the names of the -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE -// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// -// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) -// -// ************************************************************************ -//@HEADER -*/ - -#ifndef KOKKOSBLAS2_TEAM_GEMV_HPP_ -#define KOKKOSBLAS2_TEAM_GEMV_HPP_ - -#include - -namespace KokkosBlas { -namespace Experimental { - -template -void KOKKOS_INLINE_FUNCTION team_gemv(const TeamType& team, const char trans, - const ScalarType& alpha, - const MatrixType& A, const XVector& x, - const ScalarType& beta, - const YVector& y) { - if (trans == 'N' || trans == 'n') - TeamGemv::invoke(team, alpha, A, x, - beta, y); - else if (trans == 'T' || trans == 't') - TeamGemv::invoke(team, alpha, A, x, - beta, y); - else if (trans == 'C' || trans == 'c') - TeamGemv::invoke(team, alpha, A, x, - beta, y); - else { - Kokkos::abort("Matrix mode not supported"); - } -} - -// default AlgoTag -template -void KOKKOS_INLINE_FUNCTION team_gemv(const TeamType& team, const char trans, - const ScalarType& alpha, - const MatrixType& A, const XVector& x, - const ScalarType& beta, - const YVector& y) { - team_gemv(team, trans, alpha, A, x, beta, y); -} - -template -void KOKKOS_INLINE_FUNCTION -teamvector_gemv(const TeamType& team, const char trans, const ScalarType& alpha, - const MatrixType& A, const XVector& x, const ScalarType& beta, - const YVector& y) { - if (trans == 'N' || trans == 'n') { - KokkosBlas::TeamVectorGemv::invoke( - team, alpha, A, x, beta, y); - } else if (trans == 'T' || trans == 't') { - KokkosBlas::TeamVectorGemv::invoke( - team, alpha, A, x, beta, y); - } else if (trans == 'C' || trans == 'c') { - KokkosBlas::TeamVectorGemv::invoke( - team, alpha, A, x, beta, y); - } else { - Kokkos::abort("Matrix mode not supported"); - } -} - -// default AlgoTag -template -void KOKKOS_INLINE_FUNCTION -team_vector_gemv(const TeamType& team, const char trans, - const ScalarType& alpha, const MatrixType& A, const XVector& x, - const ScalarType& beta, const YVector& y) { - teamvector_gemv(team, trans, alpha, A, x, - beta, y); -} - -} // namespace Experimental -} // namespace KokkosBlas - -#endif diff --git a/blas/src/KokkosBlas3_gemm.hpp b/blas/src/KokkosBlas3_gemm.hpp index 3d36ad86c9..d93c378e67 100644 --- a/blas/src/KokkosBlas3_gemm.hpp +++ b/blas/src/KokkosBlas3_gemm.hpp @@ -256,6 +256,118 @@ void gemm(const char transA[], const char transB[], gemm(space, transA, transB, alpha, A, B, beta, C); } +/********************* BEGIN functor-level routines *********************/ + +/// +/// Serial Gemm +/// + +template +struct SerialGemm { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, + const AViewType& A, + const BViewType& B, + const ScalarType beta, + const CViewType& C); +}; + +/// +/// Team Impl +/// ========= + +template +struct TeamGemm { + template + KOKKOS_INLINE_FUNCTION static int invoke( + const MemberType& member, const ScalarType alpha, const AViewType& A, + const BViewType& B, const ScalarType beta, const CViewType& C); +}; + +/// +/// TeamVector Impl +/// ========= + +template +struct TeamVectorGemm { + template + KOKKOS_INLINE_FUNCTION static int invoke( + const MemberType& member, const ScalarType alpha, const AViewType& A, + const BViewType& B, const ScalarType beta, const CViewType& C); +}; + +/// +/// ThreadVector Impl +/// ========= + +template +struct ThreadVectorGemm { + template + KOKKOS_INLINE_FUNCTION static int invoke( + const MemberType& member, const ScalarType alpha, const AViewType& A, + const BViewType& B, const ScalarType beta, const CViewType& C); +}; + +/// +/// Selective Interface +/// +template +struct Gemm { + template + KOKKOS_FORCEINLINE_FUNCTION static int invoke( + const MemberType& member, const ScalarType alpha, const AViewType& A, + const BViewType& B, const ScalarType beta, const CViewType& C); +}; + +template +struct Gemm { + template + KOKKOS_FORCEINLINE_FUNCTION static int invoke(const MemberType& /* member */, + const ScalarType alpha, + const AViewType& A, + const BViewType& B, + const ScalarType beta, + const CViewType& C) { + return SerialGemm::invoke(alpha, A, B, beta, + C); + } +}; + +template +struct Gemm { + template + KOKKOS_FORCEINLINE_FUNCTION static int invoke( + const MemberType& member, const ScalarType alpha, const AViewType& A, + const BViewType& B, const ScalarType beta, const CViewType& C) { + return TeamGemm::invoke(member, alpha, A, B, + beta, C); + } +}; + +template +struct Gemm { + template + KOKKOS_FORCEINLINE_FUNCTION static int invoke( + const MemberType& member, const ScalarType alpha, const AViewType& A, + const BViewType& B, const ScalarType beta, const CViewType& C) { + return TeamVectorGemm::invoke(member, alpha, + A, B, beta, C); + } +}; + +/********************* END functor-level routines *********************/ } // namespace KokkosBlas +#include +#include + #endif // KOKKOS_BLAS3_MV_HPP_ diff --git a/blas/src/KokkosBlas3_gemm_inner_fix.hpp b/blas/src/KokkosBlas3_gemm_inner_fix.hpp new file mode 100644 index 0000000000..b4632dfa41 --- /dev/null +++ b/blas/src/KokkosBlas3_gemm_inner_fix.hpp @@ -0,0 +1,169 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef KOKKOSBLAS3_GEMM_INNER_FIX_HPP +#define KOKKOSBLAS3_GEMM_INNER_FIX_HPP + +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +namespace KokkosBlas { + +template +struct InnerGemmFixA { + const int _as0, _as1, _bs0, _bs1, _cs0, _cs1; + + KOKKOS_INLINE_FUNCTION + InnerGemmFixA(const int as0, const int as1, const int bs0, const int bs1, + const int cs0, const int cs1) + : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1), _cs0(cs0), _cs1(cs1) {} + + // serial rank update + template + KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const ValueType *KOKKOS_RESTRICT B, + const int n, + /**/ ValueType *KOKKOS_RESTRICT C); + + // serial rank update for remainder + template + KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const ValueType *KOKKOS_RESTRICT B, + const int m, const int n, + const int k, + /**/ ValueType *KOKKOS_RESTRICT C); +}; + +template +struct InnerGemmFixB { + const int _as0, _as1, _bs0, _bs1, _cs0, _cs1; + + KOKKOS_INLINE_FUNCTION + InnerGemmFixB(const int as0, const int as1, const int bs0, const int bs1, + const int cs0, const int cs1) + : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1), _cs0(cs0), _cs1(cs1) {} + + // serial rank update + template + KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const ValueType *KOKKOS_RESTRICT B, + const int n, + /**/ ValueType *KOKKOS_RESTRICT C); + + // serial rank update for remainder + template + KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const ValueType *KOKKOS_RESTRICT B, + const int m, const int n, + const int k, + /**/ ValueType *KOKKOS_RESTRICT C); +}; + +template +struct InnerGemmFixC { + const int _as0, _as1, _bs0, _bs1, _cs0, _cs1; + + KOKKOS_INLINE_FUNCTION + InnerGemmFixC(const int as0, const int as1, const int bs0, const int bs1, + const int cs0, const int cs1) + : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1), _cs0(cs0), _cs1(cs1) {} + + // serial rank update + template + KOKKOS_INLINE_FUNCTION int serial_invoke(OpA opA, OpB opB, + const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const ValueType *KOKKOS_RESTRICT B, + const int k, + /**/ ValueType *KOKKOS_RESTRICT C); + + // serial rank update for remainder + template + KOKKOS_INLINE_FUNCTION int serial_invoke(OpA opA, OpB opB, + const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const ValueType *KOKKOS_RESTRICT B, + const int m, const int k, + /**/ ValueType *KOKKOS_RESTRICT C); + + // serial rank update for remainder + template + KOKKOS_INLINE_FUNCTION int serial_invoke(OpA opA, OpB opB, + const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const ValueType *KOKKOS_RESTRICT B, + const int m, const int n, + const int k, + /**/ ValueType *KOKKOS_RESTRICT C); + + template + KOKKOS_INLINE_FUNCTION int team_invoke(const MemberType &member, + const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const ValueType *KOKKOS_RESTRICT B, + const int k, + /**/ ValueType *KOKKOS_RESTRICT C); + + // team rank update for remainder + template + KOKKOS_INLINE_FUNCTION int team_invoke(const MemberType &member, + const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const ValueType *KOKKOS_RESTRICT B, + const int m, const int n, const int k, + /**/ ValueType *KOKKOS_RESTRICT C); +}; + +} // namespace KokkosBlas + +#include "KokkosBlas3_serial_gemm_inner_fixa_impl.hpp" +// TODO: fix compilation errors InnerGemmFixB (not used internally, not tested) +// #include "KokkosBlas3_serial_gemm_inner_fixb_impl.hpp" +#include "KokkosBlas3_serial_gemm_inner_fixc_impl.hpp" +#include "KokkosBlas3_team_gemm_inner_fixc_impl.hpp" + +#endif diff --git a/blas/tpls/KokkosBlas2_serial_gemv_tpl_spec_decl.hpp b/blas/tpls/KokkosBlas2_serial_gemv_tpl_spec_decl.hpp index 77aa5a6713..0fa958b3f7 100644 --- a/blas/tpls/KokkosBlas2_serial_gemv_tpl_spec_decl.hpp +++ b/blas/tpls/KokkosBlas2_serial_gemv_tpl_spec_decl.hpp @@ -46,18 +46,10 @@ #include "KokkosBlas_util.hpp" #include "KokkosBatched_Vector.hpp" - -#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) -#include "mkl_version.h" -#if __INTEL_MKL__ >= 2018 -#define __KOKKOSBLAS_ENABLE_INTEL_MKL_COMPACT__ 1 -#endif -#endif +#include "KokkosKernels_MKLUtils.hpp" #ifdef __KOKKOSBLAS_ENABLE_INTEL_MKL_COMPACT__ -#include "mkl_compact.h" - namespace KokkosBlas { namespace Impl { diff --git a/blas/tpls/KokkosBlas3_serial_gemm_tpl_spec_decl.hpp b/blas/tpls/KokkosBlas3_serial_gemm_tpl_spec_decl.hpp new file mode 100644 index 0000000000..e608841967 --- /dev/null +++ b/blas/tpls/KokkosBlas3_serial_gemm_tpl_spec_decl.hpp @@ -0,0 +1,122 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ +#ifndef KOKKOSBLAS3_SERIAL_GEMM_TPL_SPEC_DECL_HPP_ +#define KOKKOSBLAS3_SERIAL_GEMM_TPL_SPEC_DECL_HPP_ + +#include "KokkosBlas_util.hpp" +#include "KokkosBatched_Vector.hpp" +#include "KokkosKernels_MKLUtils.hpp" + +#ifdef __KOKKOSBLAS_ENABLE_INTEL_MKL_COMPACT__ + +namespace KokkosBlas { + +namespace Impl { + +template +constexpr MKL_TRANSPOSE trans2mkl = + std::is_same::value + ? MKL_NOTRANS + : (std::is_same::value + ? MKL_TRANS + : (std::is_same::value + ? MKL_CONJTRANS + : MKL_CONJ)); // Note: CONJ is not supported by MKL GEMM + +} +/// +/// Serial Impl +/// =========== + +template +struct SerialGemm { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B, + const ScalarType beta, + const CViewType &C) { + typedef typename CViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = C.extent(0), n = C.extent(1); + const int k = + A.extent(std::is_same::value ? 1 : 0); + const MKL_TRANSPOSE trans_A = Impl::trans2mkl; + const MKL_TRANSPOSE trans_B = Impl::trans2mkl; + + static_assert(KokkosBatched::is_vector::value, + "value type is not vector type"); + static_assert( + vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = + vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { + mkl_dgemm_compact(MKL_COL_MAJOR, trans_A, trans_B, m, n, k, alpha, + (const double *)A.data(), A.stride_1(), + (const double *)B.data(), B.stride_1(), beta, + (double *)C.data(), C.stride_1(), format, + (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { + mkl_dgemm_compact(MKL_ROW_MAJOR, trans_A, trans_B, m, n, k, alpha, + (const double *)A.data(), A.stride_0(), + (const double *)B.data(), B.stride_0(), beta, + (double *)C.data(), C.stride_0(), format, + (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; + } +}; + +} // namespace KokkosBlas + +#endif // __KOKKOSBLAS_ENABLE_INTEL_MKL_COMPACT__ +#endif // KOKKOSBLAS2_SERIAL_GEMV_TPL_SPEC_DECL_HPP_ diff --git a/blas/unit_test/Test_Blas.hpp b/blas/unit_test/Test_Blas.hpp index b794d74bde..79d7ea27f5 100644 --- a/blas/unit_test/Test_Blas.hpp +++ b/blas/unit_test/Test_Blas.hpp @@ -31,6 +31,7 @@ // Team Blas 1 #include "Test_Blas1_team_setscal.hpp" +#include "Test_Blas1_threadvector_setscal.hpp" #include "Test_Blas1_team_abs.hpp" #include "Test_Blas1_team_axpby.hpp" #include "Test_Blas1_team_axpy.hpp" @@ -49,9 +50,14 @@ // Team Blas 2 #include "Test_Blas2_team_gemv.hpp" #include "Test_Blas2_teamvector_gemv.hpp" +#include "Test_Blas2_threadvector_gemv.hpp" // Blas 3 #include "Test_Blas3_gemm.hpp" +#include "Test_Blas3_serial_gemm.hpp" +#include "Test_Blas3_team_gemm.hpp" +#include "Test_Blas3_teamvector_gemm.hpp" +#include "Test_Blas3_threadvector_gemm.hpp" #include "Test_Blas3_trmm.hpp" #include "Test_Blas3_trsm.hpp" diff --git a/blas/unit_test/Test_Blas1_threadvector_setscal.hpp b/blas/unit_test/Test_Blas1_threadvector_setscal.hpp new file mode 100644 index 0000000000..b2c96617de --- /dev/null +++ b/blas/unit_test/Test_Blas1_threadvector_setscal.hpp @@ -0,0 +1,271 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#include "gtest/gtest.h" +#include "Kokkos_Core.hpp" +#include "Kokkos_Random.hpp" + +#include "KokkosBlas1_set.hpp" +#include "KokkosBlas1_scal.hpp" + +#include "KokkosKernels_TestUtils.hpp" + +namespace Test { +namespace ThreadVectorMatUtil { + +enum : int { BlasSet = 0, BlasScale = 1 }; + +struct KokkosKernelTag {}; +struct NaiveTag {}; + +template +struct Functor_TestBlasThreadVectorMatUtil { + ScalarType _alpha; + ViewType _a; + + KOKKOS_INLINE_FUNCTION + Functor_TestBlasThreadVectorMatUtil(const ScalarType alpha, const ViewType &a) + : _alpha(alpha), _a(a) {} + + template + KOKKOS_INLINE_FUNCTION void operator()(const KokkosKernelTag &, + const MemberType &member) const { + const int i = member.league_rank(); + auto A = Kokkos::subview(_a, i, Kokkos::ALL(), Kokkos::ALL()); + switch (TestID) { + case BlasSet: + KokkosBlas::ThreadVectorSet::invoke(member, _alpha, A); + break; + case BlasScale: + KokkosBlas::ThreadVectorScale::invoke(member, _alpha, A); + break; + } + } + + template + KOKKOS_INLINE_FUNCTION void operator()(const NaiveTag &, + const MemberType &member) const { + if (member.team_rank() == 0) { + const int k = member.league_rank(); + auto A = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); + const int m = A.extent(0), n = A.extent(1); + switch (TestID) { + case BlasSet: { + for (int i = 0; i < m; ++i) + for (int j = 0; j < n; ++j) A(i, j) = _alpha; + break; + } + case BlasScale: { + for (int i = 0; i < m; ++i) + for (int j = 0; j < n; ++j) A(i, j) *= _alpha; + break; + } + } + } + } + + inline int run() { + typedef typename ViewType::value_type value_type; + std::string name_region("KokkosBlas::Test::SerialMatUtil"); + const std::string name_value_type = Test::value_type_name(); + std::string name_work_tag = + (std::is_same::value + ? "::KokkosBlas" + : std::is_same::value ? "::Naive" + : "::UnknownWorkTag"); + std::string name_test_id = + (TestID == BlasSet ? "Set" + : TestID == BlasScale ? "Scale" : "UnknownTest"); + std::string name = + name_region + name_value_type + name_work_tag + name_test_id; + Kokkos::Profiling::pushRegion(name.c_str()); + + const int league_size = _a.extent(0); + Kokkos::TeamPolicy policy(league_size, + Kokkos::AUTO); + Kokkos::parallel_for(name.c_str(), policy, *this); + Kokkos::Profiling::popRegion(); + + return 0; + } +}; + +template +void impl_test_blas_matutil(const int N, const int BlkSize) { + /// typedefs + typedef typename ViewType::value_type value_type; + typedef Kokkos::Details::ArithTraits ats; + + /// radomized input testing views + const ScalarType alpha = 11.1; + ViewType a("a", N, BlkSize, BlkSize); + ViewType b("b", N, BlkSize, BlkSize); + + Kokkos::Random_XorShift64_Pool random( + 13718); + Kokkos::fill_random(a, random, value_type(1.0)); + + Kokkos::fence(); + + Kokkos::deep_copy(b, a); + + /// test body + Functor_TestBlasThreadVectorMatUtil(alpha, a) + .run(); + Functor_TestBlasThreadVectorMatUtil(alpha, b) + .run(); + + Kokkos::fence(); + + /// for comparison send it to host + typename ViewType::HostMirror a_host = Kokkos::create_mirror_view(a); + typename ViewType::HostMirror b_host = Kokkos::create_mirror_view(b); + + Kokkos::deep_copy(a_host, a); + Kokkos::deep_copy(b_host, b); + + /// check a = b + typename ats::mag_type eps = + 100 * std::numeric_limits::epsilon(); + for (int k = 0; k < N; ++k) + for (int i = 0; i < BlkSize; ++i) + for (int j = 0; j < BlkSize; ++j) + EXPECT_NEAR_KK(b_host(k, i, j), a_host(k, i, j), eps); +} +} // namespace ThreadVectorMatUtil +} // namespace Test + +template +int test_blas_threadvector_matutil() { +#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) + { + typedef Kokkos::View + ViewType; + Test::ThreadVectorMatUtil::impl_test_blas_matutil(0, + 10); + Test::ThreadVectorMatUtil::impl_test_blas_matutil(10, + 15); + Test::ThreadVectorMatUtil::impl_test_blas_matutil(1024, + 9); + Test::ThreadVectorMatUtil::impl_test_blas_matutil( + 132231, 3); + } +#endif +#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) + { + typedef Kokkos::View + ViewType; + Test::ThreadVectorMatUtil::impl_test_blas_matutil(0, + 10); + Test::ThreadVectorMatUtil::impl_test_blas_matutil(10, + 15); + Test::ThreadVectorMatUtil::impl_test_blas_matutil(1024, + 9); + Test::ThreadVectorMatUtil::impl_test_blas_matutil( + 132231, 3); + } +#endif + + return 0; +} + +// Real test cases + +#if defined(KOKKOSKERNELS_INST_FLOAT) +TEST_F(TestCategory, blas_scalar_threadvector_set_float_float) { + test_blas_threadvector_matutil(); +} +TEST_F(TestCategory, blas_scalar_threadvector_scale_float_float) { + test_blas_threadvector_matutil(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_DOUBLE) +TEST_F(TestCategory, blas_scalar_threadvector_set_double_double) { + test_blas_threadvector_matutil(); +} +TEST_F(TestCategory, blas_scalar_threadvector_scale_double_double) { + test_blas_threadvector_matutil(); +} +#endif + +// Complex test cases + +#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) +TEST_F(TestCategory, blas_scalar_threadvector_set_dcomplex_dcomplex) { + test_blas_threadvector_matutil, + Kokkos::complex, + ::Test::ThreadVectorMatUtil::BlasSet>(); +} +TEST_F(TestCategory, blas_scalar_threadvector_scale_dcomplex_dcomplex) { + test_blas_threadvector_matutil, + Kokkos::complex, + ::Test::ThreadVectorMatUtil::BlasScale>(); +} +TEST_F(TestCategory, blas_scalar_threadvector_set_dcomplex_double) { + test_blas_threadvector_matutil, double, + ::Test::ThreadVectorMatUtil::BlasSet>(); +} +TEST_F(TestCategory, blas_scalar_threadvector_scale_dcomplex_double) { + test_blas_threadvector_matutil, double, + ::Test::ThreadVectorMatUtil::BlasScale>(); +} +#endif diff --git a/blas/unit_test/Test_Blas2_gemv_util.hpp b/blas/unit_test/Test_Blas2_gemv_util.hpp index 635f02c558..c45602d0ec 100644 --- a/blas/unit_test/Test_Blas2_gemv_util.hpp +++ b/blas/unit_test/Test_Blas2_gemv_util.hpp @@ -299,24 +299,24 @@ struct GEMVTest { } // namespace Test -#define TEST_CASE4(PREFIX, FACTORY, NAME, SCALAR_A, SCALAR_X, SCALAR_Y, \ - SCALAR_COEF) \ - using PREFIX##_##NAME##_gemv_test = \ - ::Test::GEMVTest<::Test::FACTORY, SCALAR_A, SCALAR_X, SCALAR_Y, \ - TestExecSpace, SCALAR_COEF>; \ - TEST_F(TestCategory, PREFIX##_gemv_nt_##NAME) { \ - PREFIX##_##NAME##_gemv_test::run("N"); \ - } \ - TEST_F(TestCategory, PREFIX##_gemv_t_##NAME) { \ - PREFIX##_##NAME##_gemv_test::run("T"); \ - } \ - TEST_F(TestCategory, PREFIX##_gemv_ct_##NAME) { \ - PREFIX##_##NAME##_gemv_test::run("C"); \ +#define TEST_GEMV_CASE4(PREFIX, FACTORY, NAME, SCALAR_A, SCALAR_X, SCALAR_Y, \ + SCALAR_COEF) \ + using PREFIX##_##NAME##_gemv_test = \ + ::Test::GEMVTest<::Test::FACTORY, SCALAR_A, SCALAR_X, SCALAR_Y, \ + TestExecSpace, SCALAR_COEF>; \ + TEST_F(TestCategory, PREFIX##_gemv_nt_##NAME) { \ + PREFIX##_##NAME##_gemv_test::run("N"); \ + } \ + TEST_F(TestCategory, PREFIX##_gemv_t_##NAME) { \ + PREFIX##_##NAME##_gemv_test::run("T"); \ + } \ + TEST_F(TestCategory, PREFIX##_gemv_ct_##NAME) { \ + PREFIX##_##NAME##_gemv_test::run("C"); \ } -#define TEST_CASE2(PREFIX, FACTORY, NAME, SCALAR, SCALAR_COEF) \ - TEST_CASE4(PREFIX, FACTORY, NAME, SCALAR, SCALAR, SCALAR, SCALAR_COEF) -#define TEST_CASE(PREFIX, FACTORY, NAME, SCALAR) \ - TEST_CASE2(PREFIX, FACTORY, NAME, SCALAR, SCALAR) +#define TEST_GEMV_CASE2(PREFIX, FACTORY, NAME, SCALAR, SCALAR_COEF) \ + TEST_GEMV_CASE4(PREFIX, FACTORY, NAME, SCALAR, SCALAR, SCALAR, SCALAR_COEF) +#define TEST_GEMV_CASE(PREFIX, FACTORY, NAME, SCALAR) \ + TEST_GEMV_CASE2(PREFIX, FACTORY, NAME, SCALAR, SCALAR) #endif // TEST_BLAS2_GEMV_UTIL_HPP diff --git a/blas/unit_test/Test_Blas2_serial_gemv.hpp b/blas/unit_test/Test_Blas2_serial_gemv.hpp index fd73707c9a..a7b9526722 100644 --- a/blas/unit_test/Test_Blas2_serial_gemv.hpp +++ b/blas/unit_test/Test_Blas2_serial_gemv.hpp @@ -47,10 +47,12 @@ struct SerialMKLGemvFactory { } // namespace Test #define TEST_SERIAL_CASE4(N, A, X, Y, SC) \ - TEST_CASE4(serial, SerialGemvFactory, N, A, X, Y, SC) + TEST_GEMV_CASE4(serial, SerialGemvFactory, N, A, X, Y, SC) #define TEST_SERIAL_CASE2(N, S, SC) \ - TEST_CASE2(serial, SerialGemvFactory, N, S, SC) -#define TEST_SERIAL_CASE(N, S) TEST_CASE(serial, SerialGemvFactory, N, S) + TEST_GEMV_CASE2(serial, SerialGemvFactory, N, S, SC) +#define TEST_SERIAL_CASE(N, S) TEST_GEMV_CASE(serial, SerialGemvFactory, N, S) +#define TEST_SERIAL_GEMV_MKL_CASE(N, S, SC) \ + TEST_GEMV_CASE2(serial, SerialMKLGemvFactory, N, S, SC) #ifdef KOKKOSKERNELS_TEST_FLOAT TEST_SERIAL_CASE(float, float) @@ -59,10 +61,9 @@ TEST_SERIAL_CASE(float, float) using simd_float_sse = ::Test::simd_vector; using simd_float_avx = ::Test::simd_vector; using simd_float_avx512 = ::Test::simd_vector; -TEST_CASE2(serial, SerialMKLGemvFactory, mkl_float_sse, simd_float_sse, float) -TEST_CASE2(serial, SerialMKLGemvFactory, mkl_float_avx, simd_float_avx, float) -TEST_CASE2(serial, SerialMKLGemvFactory, mkl_float_avx512, simd_float_avx512, - float) +TEST_SERIAL_GEMV_MKL_CASE(mkl_float_sse, simd_float_sse, float) +TEST_SERIAL_GEMV_MKL_CASE(mkl_float_avx, simd_float_avx, float) +TEST_SERIAL_GEMV_MKL_CASE(mkl_float_avx512, simd_float_avx512, float) #endif #endif @@ -73,12 +74,9 @@ TEST_SERIAL_CASE(double, double) using simd_double_sse = ::Test::simd_vector; using simd_double_avx = ::Test::simd_vector; using simd_double_avx512 = ::Test::simd_vector; -TEST_CASE2(serial, SerialMKLGemvFactory, mkl_double_sse, simd_double_sse, - double) -TEST_CASE2(serial, SerialMKLGemvFactory, mkl_double_avx, simd_double_avx, - double) -TEST_CASE2(serial, SerialMKLGemvFactory, mkl_double_avx512, simd_double_avx512, - double) +TEST_SERIAL_GEMV_MKL_CASE(mkl_double_sse, simd_double_sse, double) +TEST_SERIAL_GEMV_MKL_CASE(mkl_double_avx, simd_double_avx, double) +TEST_SERIAL_GEMV_MKL_CASE(mkl_double_avx512, simd_double_avx512, double) #endif #endif @@ -105,3 +103,4 @@ TEST_SERIAL_CASE2(alphabeta, Kokkos::complex, double) #undef TEST_SERIAL_CASE4 #undef TEST_SERIAL_CASE2 #undef TEST_SERIAL_CASE +#undef TEST_SERIAL_GEMV_MKL_CASE diff --git a/blas/unit_test/Test_Blas2_team_gemv.hpp b/blas/unit_test/Test_Blas2_team_gemv.hpp index 722aca1938..d4b1733521 100644 --- a/blas/unit_test/Test_Blas2_team_gemv.hpp +++ b/blas/unit_test/Test_Blas2_team_gemv.hpp @@ -42,9 +42,10 @@ struct TeamGemvFactory { } // namespace Test #define TEST_TEAM_CASE4(N, A, X, Y, SC) \ - TEST_CASE4(team, TeamGemvFactory, N, A, X, Y, SC) -#define TEST_TEAM_CASE2(N, S, SC) TEST_CASE2(team, TeamGemvFactory, N, S, SC) -#define TEST_TEAM_CASE(N, S) TEST_CASE(team, TeamGemvFactory, N, S) + TEST_GEMV_CASE4(team, TeamGemvFactory, N, A, X, Y, SC) +#define TEST_TEAM_CASE2(N, S, SC) \ + TEST_GEMV_CASE2(team, TeamGemvFactory, N, S, SC) +#define TEST_TEAM_CASE(N, S) TEST_GEMV_CASE(team, TeamGemvFactory, N, S) #ifdef KOKKOSKERNELS_TEST_FLOAT TEST_TEAM_CASE(float, float) diff --git a/blas/unit_test/Test_Blas2_teamvector_gemv.hpp b/blas/unit_test/Test_Blas2_teamvector_gemv.hpp index 5814541bb2..490c543d97 100644 --- a/blas/unit_test/Test_Blas2_teamvector_gemv.hpp +++ b/blas/unit_test/Test_Blas2_teamvector_gemv.hpp @@ -44,11 +44,11 @@ struct TeamVectorGemvFactory { } // namespace Test #define TEST_TEAMVECTOR_CASE4(N, A, X, Y, SC) \ - TEST_CASE4(teamvector, TeamVectorGemvFactory, N, A, X, Y, SC) + TEST_GEMV_CASE4(teamvector, TeamVectorGemvFactory, N, A, X, Y, SC) #define TEST_TEAMVECTOR_CASE2(N, S, SC) \ - TEST_CASE2(teamvector, TeamVectorGemvFactory, N, S, SC) + TEST_GEMV_CASE2(teamvector, TeamVectorGemvFactory, N, S, SC) #define TEST_TEAMVECTOR_CASE(N, S) \ - TEST_CASE(teamvector, TeamVectorGemvFactory, N, S) + TEST_GEMV_CASE(teamvector, TeamVectorGemvFactory, N, S) #ifdef KOKKOSKERNELS_TEST_FLOAT TEST_TEAMVECTOR_CASE(float, float) diff --git a/blas/unit_test/Test_Blas2_threadvector_gemv.hpp b/blas/unit_test/Test_Blas2_threadvector_gemv.hpp new file mode 100644 index 0000000000..a14f7dca8b --- /dev/null +++ b/blas/unit_test/Test_Blas2_threadvector_gemv.hpp @@ -0,0 +1,85 @@ +// Note: Luc Berger-Vergiat 04/14/21 +// This tests uses KOKKOS_LAMBDA so we need +// to make sure that these are enabled in +// the CUDA backend before including this test. +#if !defined(TEST_CUDA_BLAS_CPP) || defined(KOKKOS_ENABLE_CUDA_LAMBDA) + +#include +#include // for test/inst guards +// Note: include serial gemv before util so it knows if CompactMKL is available +#include +#include + +namespace Test { + +template +struct ThreadVectorGEMVOp : public GemvOpBase { + using params = GemvOpBase; + + ThreadVectorGEMVOp(char trans_, ScalarType alpha_, AType A_, XType x_, + ScalarType beta_, YType y_) + : params(trans_, alpha_, A_, x_, beta_, y_) {} + + template + KOKKOS_INLINE_FUNCTION void operator()(const TeamMember& member) const { + KokkosBlas::Experimental::Gemv::invoke(member, params::trans, + params::alpha, params::A, + params::x, params::beta, + params::y); + } +}; + +struct ThreadVectorGemvFactory { + template + using functor_type = + ThreadVectorGEMVOp; + + // no Blocked implementation + using algorithms = std::tuple; +}; + +} // namespace Test + +#define TEST_THREADVECTOR_CASE4(N, A, X, Y, SC) \ + TEST_GEMV_CASE4(threadvector, ThreadVectorGemvFactory, N, A, X, Y, SC) +#define TEST_THREADVECTOR_CASE2(N, S, SC) \ + TEST_GEMV_CASE2(threadvector, ThreadVectorGemvFactory, N, S, SC) +#define TEST_THREADVECTOR_CASE(N, S) \ + TEST_GEMV_CASE(threadvector, ThreadVectorGemvFactory, N, S) + +#ifdef KOKKOSKERNELS_TEST_FLOAT +TEST_THREADVECTOR_CASE(float, float) +#endif + +#ifdef KOKKOSKERNELS_TEST_DOUBLE +TEST_THREADVECTOR_CASE(double, double) +#endif + +#ifdef KOKKOSKERNELS_TEST_COMPLEX_DOUBLE +TEST_THREADVECTOR_CASE(complex_double, Kokkos::complex) +#endif + +#ifdef KOKKOSKERNELS_TEST_COMPLEX_FLOAT +TEST_THREADVECTOR_CASE(complex_float, Kokkos::complex) +#endif + +#ifdef KOKKOSKERNELS_TEST_INT +TEST_THREADVECTOR_CASE(int, int) +#endif + +#ifdef KOKKOSKERNELS_TEST_ALL_TYPES +// test mixed scalar types (void -> default alpha/beta) +TEST_THREADVECTOR_CASE4(mixed, double, int, float, void) + +// test arbitrary double alpha/beta with complex values +TEST_THREADVECTOR_CASE2(alphabeta, Kokkos::complex, double) +#endif + +#undef TEST_THREADVECTOR_CASE4 +#undef TEST_THREADVECTOR_CASE2 +#undef TEST_THREADVECTOR_CASE + +#endif // Check for lambda availability on CUDA backend diff --git a/blas/unit_test/Test_Blas3_gemm_util.hpp b/blas/unit_test/Test_Blas3_gemm_util.hpp new file mode 100644 index 0000000000..39f9adbe99 --- /dev/null +++ b/blas/unit_test/Test_Blas3_gemm_util.hpp @@ -0,0 +1,54 @@ +#ifndef TEST_BLAS2_GEMM_UTIL_HPP +#define TEST_BLAS2_GEMM_UTIL_HPP + +#include "gtest/gtest.h" +#include "Kokkos_Core.hpp" +#include "Kokkos_Random.hpp" + +#include "KokkosBlas3_gemm.hpp" +#include "KokkosKernels_TestUtils.hpp" + +namespace Test { +namespace Gemm { + +using KokkosBlas::Algo; +using KokkosBlas::Trans; + +template +struct ParamTag { + typedef TA transA; + typedef TB transB; +}; + +#define TEST_GEMM_ALGO(NAME, FUNC, TRANS_A, TRANS_B, VALUE, SCALAR) \ + TEST_F(TestCategory, batched_scalar_##NAME) { \ + typedef ::Test::Gemm::ParamTag param_tag_type; \ + FUNC(); \ + FUNC(); \ + } + +#define TEST_GEMM_CASE(PREFIX, NAME, FUNC, VALUE, SCALAR) \ + TEST_GEMM_ALGO(PREFIX##_gemm_nt_nt_##NAME, FUNC, Trans::NoTranspose, \ + Trans::NoTranspose, VALUE, SCALAR) \ + TEST_GEMM_ALGO(PREFIX##_gemm_t_nt_##NAME, FUNC, Trans::Transpose, \ + Trans::NoTranspose, VALUE, SCALAR) \ + TEST_GEMM_ALGO(PREFIX##_gemm_ct_nt_##NAME, FUNC, Trans::ConjTranspose, \ + Trans::NoTranspose, VALUE, SCALAR) \ + TEST_GEMM_ALGO(PREFIX##_gemm_nt_t_##NAME, FUNC, Trans::NoTranspose, \ + Trans::Transpose, VALUE, SCALAR) \ + TEST_GEMM_ALGO(PREFIX##_gemm_t_t_##NAME, FUNC, Trans::Transpose, \ + Trans::Transpose, VALUE, SCALAR) \ + TEST_GEMM_ALGO(PREFIX##_gemm_ct_t_##NAME, FUNC, Trans::ConjTranspose, \ + Trans::Transpose, VALUE, SCALAR) \ + TEST_GEMM_ALGO(PREFIX##_gemm_nt_ct_##NAME, FUNC, Trans::NoTranspose, \ + Trans::ConjTranspose, VALUE, SCALAR) \ + TEST_GEMM_ALGO(PREFIX##_gemm_t_ct_##NAME, FUNC, Trans::Transpose, \ + Trans::ConjTranspose, VALUE, SCALAR) \ + TEST_GEMM_ALGO(PREFIX##_gemm_ct_ct_##NAME, FUNC, Trans::ConjTranspose, \ + Trans::ConjTranspose, VALUE, SCALAR) + +} // namespace Gemm +} // namespace Test + +#endif // TEST_BLAS2_GEMM_UTIL_HPP diff --git a/blas/unit_test/Test_Blas3_serial_gemm.hpp b/blas/unit_test/Test_Blas3_serial_gemm.hpp new file mode 100644 index 0000000000..5c780a3ccb --- /dev/null +++ b/blas/unit_test/Test_Blas3_serial_gemm.hpp @@ -0,0 +1,238 @@ +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Test_Blas3_gemm_util.hpp" + +namespace Test { +namespace Gemm { + +template +struct Functor_TestBlasSerialGemm { + ViewType _a, _b, _c; + + ScalarType _alpha, _beta; + + KOKKOS_INLINE_FUNCTION + Functor_TestBlasSerialGemm(const ScalarType alpha, const ViewType &a, + const ViewType &b, const ScalarType beta, + const ViewType &c) + : _a(a), _b(b), _c(c), _alpha(alpha), _beta(beta) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const ParamTagType &, const int k) const { + auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); + auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL()); + + KokkosBlas::SerialGemm::invoke(_alpha, aa, bb, _beta, cc); + } + + inline void run() { + typedef typename ViewType::value_type value_type; + std::string name_region("KokkosBlas::Test::SerialGemm"); + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + Kokkos::Profiling::pushRegion(name.c_str()); + Kokkos::RangePolicy policy(0, _c.extent(0)); + Kokkos::parallel_for(name.c_str(), policy, *this); + Kokkos::Profiling::popRegion(); + } +}; + +template +void impl_test_blas_gemm(const int N, const int dimM, const int dimN, + const int dimK) { + using execution_space = typename DeviceType::execution_space; + using transA = typename ParamTagType::transA; + using transB = typename ParamTagType::transB; + using value_type = typename ViewType::value_type; + using ats = Kokkos::Details::ArithTraits; + + const auto transposed_A = !std::is_same::value; + const auto transposed_B = !std::is_same::value; + + const int matAdim1 = transposed_A ? dimM : dimK; + const int matAdim2 = transposed_A ? dimK : dimM; + const int matBdim1 = transposed_B ? dimK : dimN; + const int matBdim2 = transposed_B ? dimN : dimK; + const int matCdim1 = dimM; + const int matCdim2 = dimN; + + /// randomized input testing views + ScalarType alpha = ScalarType(1.5); + ScalarType beta = ScalarType(3.0); + + ViewType a_expected("a_expected", N, matAdim1, matAdim2), + a_actual("a_actual", N, matAdim1, matAdim2), + b_expected("b_expected", N, matBdim1, matBdim2), + b_actual("b_actual", N, matBdim1, matBdim2), + c_expected("c_expected", N, matCdim1, matCdim2), + c_actual("c_actual", N, matCdim1, matCdim2); + + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(a_expected, random, value_type(1.0)); + Kokkos::fill_random(b_expected, random, value_type(1.0)); + Kokkos::fill_random(c_expected, random, value_type(1.0)); + + Kokkos::fence(); + + Kokkos::deep_copy(a_actual, a_expected); + Kokkos::deep_copy(b_actual, b_expected); + Kokkos::deep_copy(c_actual, c_expected); + + Functor_BatchedVanillaGEMM + vgemm; + vgemm.A_t = transposed_A; + vgemm.B_t = transposed_B; + vgemm.A_c = std::is_same::value; + vgemm.B_c = std::is_same::value; + vgemm.A = a_expected; + vgemm.B = b_expected; + vgemm.C = c_expected; + vgemm.alpha = alpha; + vgemm.beta = beta; + vgemm.run(); // Compute c_expected + Functor_TestBlasSerialGemm(alpha, a_actual, b_actual, beta, + c_actual) + .run(); + + typename ViewType::HostMirror c_expected_host = + Kokkos::create_mirror_view(c_expected); + typename ViewType::HostMirror c_actual_host = + Kokkos::create_mirror_view(c_actual); + + // Copy to host for comparison + Kokkos::deep_copy(c_expected_host, c_expected); + Kokkos::deep_copy(c_actual_host, c_actual); + + Kokkos::fence(); + + // check c_expected = c_actual + // std::conditional<, float, + using mag_type = typename ats::mag_type; + mag_type sum(1), diff(0); + + mag_type eps = ats::epsilon(); + + eps *= std::is_same::value || + std::is_same::value + ? 4 + : 1e3; + + for (int k = 0; k < N; ++k) + for (int i = 0; i < matCdim1; ++i) + for (int j = 0; j < matCdim2; ++j) { + sum += ats::abs(c_expected_host(k, i, j)); + diff += ats::abs(c_expected_host(k, i, j) - c_actual_host(k, i, j)); + } + EXPECT_NEAR_KK(diff / sum, 0, eps); +} + +template +int test_blas_gemm() { +#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) + { + typedef Kokkos::View + ViewType; + Test::Gemm::impl_test_blas_gemm(0, 10, 10, 10); + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutLeft, Blksize %d\n", i); + Test::Gemm::impl_test_blas_gemm(1024, i, i, i); + } + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutLeft, Blksize %d\n", i); + int dimM = i; + int dimN = 2 * i; + int dimK = 3 * i; + Test::Gemm::impl_test_blas_gemm(1024, dimM, + dimN, dimK); + } + } +#endif +#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) + { + typedef Kokkos::View + ViewType; + Test::Gemm::impl_test_blas_gemm(0, 10, 10, 10); + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutRight, Blksize %d\n", i); + Test::Gemm::impl_test_blas_gemm(1024, i, i, i); + } + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutLeft, Blksize %d\n", i); + int dimM = i; + int dimN = 2 * i; + int dimK = 3 * i; + Test::Gemm::impl_test_blas_gemm(1024, dimM, + dimN, dimK); + } + } +#endif + + return 0; +} + +#define TEST_SERIAL_CASE2(NAME, VALUE, SCALAR) \ + TEST_GEMM_CASE(serial, NAME, test_blas_gemm, VALUE, SCALAR) +#define TEST_SERIAL_CASE(NAME, VALUE) TEST_SERIAL_CASE2(NAME, VALUE, VALUE) + +#if defined(KOKKOS_BHALF_T_IS_FLOAT) +TEST_SERIAL_CASE(bhalf_bhalf, ::Test::bhalfScalarType) +#endif + +#if defined(KOKKOS_HALF_T_IS_FLOAT) +TEST_SERIAL_CASE(half_half, ::Test::halfScalarType) +#endif + +#if defined(KOKKOSKERNELS_INST_FLOAT) +TEST_SERIAL_CASE(float_float, float) +#endif + +#if defined(KOKKOSKERNELS_INST_DOUBLE) +TEST_SERIAL_CASE(double_double, double) +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) +TEST_SERIAL_CASE(dcomplex_dcomplex, Kokkos::complex) +// TEST_F( TestCategory, blas_scalar_serial_gemm_ct_nt_dcomplex_dcomplex ) { +// typedef ::Test::Gemm::ParamTag +// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; +// test_blas_gemm,Kokkos::complex,param_tag_type,algo_tag_type>(); +// } +// TEST_F( TestCategory, blas_scalar_serial_gemm_nt_ct_dcomplex_dcomplex ) { +// typedef ::Test::Gemm::ParamTag +// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; +// test_blas_gemm,Kokkos::complex,param_tag_type,algo_tag_type>(); +// } +TEST_SERIAL_CASE2(dcomplex_double, Kokkos::complex, double) +// TEST_F( TestCategory, blas_scalar_serial_gemm_ct_nt_dcomplex_double ) { +// typedef ::Test::Gemm::ParamTag +// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; +// test_blas_gemm,double,param_tag_type,algo_tag_type>(); +// } +// TEST_F( TestCategory, blas_scalar_serial_gemm_nt_ct_dcomplex_double ) { +// typedef ::Test::Gemm::ParamTag +// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; +// test_blas_gemm,double,param_tag_type,algo_tag_type>(); +// } +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) +TEST_SERIAL_CASE(fcomplex_fcomplex, Kokkos::complex) +TEST_SERIAL_CASE2(fcomplex_float, Kokkos::complex, float) +#endif + +} // namespace Gemm +} // namespace Test diff --git a/blas/unit_test/Test_Blas3_team_gemm.hpp b/blas/unit_test/Test_Blas3_team_gemm.hpp new file mode 100644 index 0000000000..a35ac0ffcb --- /dev/null +++ b/blas/unit_test/Test_Blas3_team_gemm.hpp @@ -0,0 +1,223 @@ +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Test_Blas3_gemm_util.hpp" + +namespace Test { +namespace Gemm { + +template +struct Functor_TestBlasTeamGemm { + ViewType _a, _b, _c; + + ScalarType _alpha, _beta; + + KOKKOS_INLINE_FUNCTION + Functor_TestBlasTeamGemm(const ScalarType alpha, const ViewType &a, + const ViewType &b, const ScalarType beta, + const ViewType &c) + : _a(a), _b(b), _c(c), _alpha(alpha), _beta(beta) {} + + template + KOKKOS_INLINE_FUNCTION void operator()(const ParamTagType &, + const MemberType &member) const { + const int k = member.league_rank(); + + auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); + auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL()); + + KokkosBlas::TeamGemm::invoke(member, _alpha, aa, bb, _beta, + cc); + } + + inline void run() { + typedef typename ViewType::value_type value_type; + std::string name_region("KokkosBlas::Test::TeamGemm"); + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + Kokkos::Profiling::pushRegion(name.c_str()); + const int league_size = _c.extent(0); + Kokkos::TeamPolicy policy(league_size, + Kokkos::AUTO); + Kokkos::parallel_for(name.c_str(), policy, *this); + Kokkos::Profiling::popRegion(); + } +}; + +template +void impl_test_blas_teamgemm(const int N, const int dimM, const int dimN, + const int dimK) { + using transA = typename ParamTagType::transA; + using transB = typename ParamTagType::transB; + using execution_space = typename DeviceType::execution_space; + using value_type = typename ViewType::value_type; + using ats = Kokkos::Details::ArithTraits; + + const auto transposed_A = !std::is_same::value; + const auto transposed_B = !std::is_same::value; + + const int matAdim1 = transposed_A ? dimM : dimK; + const int matAdim2 = transposed_A ? dimK : dimM; + const int matBdim1 = transposed_B ? dimK : dimN; + const int matBdim2 = transposed_B ? dimN : dimK; + const int matCdim1 = dimM; + const int matCdim2 = dimN; + + /// randomized input testing views + ScalarType alpha = ScalarType(1.5), beta = ScalarType(3.0); + + ViewType a_expected("a_expected", N, matAdim1, matAdim2), + a_actual("a_actual", N, matAdim1, matAdim2), + b_expected("b_expected", N, matBdim1, matBdim2), + b_actual("b_actual", N, matBdim1, matBdim2), + c_expected("c_expected", N, matCdim1, matCdim2), + c_actual("c_actual", N, matCdim1, matCdim2); + + Kokkos::Random_XorShift64_Pool random( + 13718); + + Kokkos::fill_random(a_expected, random, value_type(1.0)); + Kokkos::fill_random(b_expected, random, value_type(1.0)); + Kokkos::fill_random(c_expected, random, value_type(1.0)); + + Kokkos::fence(); + + Kokkos::deep_copy(a_actual, a_expected); + Kokkos::deep_copy(b_actual, b_expected); + Kokkos::deep_copy(c_actual, c_expected); + + Functor_BatchedVanillaGEMM + vgemm; + vgemm.A_t = transposed_A; + vgemm.B_t = transposed_B; + vgemm.A_c = std::is_same::value; + vgemm.B_c = std::is_same::value; + vgemm.A = a_expected; + vgemm.B = b_expected; + vgemm.C = c_expected; + vgemm.alpha = alpha; + vgemm.beta = beta; + vgemm.run(); // Compute c_expected + + Functor_TestBlasTeamGemm(alpha, a_actual, b_actual, beta, + c_actual) + .run(); + + Kokkos::fence(); + + typename ViewType::HostMirror c_expected_host = + Kokkos::create_mirror_view(c_expected); + typename ViewType::HostMirror c_actual_host = + Kokkos::create_mirror_view(c_actual); + + // Copy to host for comparision + Kokkos::deep_copy(c_expected_host, c_expected); + Kokkos::deep_copy(c_actual_host, c_actual); + + using mag_type = typename ats::mag_type; + mag_type sum(1), diff(0); + mag_type eps = ats::epsilon(); + + eps *= std::is_same::value || + std::is_same::value + ? 4 + : 1e3; + + for (int k = 0; k < N; ++k) + for (int i = 0; i < matCdim1; ++i) + for (int j = 0; j < matCdim2; ++j) { + sum += ats::abs(c_expected_host(k, i, j)); + diff += ats::abs(c_expected_host(k, i, j) - c_actual_host(k, i, j)); + } + EXPECT_NEAR_KK(diff / sum, 0, eps); +} + +template +int test_blas_teamgemm() { +#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) + { + typedef Kokkos::View + ViewType; + impl_test_blas_teamgemm(0, 10, 10, 10); + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutLeft, Blksize %d\n", i); + impl_test_blas_teamgemm(1024, i, i, i); + } + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutLeft, Blksize %d\n", i); + int dimM = i; + int dimN = 2 * i; + int dimK = 3 * i; + impl_test_blas_teamgemm(1024, dimM, dimN, dimK); + } + } +#endif +#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) + { + typedef Kokkos::View + ViewType; + impl_test_blas_teamgemm(0, 10, 10, 10); + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutRight, Blksize %d\n", i); + impl_test_blas_teamgemm(1024, i, i, i); + } + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutLeft, Blksize %d\n", i); + int dimM = i; + int dimN = 2 * i; + int dimK = 3 * i; + impl_test_blas_teamgemm(1024, dimM, dimN, dimK); + } + } +#endif + + return 0; +} + +#define TEST_TEAM_CASE2(NAME, VALUE, SCALAR) \ + TEST_GEMM_CASE(team, NAME, test_blas_teamgemm, VALUE, SCALAR) +#define TEST_TEAM_CASE(NAME, VALUE) TEST_TEAM_CASE2(NAME, VALUE, VALUE) + +#if defined(KOKKOS_BHALF_T_IS_FLOAT) +TEST_TEAM_CASE(bhalf_bhalf, ::Test::bhalfScalarType) +#endif + +#if defined(KOKKOS_HALF_T_IS_FLOAT) +TEST_TEAM_CASE(half_half, ::Test::halfScalarType) +#endif + +#if defined(KOKKOSKERNELS_INST_FLOAT) +TEST_TEAM_CASE(float_float, float) +#endif + +#if defined(KOKKOSKERNELS_INST_DOUBLE) +TEST_TEAM_CASE(double_double, double) +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) +TEST_TEAM_CASE(dcomplex_dcomplex, Kokkos::complex) +TEST_TEAM_CASE2(dcomplex_double, Kokkos::complex, double) +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) +TEST_TEAM_CASE(fcomplex_fcomplex, Kokkos::complex) +TEST_TEAM_CASE2(fcomplex_float, Kokkos::complex, float) +#endif + +#undef TEST_TEAM_CASE +#undef TEST_TEAM_CASE2 + +} // namespace Gemm +} // namespace Test diff --git a/blas/unit_test/Test_Blas3_teamvector_gemm.hpp b/blas/unit_test/Test_Blas3_teamvector_gemm.hpp new file mode 100644 index 0000000000..481a2158cc --- /dev/null +++ b/blas/unit_test/Test_Blas3_teamvector_gemm.hpp @@ -0,0 +1,241 @@ +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Test_Blas3_gemm_util.hpp" + +namespace Test { +namespace Gemm { + +template +struct Functor_TestBlasTeamVector { + ViewType _a, _b, _c; + + ScalarType _alpha, _beta; + + KOKKOS_INLINE_FUNCTION + Functor_TestBlasTeamVector(const ScalarType alpha, const ViewType &a, + const ViewType &b, const ScalarType beta, + const ViewType &c) + : _a(a), _b(b), _c(c), _alpha(alpha), _beta(beta) {} + + template + KOKKOS_INLINE_FUNCTION void operator()(const ParamTagType &, + const MemberType &member) const { + const int k = member.league_rank(); + + auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); + auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL()); + + KokkosBlas::TeamVectorGemm::invoke(member, _alpha, aa, bb, + _beta, cc); + } + + inline void run() { + typedef typename ViewType::value_type value_type; + std::string name_region("KokkosBlas::Test::TeamVector"); + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + Kokkos::Profiling::pushRegion(name.c_str()); + const int league_size = _c.extent(0); + Kokkos::TeamPolicy policy(league_size, + Kokkos::AUTO); + Kokkos::parallel_for(name.c_str(), policy, *this); + Kokkos::Profiling::popRegion(); + } +}; + +template +void impl_test_blas_teamvectorgemm(const int N, const int dimM, const int dimN, + const int dimK) { + using transA = typename ParamTagType::transA; + using transB = typename ParamTagType::transB; + using execution_space = typename DeviceType::execution_space; + using value_type = typename ViewType::value_type; + using ats = Kokkos::Details::ArithTraits; + + const auto transposed_A = !std::is_same::value; + const auto transposed_B = !std::is_same::value; + + const int matAdim1 = transposed_A ? dimM : dimK; + const int matAdim2 = transposed_A ? dimK : dimM; + const int matBdim1 = transposed_B ? dimK : dimN; + const int matBdim2 = transposed_B ? dimN : dimK; + const int matCdim1 = dimM; + const int matCdim2 = dimN; + + /// randomized input testing views + ScalarType alpha = ScalarType(1.5), beta = ScalarType(3.0); + + ViewType a_expected("a_expected", N, matAdim1, matAdim2), + a_actual("a_actual", N, matAdim1, matAdim2), + b_expected("b_expected", N, matBdim1, matBdim2), + b_actual("b_actual", N, matBdim1, matBdim2), + c_expected("c_expected", N, matCdim1, matCdim2), + c_actual("c_actual", N, matCdim1, matCdim2); + + Kokkos::Random_XorShift64_Pool random( + 13718); + + Kokkos::fill_random(a_expected, random, value_type(1.0)); + Kokkos::fill_random(b_expected, random, value_type(1.0)); + Kokkos::fill_random(c_expected, random, value_type(1.0)); + + Kokkos::fence(); + + Kokkos::deep_copy(a_actual, a_expected); + Kokkos::deep_copy(b_actual, b_expected); + Kokkos::deep_copy(c_actual, c_expected); + + // Functor_TestBlasTeamVector(alpha, a_expected, b_expected, + // beta, c_expected).run(); + Functor_BatchedVanillaGEMM + vgemm; + vgemm.A_t = transposed_A; + vgemm.B_t = transposed_B; + vgemm.A_c = std::is_same::value; + vgemm.B_c = std::is_same::value; + vgemm.A = a_expected; + vgemm.B = b_expected; + vgemm.C = c_expected; + vgemm.alpha = alpha; + vgemm.beta = beta; + vgemm.run(); // Compute c_expected + + Functor_TestBlasTeamVector(alpha, a_actual, b_actual, beta, + c_actual) + .run(); + + Kokkos::fence(); + + typename ViewType::HostMirror c_expected_host = + Kokkos::create_mirror_view(c_expected); + typename ViewType::HostMirror c_actual_host = + Kokkos::create_mirror_view(c_actual); + + // Copy to host for comparison + Kokkos::deep_copy(c_expected_host, c_expected); + Kokkos::deep_copy(c_actual_host, c_actual); + + using mag_type = typename ats::mag_type; + mag_type sum(1), diff(0); + + mag_type eps = ats::epsilon(); + + eps *= std::is_same::value || + std::is_same::value + ? 4 + : 1e3; + + for (int k = 0; k < N; ++k) + for (int i = 0; i < matCdim1; ++i) + for (int j = 0; j < matCdim2; ++j) { + sum += ats::abs(c_expected_host(k, i, j)); + diff += ats::abs(c_expected_host(k, i, j) - c_actual_host(k, i, j)); + } + EXPECT_NEAR_KK(diff / sum, 0, eps); +} + +template +typename std::enable_if< + !std::is_same::value, int>::type +test_blas_teamvectorgemm() { + // skip algorithms not supported by TeamVectorGemm + return 0; +} + +template +typename std::enable_if::value, + int>::type +test_blas_teamvectorgemm() { +#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) + { + typedef Kokkos::View + ViewType; + impl_test_blas_teamvectorgemm(0, 10, 10, 10); + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutLeft, Blksize %d\n", i); + impl_test_blas_teamvectorgemm(1024, i, i, i); + } + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutLeft, Blksize %d\n", i); + int dimM = i; + int dimN = 2 * i; + int dimK = 3 * i; + impl_test_blas_teamvectorgemm(1024, dimM, dimN, + dimK); + } + } +#endif +#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) + { + typedef Kokkos::View + ViewType; + impl_test_blas_teamvectorgemm(0, 10, 10, 10); + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutRight, Blksize %d\n", i); + impl_test_blas_teamvectorgemm(1024, i, i, i); + } + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutLeft, Blksize %d\n", i); + int dimM = i; + int dimN = 2 * i; + int dimK = 3 * i; + impl_test_blas_teamvectorgemm(1024, dimM, dimN, + dimK); + } + } +#endif + + return 0; +} + +#define TEST_TEAMVECTOR_CASE2(NAME, VALUE, SCALAR) \ + TEST_GEMM_CASE(team_vector, NAME, test_blas_teamvectorgemm, VALUE, SCALAR) +#define TEST_TEAMVECTOR_CASE(NAME, VALUE) \ + TEST_TEAMVECTOR_CASE2(NAME, VALUE, VALUE) + +#if defined(KOKKOS_BHALF_T_IS_FLOAT) +TEST_TEAMVECTOR_CASE(bhalf_bhalf, ::Test::bhalfScalarType) +#endif + +#if defined(KOKKOS_HALF_T_IS_FLOAT) +TEST_TEAMVECTOR_CASE(half_half, ::Test::halfScalarType) +#endif + +#if defined(KOKKOSKERNELS_INST_FLOAT) +TEST_TEAMVECTOR_CASE(float_float, float) +#endif + +#if defined(KOKKOSKERNELS_INST_DOUBLE) +TEST_TEAMVECTOR_CASE(double_double, double) +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) +TEST_TEAMVECTOR_CASE(dcomplex_dcomplex, Kokkos::complex) +TEST_TEAMVECTOR_CASE2(dcomplex_double, Kokkos::complex, double) +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) +TEST_TEAMVECTOR_CASE(fcomplex_fcomplex, Kokkos::complex) +TEST_TEAMVECTOR_CASE2(fcomplex_float, Kokkos::complex, float) +#endif + +#undef TEST_TEAMVECTOR_CASE +#undef TEST_TEAMVECTOR_CASE2 + +} // namespace Gemm +} // namespace Test diff --git a/blas/unit_test/Test_Blas3_threadvector_gemm.hpp b/blas/unit_test/Test_Blas3_threadvector_gemm.hpp new file mode 100644 index 0000000000..f5efec6ba9 --- /dev/null +++ b/blas/unit_test/Test_Blas3_threadvector_gemm.hpp @@ -0,0 +1,244 @@ +#include "Test_Blas3_gemm_util.hpp" + +namespace Test { +namespace Gemm { + +template +struct Functor_TestBatchedThreadVector { + ViewType _a, _b, _c; + + ScalarType _alpha, _beta; + + KOKKOS_INLINE_FUNCTION + Functor_TestBatchedThreadVector(const ScalarType alpha, const ViewType &a, + const ViewType &b, const ScalarType beta, + const ViewType &c) + : _a(a), _b(b), _c(c), _alpha(alpha), _beta(beta) {} + + template + KOKKOS_INLINE_FUNCTION void operator()(const ParamTagType &, + const MemberType &member) const { + const int k = member.league_rank(); + + auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); + auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL()); + + KokkosBlas::ThreadVectorGemm::invoke(member, _alpha, aa, bb, + _beta, cc); + } + + inline void run() { + typedef typename ViewType::value_type value_type; + std::string name_region("KokkosBatched::Test::ThreadVector"); + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + Kokkos::Profiling::pushRegion(name.c_str()); + const int league_size = _c.extent(0); + Kokkos::TeamPolicy policy(league_size, + Kokkos::AUTO); + Kokkos::parallel_for(name.c_str(), policy, *this); + Kokkos::Profiling::popRegion(); + } +}; + +template +void impl_test_batched_threadvectorgemm(const int N, const int dimM, + const int dimN, const int dimK) { + using transA = typename ParamTagType::transA; + using transB = typename ParamTagType::transB; + using execution_space = typename DeviceType::execution_space; + using value_type = typename ViewType::value_type; + using ats = Kokkos::Details::ArithTraits; + + const auto transposed_A = !std::is_same::value; + const auto transposed_B = !std::is_same::value; + + const int matAdim1 = transposed_A ? dimM : dimK; + const int matAdim2 = transposed_A ? dimK : dimM; + const int matBdim1 = transposed_B ? dimK : dimN; + const int matBdim2 = transposed_B ? dimN : dimK; + const int matCdim1 = dimM; + const int matCdim2 = dimN; + + /// randomized input testing views + ScalarType alpha = ScalarType(1.5), beta = ScalarType(3.0); + + ViewType a_expected("a_expected", N, matAdim1, matAdim2), + a_actual("a_actual", N, matAdim1, matAdim2), + b_expected("b_expected", N, matBdim1, matBdim2), + b_actual("b_actual", N, matBdim1, matBdim2), + c_expected("c_expected", N, matCdim1, matCdim2), + c_actual("c_actual", N, matCdim1, matCdim2); + + Kokkos::Random_XorShift64_Pool random( + 13718); + + Kokkos::fill_random(a_expected, random, value_type(1.0)); + Kokkos::fill_random(b_expected, random, value_type(1.0)); + Kokkos::fill_random(c_expected, random, value_type(1.0)); + + Kokkos::fence(); + + Kokkos::deep_copy(a_actual, a_expected); + Kokkos::deep_copy(b_actual, b_expected); + Kokkos::deep_copy(c_actual, c_expected); + + // Functor_TestBatchedThreadVector(alpha, a_expected, b_expected, + // beta, c_expected).run(); + Functor_BatchedVanillaGEMM + vgemm; + vgemm.A_t = transposed_A; + vgemm.B_t = transposed_B; + vgemm.A_c = false; + vgemm.B_c = false; + vgemm.A = a_expected; + vgemm.B = b_expected; + vgemm.C = c_expected; + vgemm.alpha = alpha; + vgemm.beta = beta; + vgemm.run(); // Compute c_expected + + Functor_TestBatchedThreadVector( + alpha, a_actual, b_actual, beta, c_actual) + .run(); + + Kokkos::fence(); + + typename ViewType::HostMirror c_expected_host = + Kokkos::create_mirror_view(c_expected); + typename ViewType::HostMirror c_actual_host = + Kokkos::create_mirror_view(c_actual); + + // Copy to host for comparison + Kokkos::deep_copy(c_expected_host, c_expected); + Kokkos::deep_copy(c_actual_host, c_actual); + + using mag_type = typename ats::mag_type; + mag_type sum(1), diff(0); + + mag_type eps = ats::epsilon(); + + eps *= std::is_same::value || + std::is_same::value + ? 4 + : 1e3; + + for (int k = 0; k < N; ++k) + for (int i = 0; i < matCdim1; ++i) + for (int j = 0; j < matCdim2; ++j) { + sum += ats::abs(c_expected_host(k, i, j)); + diff += ats::abs(c_expected_host(k, i, j) - c_actual_host(k, i, j)); + } + EXPECT_NEAR_KK(diff / sum, 0, eps); +} + +template +typename std::enable_if< + !std::is_same::value, int>::type +test_batched_threadvectorgemm() { + // skip algorithms not supported by ThreadVectorGemm + return 0; +} + +template +typename std::enable_if::value, + int>::type +test_batched_threadvectorgemm() { +#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) + { + typedef Kokkos::View + ViewType; + impl_test_batched_threadvectorgemm(0, 10, 10, + 10); + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutLeft, Blksize %d\n", i); + impl_test_batched_threadvectorgemm(1024, i, i, + i); + } + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutLeft, Blksize %d\n", i); + int dimM = i; + int dimN = 2 * i; + int dimK = 3 * i; + impl_test_batched_threadvectorgemm(1024, dimM, + dimN, dimK); + } + } +#endif +#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) + { + typedef Kokkos::View + ViewType; + impl_test_batched_threadvectorgemm(0, 10, 10, + 10); + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutRight, Blksize %d\n", i); + impl_test_batched_threadvectorgemm(1024, i, i, + i); + } + for (int i = 0; i < 10; ++i) { + // printf("Testing: LayoutLeft, Blksize %d\n", i); + int dimM = i; + int dimN = 2 * i; + int dimK = 3 * i; + impl_test_batched_threadvectorgemm(1024, dimM, + dimN, dimK); + } + } +#endif + + return 0; +} + +#define TEST_THREADVECTOR_CASE2(NAME, VALUE, SCALAR) \ + TEST_GEMM_CASE(thread_vector, NAME, test_batched_threadvectorgemm, VALUE, \ + SCALAR) +#define TEST_THREADVECTOR_CASE(NAME, VALUE) \ + TEST_THREADVECTOR_CASE2(NAME, VALUE, VALUE) + +#if defined(KOKKOS_BHALF_T_IS_FLOAT) +TEST_THREADVECTOR_CASE(bhalf_bhalf, ::Test::bhalfScalarType) +#endif + +#if defined(KOKKOS_HALF_T_IS_FLOAT) +TEST_THREADVECTOR_CASE(half_half, ::Test::halfScalarType) +#endif + +#if defined(KOKKOSKERNELS_INST_FLOAT) +TEST_THREADVECTOR_CASE(float_float, float) +#endif + +#if defined(KOKKOSKERNELS_INST_DOUBLE) +TEST_THREADVECTOR_CASE(double_double, double) +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) +TEST_THREADVECTOR_CASE(dcomplex_dcomplex, Kokkos::complex) +TEST_THREADVECTOR_CASE2(dcomplex_double, Kokkos::complex, double) +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) +TEST_THREADVECTOR_CASE(fcomplex_fcomplex, Kokkos::complex) +TEST_THREADVECTOR_CASE2(fcomplex_float, Kokkos::complex, float) +#endif + +#undef TEST_THREADVECTOR_CASE +#undef TEST_THREADVECTOR_CASE2 + +} // namespace Gemm +} // namespace Test diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 72972b5cd7..c566747d36 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -1,7 +1,7 @@ # Adding source directory to the build LIST(APPEND KK_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/common/src) -LIST(APPEND KK_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/common/src/impl) -LIST(APPEND KK_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/common/src/tpls) +LIST(APPEND KK_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/common/impl) +LIST(APPEND KK_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/common/tpls) LIST(APPEND KK_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/common/unit_test) # Adding unit-tests diff --git a/common/impl/KokkosKernels_MKLUtils.hpp b/common/impl/KokkosKernels_MKLUtils.hpp new file mode 100644 index 0000000000..8c29ce456a --- /dev/null +++ b/common/impl/KokkosKernels_MKLUtils.hpp @@ -0,0 +1,109 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef _KOKKOSKERNELS_UTILS_MKL_HPP +#define _KOKKOSKERNELS_UTILS_MKL_HPP + +#include "KokkosKernels_config.h" + +#ifdef KOKKOSKERNELS_ENABLE_TPL_MKL + +#include + +#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) +#include "mkl_version.h" +#if __INTEL_MKL__ >= 2018 +#define __KOKKOSBLAS_ENABLE_INTEL_MKL_COMPACT__ 1 +#include "mkl_compact.h" +#endif +#endif + +namespace KokkosKernels { +namespace Impl { + +inline void mkl_internal_safe_call(sparse_status_t mkl_status, const char *name, + const char *file = nullptr, + const int line = 0) { + if (SPARSE_STATUS_SUCCESS != mkl_status) { + std::ostringstream oss; + oss << "MKL call \"" << name << "\" at " << file << ":" << line + << " encountered error: "; + switch (mkl_status) { + case SPARSE_STATUS_NOT_INITIALIZED: + oss << "SPARSE_STATUS_NOT_INITIALIZED (empty handle or matrix arrays)"; + break; + case SPARSE_STATUS_ALLOC_FAILED: + oss << "SPARSE_STATUS_ALLOC_FAILED (internal error: memory allocation " + "failed)"; + break; + case SPARSE_STATUS_INVALID_VALUE: + oss << "SPARSE_STATUS_INVALID_VALUE (invalid input value)"; + break; + case SPARSE_STATUS_EXECUTION_FAILED: + oss << "SPARSE_STATUS_EXECUTION_FAILED (e.g. 0-diagonal element for " + "triangular solver)"; + break; + case SPARSE_STATUS_INTERNAL_ERROR: + oss << "SPARSE_STATUS_INTERNAL_ERROR"; + break; + case SPARSE_STATUS_NOT_SUPPORTED: + oss << "SPARSE_STATUS_NOT_SUPPORTED (e.g. operation for double " + "precision doesn't support other types)"; + break; + default: oss << "unknown (code " << (int)mkl_status << ")"; break; + } + oss << '\n'; + Kokkos::abort(oss.str().c_str()); + } +} + +#define KOKKOSKERNELS_MKL_SAFE_CALL(call) \ + KokkosKernels::Impl::mkl_internal_safe_call(call, #call, __FILE__, __LINE__) + +} // namespace Impl +} // namespace KokkosKernels + +#endif // KOKKOSKERNELS_ENABLE_TPL_MKL + +#endif // _KOKKOSKERNELS_UTILS_MKL_HPP \ No newline at end of file diff --git a/common/src/KokkosKernels_BlockUtils.hpp b/common/src/KokkosKernels_BlockUtils.hpp index 0c001ce115..00402a1ab1 100644 --- a/common/src/KokkosKernels_BlockUtils.hpp +++ b/common/src/KokkosKernels_BlockUtils.hpp @@ -46,7 +46,7 @@ // #include // #include -#include "KokkosBatched_Gemm_Serial_Internal.hpp" +#include "KokkosBlas3_gemm.hpp" namespace KokkosSparse { namespace Impl { @@ -85,8 +85,8 @@ KOKKOS_INLINE_FUNCTION void kk_block_add(const size_type block_dim, // Note: block is assumed to be row-major, dense matrix (no extra padding) // Note: set clear=true to set C = 0 before increment template > + typename DGEMM = KokkosBlas::Impl::SerialGemmInternal< + KokkosBlas::Algo::Gemm::Unblocked>> KOKKOS_INLINE_FUNCTION void kk_block_dgemm(const size_type block_dim, value_type *dst, const value_type *valA, diff --git a/docs/developer/apidocs/batched_dense.rst b/docs/developer/apidocs/batched_dense.rst index 1d65842061..4d98d3cdc1 100644 --- a/docs/developer/apidocs/batched_dense.rst +++ b/docs/developer/apidocs/batched_dense.rst @@ -247,11 +247,11 @@ trsv gemm ---- -.. doxygenstruct:: KokkosBatched::SerialGemm +.. doxygenstruct:: KokkosBlas::SerialGemm :members: -.. doxygenstruct:: KokkosBatched::TeamGemm +.. doxygenstruct:: KokkosBlas::TeamGemm :members: -.. doxygenstruct:: KokkosBatched::TeamVectorGemm +.. doxygenstruct:: KokkosBlas::TeamVectorGemm :members: -.. doxygenstruct:: KokkosBatched::Gemm +.. doxygenstruct:: KokkosBlas::Gemm :members: \ No newline at end of file diff --git a/perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp b/perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp index d1855573e4..f2b488e96b 100644 --- a/perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp +++ b/perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp @@ -54,9 +54,6 @@ #include #include "KokkosBatched_Gemm_Decl.hpp" -#include "KokkosBatched_Gemm_Serial_Impl.hpp" -//#include "KokkosBatched_Gemm_Team_Impl.hpp" -//#include "KokkosBatched_Gemm_TeamVector_Impl.hpp" #include "KokkosBatched_Util.hpp" #include "gtest/gtest.h" // EXPECT_NEAR #include "KokkosKernels_TestUtils.hpp" @@ -418,7 +415,7 @@ void __do_gemm_serial_batched_template(options_t options, C = Kokkos::subview(_gemm_args.C, Kokkos::ALL(), Kokkos::ALL(), j); } - KokkosBatched::SerialGemm::invoke( + KokkosBlas::SerialGemm::invoke( _gemm_args.alpha, A, B, _gemm_args.beta, C); } } @@ -615,7 +612,7 @@ struct parallel_batched_gemm_range_policy { auto svB = Kokkos::subview(gemm_args_.B, i, Kokkos::ALL(), Kokkos::ALL()); auto svC = Kokkos::subview(gemm_args_.C, i, Kokkos::ALL(), Kokkos::ALL()); - KokkosBatched::SerialGemm::invoke( + KokkosBlas::SerialGemm::invoke( gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } @@ -625,7 +622,7 @@ struct parallel_batched_gemm_range_policy { auto svB = Kokkos::subview(gemm_args_.B, Kokkos::ALL(), Kokkos::ALL(), i); auto svC = Kokkos::subview(gemm_args_.C, Kokkos::ALL(), Kokkos::ALL(), i); - KokkosBatched::SerialGemm::invoke( + KokkosBlas::SerialGemm::invoke( gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } @@ -638,7 +635,7 @@ struct parallel_batched_gemm_range_policy { auto svC = Kokkos::subview(gemm_args_.Cv.vec_3d, i, Kokkos::ALL(), Kokkos::ALL()); - KokkosBatched::SerialGemm::invoke( + KokkosBlas::SerialGemm::invoke( gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } @@ -651,7 +648,7 @@ struct parallel_batched_gemm_range_policy { auto svC = Kokkos::subview(gemm_args_.Cv.vec_3d, Kokkos::ALL(), Kokkos::ALL(), i); - KokkosBatched::SerialGemm::invoke( + KokkosBlas::SerialGemm::invoke( gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } @@ -700,7 +697,7 @@ struct parallel_batched_gemm { auto svB = Kokkos::subview(gemm_args_.B, i, Kokkos::ALL(), Kokkos::ALL()); auto svC = Kokkos::subview(gemm_args_.C, i, Kokkos::ALL(), Kokkos::ALL()); - KokkosBatched::SerialGemm::invoke( + KokkosBlas::SerialGemm::invoke( gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } @@ -711,7 +708,7 @@ struct parallel_batched_gemm { auto svB = Kokkos::subview(gemm_args_.B, Kokkos::ALL(), Kokkos::ALL(), i); auto svC = Kokkos::subview(gemm_args_.C, Kokkos::ALL(), Kokkos::ALL(), i); - KokkosBatched::SerialGemm::invoke( + KokkosBlas::SerialGemm::invoke( gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } @@ -722,9 +719,8 @@ struct parallel_batched_gemm { auto svB = Kokkos::subview(gemm_args_.B, i, Kokkos::ALL(), Kokkos::ALL()); auto svC = Kokkos::subview(gemm_args_.C, i, Kokkos::ALL(), Kokkos::ALL()); - KokkosBatched::TeamGemm::invoke(member, gemm_args_.alpha, svA, - svB, gemm_args_.beta, svC); + KokkosBlas::TeamGemm::invoke( + member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } KOKKOS_INLINE_FUNCTION @@ -734,9 +730,8 @@ struct parallel_batched_gemm { auto svB = Kokkos::subview(gemm_args_.B, Kokkos::ALL(), Kokkos::ALL(), i); auto svC = Kokkos::subview(gemm_args_.C, Kokkos::ALL(), Kokkos::ALL(), i); - KokkosBatched::TeamGemm::invoke(member, gemm_args_.alpha, svA, - svB, gemm_args_.beta, svC); + KokkosBlas::TeamGemm::invoke( + member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } KOKKOS_INLINE_FUNCTION @@ -749,11 +744,8 @@ struct parallel_batched_gemm { auto svC = Kokkos::subview(gemm_args_.C, team_idx, Kokkos::ALL(), Kokkos::ALL()); - KokkosBatched::TeamVectorGemm::invoke(member, - gemm_args_.alpha, svA, - svB, gemm_args_.beta, - svC); + KokkosBlas::TeamVectorGemm::invoke( + member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } KOKKOS_INLINE_FUNCTION @@ -767,11 +759,8 @@ struct parallel_batched_gemm { auto svC = Kokkos::subview(gemm_args_.C, Kokkos::ALL(), Kokkos::ALL(), team_idx); - KokkosBatched::TeamVectorGemm::invoke(member, - gemm_args_.alpha, svA, - svB, gemm_args_.beta, - svC); + KokkosBlas::TeamVectorGemm::invoke( + member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } KOKKOS_INLINE_FUNCTION @@ -787,10 +776,9 @@ struct parallel_batched_gemm { auto svC = Kokkos::subview(gemm_args_.Cv.ivec_4d, i, Kokkos::ALL(), Kokkos::ALL(), vector_lane); - KokkosBatched::Gemm::invoke(member, gemm_args_.alpha, - svA, svB, gemm_args_.beta, - svC); + KokkosBlas::Gemm::invoke(member, gemm_args_.alpha, svA, + svB, gemm_args_.beta, svC); }); } @@ -808,10 +796,9 @@ struct parallel_batched_gemm { auto svC = Kokkos::subview(gemm_args_.Cv.ivec_4d, vector_lane, Kokkos::ALL(), Kokkos::ALL(), i); - KokkosBatched::Gemm::invoke(member, gemm_args_.alpha, - svA, svB, gemm_args_.beta, - svC); + KokkosBlas::Gemm::invoke(member, gemm_args_.alpha, svA, + svB, gemm_args_.beta, svC); }); } @@ -956,7 +943,7 @@ void __do_gemm_parallel_batched(options_t options, gemm_args_t gemm_args) { char b = gemm_args.transB; using N = KokkosBatched::Trans::NoTranspose; using T = KokkosBatched::Trans::Transpose; - // using C = KokkosBatched::Trans::ConjTranspose; + using C = KokkosBatched::Trans::ConjTranspose; STATUS; @@ -968,9 +955,10 @@ void __do_gemm_parallel_batched(options_t options, gemm_args_t gemm_args) { __do_gemm_parallel_batched_template(options, gemm_args); - //} else if (a == 'N' && b == 'C') { - // __do_gemm_parallel_batched_template(options, gemm_args); + } else if (a == 'N' && b == 'C') { + __do_gemm_parallel_batched_template(options, + gemm_args); } else if (a == 'T' && b == 'N') { __do_gemm_parallel_batched_template(options, @@ -979,18 +967,22 @@ void __do_gemm_parallel_batched(options_t options, gemm_args_t gemm_args) { __do_gemm_parallel_batched_template(options, gemm_args); - //} else if (a == 'T' && b == 'C') { - // __do_gemm_parallel_batched_template(options, gemm_args); - //} else if (a == 'C' && b == 'N') { - // __do_gemm_parallel_batched_template(options, gemm_args); - //} else if (a == 'C' && b == 'T') { - // __do_gemm_parallel_batched_template(options, gemm_args); - //} else if (a == 'C' && b == 'C') { - // __do_gemm_parallel_batched_template(options, gemm_args); + } else if (a == 'T' && b == 'C') { + __do_gemm_parallel_batched_template(options, + gemm_args); + } else if (a == 'C' && b == 'N') { + __do_gemm_parallel_batched_template(options, + gemm_args); + } else if (a == 'C' && b == 'T') { + __do_gemm_parallel_batched_template(options, + gemm_args); + } else if (a == 'C' && b == 'C') { + __do_gemm_parallel_batched_template(options, + gemm_args); } else { FATAL_ERROR("Bad gemm_args TransA or TransB value"); } @@ -1013,7 +1005,7 @@ struct parallel_batched_gemm_experiment1 { auto svC = Kokkos::subview(gemm_args_.C, i, Kokkos::ALL(), Kokkos::ALL()); // Uses two serial for-loops internally - KokkosBatched::SerialGemm::invoke( + KokkosBlas::SerialGemm::invoke( gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } }; @@ -1073,11 +1065,8 @@ struct parallel_batched_gemm_experiment2_3_4 { // Uses TeamThreadRange over C-rows // ThreadVectorRange over C-cols - KokkosBatched::TeamVectorGemm::invoke(member, - gemm_args_.alpha, svA, - svB, gemm_args_.beta, - svC); + KokkosBlas::TeamVectorGemm::invoke( + member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } // Experiment 3 @@ -1104,12 +1093,8 @@ struct parallel_batched_gemm_experiment2_3_4 { auto svC_col = Kokkos::subview(svC, Kokkos::ALL(), lane_idx); // TeamGemm Calls TeamThreadRange over M*N meaning the flat M*N array // is split over all threads of the team - KokkosBatched::TeamGemm::invoke(member, - gemm_args_.alpha, svA, - svB_col, - gemm_args_.beta, - svC_col); + KokkosBlas::TeamGemm::invoke( + member, gemm_args_.alpha, svA, svB_col, gemm_args_.beta, svC_col); }); } @@ -1138,12 +1123,8 @@ struct parallel_batched_gemm_experiment2_3_4 { auto svC_row = Kokkos::subview(svC, lane_idx, Kokkos::ALL()); // TeamGemm Calls TeamThreadRange over M*N meaning the flat M*N array // is split over all threads of the team - KokkosBatched::TeamGemm::invoke(member, - gemm_args_.alpha, - svA_row, svB, - gemm_args_.beta, - svC_row); + KokkosBlas::TeamGemm::invoke( + member, gemm_args_.alpha, svA_row, svB, gemm_args_.beta, svC_row); }); } }; @@ -1316,7 +1297,7 @@ class parallel_batched_gemm_experiment5 { auto svC = Kokkos::subview(C, i, Kokkos::ALL(), Kokkos::ALL()); // Uses two serial for-loops internally - KokkosBatched::SerialGemm::invoke( + KokkosBlas::SerialGemm::invoke( gemm_args.alpha, svA, svB, gemm_args.beta, svC); } }; @@ -1424,10 +1405,8 @@ class parallel_batched_gemm_experiment6 { auto svC = Kokkos::subview(C, i, Kokkos::ALL(), Kokkos::ALL()); // Uses two serial for-loops internally - KokkosBatched::TeamVectorGemm::invoke(member, gemm_args.alpha, - svA, svB, - gemm_args.beta, svC); + KokkosBlas::TeamVectorGemm::invoke( + member, gemm_args.alpha, svA, svB, gemm_args.beta, svC); } }; @@ -1868,9 +1847,10 @@ static inline void __gemm_do_verify(options_t options, gemm_args_t gemm_args, Test::Functor_BatchedVanillaGEMM vgemm; - vgemm.A_t = toupper(gemm_args.transA) == 'T'; - vgemm.B_t = toupper(gemm_args.transB) == 'T'; - vgemm.A_c = vgemm.B_c = false; + vgemm.A_t = toupper(gemm_args.transA) != 'N'; + vgemm.B_t = toupper(gemm_args.transB) != 'N'; + vgemm.A_c = toupper(gemm_args.transA) == 'C'; + vgemm.B_c = toupper(gemm_args.transB) == 'C'; vgemm.batch_size_last_dim = options.blas_args.batch_size_last_dim; vgemm.A = A_expected; vgemm.B = B_expected; diff --git a/perf_test/sparse/KokkosSparse_spadd.cpp b/perf_test/sparse/KokkosSparse_spadd.cpp index 5448843168..efb3e32ffa 100644 --- a/perf_test/sparse/KokkosSparse_spadd.cpp +++ b/perf_test/sparse/KokkosSparse_spadd.cpp @@ -56,7 +56,6 @@ #endif #ifdef KOKKOSKERNELS_ENABLE_TPL_MKL -#include #include #endif diff --git a/sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp b/sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp index 312ba22f8a..984633d950 100644 --- a/sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp +++ b/sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp @@ -44,7 +44,7 @@ #ifndef KOKKOSSPARSE_BSPGEMM_DEBUG_HPP_ #define KOKKOSSPARSE_BSPGEMM_DEBUG_HPP_ #include "KokkosKernels_helpers.hpp" -#include "KokkosBatched_Gemm_Serial_Internal.hpp" +#include "KokkosBlas3_gemm.hpp" #include namespace KokkosSparse { @@ -113,8 +113,8 @@ void bspgemm_debug_numeric(KernelHandle* /* handle */, typedef typename KernelHandle::nnz_lno_t lno_t; typedef typename KernelHandle::size_type size_type; typedef typename KernelHandle::nnz_scalar_t scalar_t; - typedef KokkosBatched::SerialGemmInternal< - KokkosBatched::Algo::Gemm::Unblocked> + typedef KokkosBlas::Impl::SerialGemmInternal< + KokkosBlas::Algo::Gemm::Unblocked> GEMM; const auto block_size = block_dim * block_dim; diff --git a/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp b/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp index a0bf8c96ec..cb94d9e4e9 100644 --- a/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp +++ b/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp @@ -47,6 +47,8 @@ #include "KokkosKernels_Error.hpp" #include "KokkosKernels_ExecSpaceUtils.hpp" +#include "KokkosBlas3_serial_gemm_internal.hpp" +#include "KokkosBlas3_team_gemm_internal.hpp" #if defined(KOKKOS_ENABLE_CUDA) && \ (defined(KOKKOS_ARCH_VOLTA) || defined(KOKKOS_ARCH_AMPERE)) @@ -539,9 +541,8 @@ struct BsrMatrixSpMVTensorCoreDispatcher { #include "KokkosBlas.hpp" #include "KokkosBlas2_serial_gemv_internal.hpp" -#include "KokkosBlas2_team_gemv_impl.hpp" -#include "KokkosBatched_Gemm_Serial_Internal.hpp" -#include "KokkosBatched_Gemm_TeamVector_Internal.hpp" +#include "KokkosBlas2_team_gemv_internal.hpp" +#include "KokkosBlas3_gemm.hpp" #include "KokkosBlas1_team_scal_impl.hpp" #include "KokkosKernels_ExecSpaceUtils.hpp" @@ -947,7 +948,7 @@ struct BSR_GEMV_Transpose_Functor { const auto A_cur = myRow.block(jBlock); // KokkosBlas::TeamVectorGemv< - team_member, KokkosBlas::Trans::ConjTranspose, + KokkosBlas::Trans::ConjTranspose, KokkosBlas::Algo::Gemv::Default>::invoke(dev, alpha, A_cur, X_cur, val_zero, shared_view); // @@ -1238,7 +1239,7 @@ struct BSR_GEMM_Functor { for (ordinal_type ic = 0; ic < count; ++ic) { const auto Aview = row.block(ic); const auto xstart = row.block_colidx(ic) * block_dim; - KokkosBatched::SerialGemmInternal:: + KokkosBlas::Impl::SerialGemmInternal:: invoke( static_cast(block_dim), static_cast(num_rhs), @@ -1281,17 +1282,17 @@ struct BSR_GEMM_Functor { const auto X_cur = Kokkos::subview( m_x, ::Kokkos::make_pair(X_ptBeg, X_ptBeg + block_dim), Kokkos::ALL()); - KokkosBatched::TeamVectorGemmInternal< - KokkosBatched::Algo::Gemm::Unblocked, - true>::invoke(dev, static_cast(block_dim), - static_cast(num_rhs), - static_cast(block_dim), alpha, A_cur.data(), - static_cast(A_cur.stride_0()), - static_cast(A_cur.stride_1()), X_cur.data(), - static_cast(X_cur.stride_0()), - static_cast(X_cur.stride_1()), val_one, - Y_cur.data(), static_cast(Y_cur.stride_0()), - static_cast(Y_cur.stride_1())); + KokkosBlas::Impl:: + TeamVectorGemmInternal::invoke( + KokkosBlas::Impl::OpConj{}, KokkosBlas::Impl::OpID{}, dev, + static_cast(block_dim), static_cast(num_rhs), + static_cast(block_dim), alpha, A_cur.data(), + static_cast(A_cur.stride_0()), + static_cast(A_cur.stride_1()), X_cur.data(), + static_cast(X_cur.stride_0()), + static_cast(X_cur.stride_1()), val_one, Y_cur.data(), + static_cast(Y_cur.stride_0()), + static_cast(Y_cur.stride_1())); } } else { for (ordinal_type jBlock = 0; jBlock < count; ++jBlock) { @@ -1301,15 +1302,15 @@ struct BSR_GEMM_Functor { const auto X_cur = Kokkos::subview( m_x, ::Kokkos::make_pair(X_ptBeg, X_ptBeg + block_dim), Kokkos::ALL()); - KokkosBatched::TeamVectorGemmInternal< - KokkosBatched::Algo::Gemm::Unblocked, - false>::invoke(dev, block_dim, num_rhs, block_dim, alpha, - A_cur.data(), static_cast(A_cur.stride_0()), - static_cast(A_cur.stride_1()), X_cur.data(), - static_cast(X_cur.stride_0()), - static_cast(X_cur.stride_1()), val_one, - Y_cur.data(), static_cast(Y_cur.stride_0()), - static_cast(Y_cur.stride_1())); + KokkosBlas::Impl:: + TeamVectorGemmInternal::invoke( + dev, block_dim, num_rhs, block_dim, alpha, A_cur.data(), + static_cast(A_cur.stride_0()), + static_cast(A_cur.stride_1()), X_cur.data(), + static_cast(X_cur.stride_0()), + static_cast(X_cur.stride_1()), val_one, Y_cur.data(), + static_cast(Y_cur.stride_0()), + static_cast(Y_cur.stride_1())); } } } @@ -1582,14 +1583,15 @@ struct BSR_GEMM_Transpose_Functor { for (ordinal_type jBlock = 0; jBlock < count; ++jBlock) { const auto A_cur = myRow.block(jBlock); // - KokkosBatched::TeamVectorGemmInternal< - KokkosBatched::Algo::Gemm::Unblocked, - true>::invoke(dev, block_dim, num_rhs, block_dim, alpha, - A_cur.data(), static_cast(A_cur.stride_1()), - static_cast(A_cur.stride_0()), X_cur.data(), - static_cast(X_cur.stride_0()), - static_cast(X_cur.stride_1()), val_zero, - shared_y, 1, block_dim); + KokkosBlas::Impl:: + TeamVectorGemmInternal::invoke( + KokkosBlas::Impl::OpConj{}, KokkosBlas::Impl::OpID{}, dev, + block_dim, num_rhs, block_dim, alpha, A_cur.data(), + static_cast(A_cur.stride_1()), + static_cast(A_cur.stride_0()), X_cur.data(), + static_cast(X_cur.stride_0()), + static_cast(X_cur.stride_1()), val_zero, shared_y, 1, + block_dim); // dev.team_barrier(); // @@ -1614,14 +1616,14 @@ struct BSR_GEMM_Transpose_Functor { for (ordinal_type jBlock = 0; jBlock < count; ++jBlock) { const auto A_cur = myRow.block(jBlock); // - KokkosBatched::TeamVectorGemmInternal< - KokkosBatched::Algo::Gemm::Unblocked, - false>::invoke(dev, block_dim, num_rhs, block_dim, alpha, - A_cur.data(), static_cast(A_cur.stride_1()), - static_cast(A_cur.stride_0()), X_cur.data(), - static_cast(X_cur.stride_0()), - static_cast(X_cur.stride_1()), val_zero, - shared_y, 1, block_dim); + KokkosBlas::Impl:: + TeamVectorGemmInternal::invoke( + dev, block_dim, num_rhs, block_dim, alpha, A_cur.data(), + static_cast(A_cur.stride_1()), + static_cast(A_cur.stride_0()), X_cur.data(), + static_cast(X_cur.stride_0()), + static_cast(X_cur.stride_1()), val_zero, shared_y, 1, + block_dim); // dev.team_barrier(); // diff --git a/sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp b/sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp index 0e3ed0235a..6b32b6ced0 100644 --- a/sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp +++ b/sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp @@ -62,7 +62,7 @@ #include "KokkosBatched_Util.hpp" -#include "KokkosBlas2_team_gemv_spec.hpp" +#include "KokkosBlas2_gemv.hpp" #include "KokkosBatched_Trsm_Team_Impl.hpp" #endif @@ -894,7 +894,7 @@ struct LowerTriSupernodalFunctor { workoffset, workoffset + nsrow)); // needed for gemv instead of trmv/trsv auto Ljj = Kokkos::subview(viewL, range_type(0, nsrow), Kokkos::ALL()); - KokkosBlas::TeamGemv::invoke(team, one, Ljj, Xj, @@ -922,7 +922,7 @@ struct LowerTriSupernodalFunctor { team.team_barrier(); // calling team-level "Unblocked" gemv on small-size diagonal in // KokkosBatched - KokkosBlas::TeamGemv::invoke(team, one, Ljj, @@ -955,7 +955,7 @@ struct LowerTriSupernodalFunctor { /* GEMM to update with off diagonal blocks */ auto Lij = Kokkos::subview(viewL, range_type(nscol, nsrow), Kokkos::ALL()); - KokkosBlas::TeamGemv::invoke(team, one, Lij, Xj, @@ -1081,8 +1081,7 @@ struct UpperTriSupernodalFunctor { SupernodeView viewU(&dataU[i1], nsrow, nscol); // extract part of solution, corresponding to the diagonal block U(s, s) - auto Xj = Kokkos::subview(X, range_type(j1, j2)); - using Xj_type = decltype(Xj); + auto Xj = Kokkos::subview(X, range_type(j1, j2)); // workspaces int workoffset = work_offset(s); @@ -1095,7 +1094,6 @@ struct UpperTriSupernodalFunctor { work, range_type(workoffset + nscol, workoffset + nsrow)); // needed with gemv for update&scatter - using Z_type = decltype(Z); for (int ii = team_rank; ii < nsrow2; ii += team_size) { int i = rowind(i2 + ii); Z(ii) = X(i); @@ -1106,17 +1104,16 @@ struct UpperTriSupernodalFunctor { // not device-level GEMV-udpate auto Uij = Kokkos::subview(viewU, range_type(nscol, nsrow), Kokkos::ALL()); - using Uij_type = decltype(Uij); - KokkosBlas::TeamGemv:: - template invoke( - team, -one, Uij, Z, one, Xj); + KokkosBlas::TeamGemv::invoke(team, + -one, Uij, + Z, one, + Xj); team.team_barrier(); /* TRSM with diagonal block */ // extract diagonal and off-diagonal blocks of U auto Ujj = Kokkos::subview(viewU, range_type(0, nscol), Kokkos::ALL()); - using Ujj_type = decltype(Ujj); if (invert_diagonal) { // workspace @@ -1125,17 +1122,18 @@ struct UpperTriSupernodalFunctor { range_type( workoffset, workoffset + nscol)); // needed for gemv instead of trmv/trsv - using Y_type = decltype(Y); for (int ii = team_rank; ii < nscol; ii += team_size) { Y(ii) = Xj(ii); } team.team_barrier(); // caling team-level kernel in KokkosBatched on a small-size diagonal - KokkosBlas::TeamGemv:: - template invoke( - team, one, Ujj, Y, zero, Xj); + KokkosBlas::TeamGemv::invoke(team, + one, + Ujj, Y, + zero, + Xj); } else { // NOTE: we currently supports only default_layout = LayoutLeft Kokkos::View::invoke(team, one, Uij, Xj, @@ -1295,7 +1293,7 @@ struct UpperTriTranSupernodalFunctor { Y(ii) = Xj(ii); } team.team_barrier(); - KokkosBlas::TeamGemv::invoke(team, one, Ujj, @@ -1325,7 +1323,7 @@ struct UpperTriTranSupernodalFunctor { // not device-level TRSM-solve auto Uij = Kokkos::subview(viewU, range_type(nscol, nsrow), Kokkos::ALL()); - KokkosBlas::TeamGemv::invoke(team, one, Uij, Xj, diff --git a/sparse/src/KokkosSparse_Utils_mkl.hpp b/sparse/src/KokkosSparse_Utils_mkl.hpp index b9eb3a9bd2..36a2ddc005 100644 --- a/sparse/src/KokkosSparse_Utils_mkl.hpp +++ b/sparse/src/KokkosSparse_Utils_mkl.hpp @@ -46,53 +46,13 @@ #define _KOKKOSKERNELS_SPARSEUTILS_MKL_HPP #include "KokkosKernels_config.h" +#include "KokkosKernels_MKLUtils.hpp" #ifdef KOKKOSKERNELS_ENABLE_TPL_MKL -#include - namespace KokkosSparse { namespace Impl { -inline void mkl_internal_safe_call(sparse_status_t mkl_status, const char *name, - const char *file = nullptr, - const int line = 0) { - if (SPARSE_STATUS_SUCCESS != mkl_status) { - std::ostringstream oss; - oss << "MKL call \"" << name << "\" at " << file << ":" << line - << " encountered error: "; - switch (mkl_status) { - case SPARSE_STATUS_NOT_INITIALIZED: - oss << "SPARSE_STATUS_NOT_INITIALIZED (empty handle or matrix arrays)"; - break; - case SPARSE_STATUS_ALLOC_FAILED: - oss << "SPARSE_STATUS_ALLOC_FAILED (internal error: memory allocation " - "failed)"; - break; - case SPARSE_STATUS_INVALID_VALUE: - oss << "SPARSE_STATUS_INVALID_VALUE (invalid input value)"; - break; - case SPARSE_STATUS_EXECUTION_FAILED: - oss << "SPARSE_STATUS_EXECUTION_FAILED (e.g. 0-diagonal element for " - "triangular solver)"; - break; - case SPARSE_STATUS_INTERNAL_ERROR: - oss << "SPARSE_STATUS_INTERNAL_ERROR"; - break; - case SPARSE_STATUS_NOT_SUPPORTED: - oss << "SPARSE_STATUS_NOT_SUPPORTED (e.g. operation for double " - "precision doesn't support other types)"; - break; - default: oss << "unknown (code " << (int)mkl_status << ")"; break; - } - oss << '\n'; - Kokkos::abort(oss.str().c_str()); - } -} - -#define KOKKOSKERNELS_MKL_SAFE_CALL(call) \ - KokkosSparse::Impl::mkl_internal_safe_call(call, #call, __FILE__, __LINE__) - inline sparse_operation_t mode_kk_to_mkl(char mode_kk) { switch (toupper(mode_kk)) { case 'N': return SPARSE_OPERATION_NON_TRANSPOSE; diff --git a/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp b/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp index 93457f9837..01d32bd744 100644 --- a/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp +++ b/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp @@ -49,7 +49,6 @@ #include "KokkosSparse_Utils_mkl.hpp" #ifdef KOKKOSKERNELS_ENABLE_TPL_MKL -#include namespace KokkosSparse { namespace Experimental {