From bb712395cfa04f05008822a926c775551b820907 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Thu, 10 Jul 2025 11:02:59 -0700 Subject: [PATCH 1/2] simplify blas backend when using with USE_ONEMATH_CUBLAS --- dpnp/backend/extensions/blas/gemm.cpp | 16 ---------------- dpnp/backend/extensions/blas/gemm_batch.cpp | 17 ----------------- dpnp/backend/extensions/blas/gemv.cpp | 15 --------------- 3 files changed, 48 deletions(-) diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index af18ab3002f..a757db811c5 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -55,9 +55,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &, const std::int64_t, char *, const std::int64_t, -#if !defined(USE_ONEMATH_CUBLAS) const bool, -#endif // !USE_ONEMATH_CUBLAS const std::vector &); static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types] @@ -76,9 +74,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q, const std::int64_t ldb, char *resultC, const std::int64_t ldc, -#if !defined(USE_ONEMATH_CUBLAS) const bool is_row_major, -#endif // !USE_ONEMATH_CUBLAS const std::vector &depends) { type_utils::validate_type_for_device(exec_q); @@ -100,11 +96,6 @@ static sycl::event gemm_impl(sycl::queue &exec_q, const Tab *a, const std::int64_t lda, const Tab *b, const std::int64_t ldb, Tab beta, Tc *c, const std::int64_t ldc, const std::vector &deps) -> sycl::event { -#if defined(USE_ONEMATH_CUBLAS) - return mkl_blas::column_major::gemm(q, transA, transB, m, n, k, - alpha, a, lda, b, ldb, beta, c, - ldc, deps); -#else if (is_row_major) { return mkl_blas::row_major::gemm(q, transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, c, @@ -115,7 +106,6 @@ static sycl::event gemm_impl(sycl::queue &exec_q, alpha, a, lda, b, ldb, beta, c, ldc, deps); } -#endif // USE_ONEMATH_CUBLAS }; gemm_event = gemm_func( exec_q, @@ -320,15 +310,9 @@ std::tuple const char *b_typeless_ptr = matrixB.get_data(); char *r_typeless_ptr = resultC.get_data(); -#if defined(USE_ONEMATH_CUBLAS) - sycl::event gemm_ev = - gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda, - b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends); -#else sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda, b_typeless_ptr, ldb, r_typeless_ptr, ldc, is_row_major, depends); -#endif // USE_ONEMATH_CUBLAS sycl::event args_ev = dpctl::utils::keep_args_alive( exec_q, {matrixA, matrixB, resultC}, {gemm_ev}); diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 1e210aede9f..854192b12b3 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -60,9 +60,7 @@ typedef sycl::event (*gemm_batch_impl_fn_ptr_t)( const char *, const char *, char *, -#if !defined(USE_ONEMATH_CUBLAS) const bool, -#endif // !USE_ONEMATH_CUBLAS const std::vector &); static gemm_batch_impl_fn_ptr_t @@ -85,9 +83,7 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, const char *matrixA, const char *matrixB, char *resultC, -#if !defined(USE_ONEMATH_CUBLAS) const bool is_row_major, -#endif // !USE_ONEMATH_CUBLAS const std::vector &depends) { type_utils::validate_type_for_device(exec_q); @@ -112,11 +108,6 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, Tc *c, const std::int64_t ldc, const std::int64_t stridec, const std::int64_t batch_size, const std::vector &deps) -> sycl::event { -#if defined(USE_ONEMATH_CUBLAS) - return mkl_blas::column_major::gemm_batch( - q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb, - strideb, beta, c, ldc, stridec, batch_size, deps); -#else if (is_row_major) { return mkl_blas::row_major::gemm_batch( q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb, @@ -127,7 +118,6 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, batch_size, deps); } -#endif // USE_ONEMATH_CUBLAS }; gemm_batch_event = gemm_batch_func( exec_q, @@ -396,17 +386,10 @@ std::tuple const char *b_typeless_ptr = matrixB.get_data(); char *r_typeless_ptr = resultC.get_data(); -#if defined(USE_ONEMATH_CUBLAS) - sycl::event gemm_batch_ev = - gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea, - strideb, stridec, transA, transB, a_typeless_ptr, - b_typeless_ptr, r_typeless_ptr, depends); -#else sycl::event gemm_batch_ev = gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea, strideb, stridec, transA, transB, a_typeless_ptr, b_typeless_ptr, r_typeless_ptr, is_row_major, depends); -#endif // USE_ONEMATH_CUBLAS sycl::event args_ev = dpctl::utils::keep_args_alive( exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev}); diff --git a/dpnp/backend/extensions/blas/gemv.cpp b/dpnp/backend/extensions/blas/gemv.cpp index 91057893aa5..4e293bf45df 100644 --- a/dpnp/backend/extensions/blas/gemv.cpp +++ b/dpnp/backend/extensions/blas/gemv.cpp @@ -53,9 +53,7 @@ typedef sycl::event (*gemv_impl_fn_ptr_t)(sycl::queue &, const std::int64_t, char *, const std::int64_t, -#if !defined(USE_ONEMATH_CUBLAS) const bool, -#endif // !USE_ONEMATH_CUBLAS const std::vector &); static gemv_impl_fn_ptr_t gemv_dispatch_vector[dpctl_td_ns::num_types]; @@ -71,9 +69,7 @@ static sycl::event gemv_impl(sycl::queue &exec_q, const std::int64_t incx, char *vectorY, const std::int64_t incy, -#if !defined(USE_ONEMATH_CUBLAS) const bool is_row_major, -#endif // !USE_ONEMATH_CUBLAS const std::vector &depends) { type_utils::validate_type_for_device(exec_q); @@ -93,10 +89,6 @@ static sycl::event gemv_impl(sycl::queue &exec_q, const std::int64_t lda, const T *x, const std::int64_t incx, T beta, T *y, const std::int64_t incy, const std::vector &deps) -> sycl::event { -#if defined(USE_ONEMATH_CUBLAS) - return mkl_blas::column_major::gemv(q, transA, m, n, alpha, a, lda, - x, incx, beta, y, incy, deps); -#else if (is_row_major) { return mkl_blas::row_major::gemv(q, transA, m, n, alpha, a, lda, x, incx, beta, y, incy, deps); @@ -106,7 +98,6 @@ static sycl::event gemv_impl(sycl::queue &exec_q, lda, x, incx, beta, y, incy, deps); } -#endif // USE_ONEMATH_CUBLAS }; gemv_event = gemv_func( exec_q, @@ -304,15 +295,9 @@ std::pair y_typeless_ptr -= (y_shape[0] - 1) * std::abs(incy) * y_elemsize; } -#if defined(USE_ONEMATH_CUBLAS) - sycl::event gemv_ev = - gemv_fn(exec_q, transA, m, n, a_typeless_ptr, lda, x_typeless_ptr, incx, - y_typeless_ptr, incy, depends); -#else sycl::event gemv_ev = gemv_fn(exec_q, transA, m, n, a_typeless_ptr, lda, x_typeless_ptr, incx, y_typeless_ptr, incy, is_row_major, depends); -#endif // USE_ONEMATH_CUBLAS sycl::event args_ev = dpctl::utils::keep_args_alive( exec_q, {matrixA, vectorX, vectorY}, {gemv_ev}); From a240c4a0dcf6b37adbfbd2bc9c1b484957e53366 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Fri, 11 Jul 2025 09:08:03 -0700 Subject: [PATCH 2/2] use constexpr for defining is_row_major --- dpnp/backend/extensions/blas/gemm.cpp | 2 +- dpnp/backend/extensions/blas/gemm_batch.cpp | 2 +- dpnp/backend/extensions/blas/gemv.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index a757db811c5..c343c232b7a 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -232,7 +232,7 @@ std::tuple // cuBLAS supports only column-major storage #if defined(USE_ONEMATH_CUBLAS) - const bool is_row_major = false; + constexpr bool is_row_major = false; transA = is_matrixA_c_contig ? oneapi::mkl::transpose::T : oneapi::mkl::transpose::N; diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 854192b12b3..95f5a1aaf32 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -307,7 +307,7 @@ std::tuple // cuBLAS supports only column-major storage #if defined(USE_ONEMATH_CUBLAS) - const bool is_row_major = false; + constexpr bool is_row_major = false; transA = A_base_is_c_contig ? oneapi::mkl::transpose::T : oneapi::mkl::transpose::N; diff --git a/dpnp/backend/extensions/blas/gemv.cpp b/dpnp/backend/extensions/blas/gemv.cpp index 4e293bf45df..29bce7a1099 100644 --- a/dpnp/backend/extensions/blas/gemv.cpp +++ b/dpnp/backend/extensions/blas/gemv.cpp @@ -187,7 +187,7 @@ std::pair // cuBLAS supports only column-major storage #if defined(USE_ONEMATH_CUBLAS) - const bool is_row_major = false; + constexpr bool is_row_major = false; std::int64_t m; std::int64_t n;