Skip to content

Commit a3fbd5d

Browse files
authored
Support mixed-precision for SpMV (#907)
* Support mixed-precision for SpMV
1 parent 6b20e04 commit a3fbd5d

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

include/matx/transforms/matmul/matvec_cusparse.h

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ class MatVecCUSPARSEHandle_t {
7777
using TB = typename TensorTypeB::value_type;
7878
using TC = typename TensorTypeC::value_type;
7979

80+
// Mixed-precision compute type.
81+
using TCOMP = std::conditional_t<is_matx_half_v<TC>, float, TC>;
82+
8083
/**
8184
* Construct a SpMV handle
8285
*/
@@ -87,12 +90,12 @@ class MatVecCUSPARSEHandle_t {
8790
params_ = GetSpMVParams(c, a, b, stream, alpha, beta);
8891

8992
// Properly typed alpha, beta.
90-
if constexpr (std::is_same_v<TC, cuda::std::complex<float>> ||
91-
std::is_same_v<TC, cuda::std::complex<double>>) {
93+
if constexpr (std::is_same_v<TCOMP, cuda::std::complex<float>> ||
94+
std::is_same_v<TCOMP, cuda::std::complex<double>>) {
9295
salpha_ = {alpha, 0};
9396
sbeta_ = {beta, 0};
94-
} else if constexpr (std::is_same_v<TC, float> ||
95-
std::is_same_v<TC, double>) {
97+
} else if constexpr (std::is_same_v<TCOMP, float> ||
98+
std::is_same_v<TCOMP, double>) {
9699
salpha_ = alpha;
97100
sbeta_ = beta;
98101
} else {
@@ -139,7 +142,7 @@ class MatVecCUSPARSEHandle_t {
139142

140143
// Allocate a workspace for SpMV.
141144
const cusparseSpMVAlg_t algo = CUSPARSE_SPMV_ALG_DEFAULT;
142-
const cudaDataType comptp = dtc; // TODO: support separate comp type?!
145+
const cudaDataType comptp = MatXTypeToCudaType<TCOMP>();
143146
ret =
144147
cusparseSpMV_bufferSize(handle_, params_.opA, &salpha_, matA_, vecB_,
145148
&sbeta_, vecC_, comptp, algo, &workspaceSize_);
@@ -188,7 +191,7 @@ class MatVecCUSPARSEHandle_t {
188191
[[maybe_unused]] const TensorTypeB &b) {
189192
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL);
190193
const cusparseSpMVAlg_t algo = CUSPARSE_SPMV_ALG_DEFAULT;
191-
const cudaDataType comptp = MatXTypeToCudaType<TC>(); // TODO: see above
194+
const cudaDataType comptp = MatXTypeToCudaType<TCOMP>();
192195
[[maybe_unused]] cusparseStatus_t ret =
193196
cusparseSpMV(handle_, params_.opA, &salpha_, matA_, vecB_, &sbeta_,
194197
vecC_, comptp, algo, workspace_);
@@ -203,8 +206,8 @@ class MatVecCUSPARSEHandle_t {
203206
size_t workspaceSize_ = 0;
204207
void *workspace_ = nullptr;
205208
detail::MatVecCUSPARSEParams_t params_;
206-
TC salpha_;
207-
TC sbeta_;
209+
TCOMP salpha_;
210+
TCOMP sbeta_;
208211
};
209212

210213
/**
@@ -287,10 +290,12 @@ void sparse_matvec_impl(TensorTypeC &C, const TensorTypeA &a,
287290
"tensors must have SpMV rank");
288291
static_assert(std::is_same_v<TC, TA> && std::is_same_v<TC, TB>,
289292
"tensors must have the same data type");
290-
// TODO: allow MIXED-PRECISION computation!
291-
static_assert(std::is_same_v<TC, float> || std::is_same_v<TC, double> ||
292-
std::is_same_v<TC, cuda::std::complex<float>> ||
293-
std::is_same_v<TC, cuda::std::complex<double>>,
293+
static_assert(std::is_same_v<TC, matx::matxFp16> ||
294+
std::is_same_v<TC, matx::matxBf16> ||
295+
std::is_same_v<TC, float> ||
296+
std::is_same_v<TC, double> ||
297+
std::is_same_v<TC, cuda::std::complex<float>> ||
298+
std::is_same_v<TC, cuda::std::complex<double>>,
294299
"unsupported data type");
295300
MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 1) &&
296301
a.Size(RANKA - 2) == c.Size(RANKC - 1),

0 commit comments

Comments
 (0)