@@ -77,6 +77,9 @@ class MatVecCUSPARSEHandle_t {
77
77
using TB = typename TensorTypeB::value_type;
78
78
using TC = typename TensorTypeC::value_type;
79
79
80
+ // Mixed-precision compute type.
81
+ using TCOMP = std::conditional_t <is_matx_half_v<TC>, float , TC>;
82
+
80
83
/* *
81
84
* Construct a SpMV handle
82
85
*/
@@ -87,12 +90,12 @@ class MatVecCUSPARSEHandle_t {
87
90
params_ = GetSpMVParams (c, a, b, stream, alpha, beta);
88
91
89
92
// Properly typed alpha, beta.
90
- if constexpr (std::is_same_v<TC , cuda::std::complex<float >> ||
91
- std::is_same_v<TC , cuda::std::complex<double >>) {
93
+ if constexpr (std::is_same_v<TCOMP , cuda::std::complex<float >> ||
94
+ std::is_same_v<TCOMP , cuda::std::complex<double >>) {
92
95
salpha_ = {alpha, 0 };
93
96
sbeta_ = {beta, 0 };
94
- } else if constexpr (std::is_same_v<TC , float > ||
95
- std::is_same_v<TC , double >) {
97
+ } else if constexpr (std::is_same_v<TCOMP , float > ||
98
+ std::is_same_v<TCOMP , double >) {
96
99
salpha_ = alpha;
97
100
sbeta_ = beta;
98
101
} else {
@@ -139,7 +142,7 @@ class MatVecCUSPARSEHandle_t {
139
142
140
143
// Allocate a workspace for SpMV.
141
144
const cusparseSpMVAlg_t algo = CUSPARSE_SPMV_ALG_DEFAULT;
142
- const cudaDataType comptp = dtc; // TODO: support separate comp type?!
145
+ const cudaDataType comptp = MatXTypeToCudaType<TCOMP>();
143
146
ret =
144
147
cusparseSpMV_bufferSize (handle_, params_.opA , &salpha_, matA_, vecB_,
145
148
&sbeta_, vecC_, comptp, algo, &workspaceSize_);
@@ -188,7 +191,7 @@ class MatVecCUSPARSEHandle_t {
188
191
[[maybe_unused]] const TensorTypeB &b) {
189
192
MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_INTERNAL);
190
193
const cusparseSpMVAlg_t algo = CUSPARSE_SPMV_ALG_DEFAULT;
191
- const cudaDataType comptp = MatXTypeToCudaType<TC >(); // TODO: see above
194
+ const cudaDataType comptp = MatXTypeToCudaType<TCOMP >();
192
195
[[maybe_unused]] cusparseStatus_t ret =
193
196
cusparseSpMV (handle_, params_.opA , &salpha_, matA_, vecB_, &sbeta_,
194
197
vecC_, comptp, algo, workspace_);
@@ -203,8 +206,8 @@ class MatVecCUSPARSEHandle_t {
203
206
size_t workspaceSize_ = 0 ;
204
207
void *workspace_ = nullptr ;
205
208
detail::MatVecCUSPARSEParams_t params_;
206
- TC salpha_;
207
- TC sbeta_;
209
+ TCOMP salpha_;
210
+ TCOMP sbeta_;
208
211
};
209
212
210
213
/* *
@@ -287,10 +290,12 @@ void sparse_matvec_impl(TensorTypeC &C, const TensorTypeA &a,
287
290
" tensors must have SpMV rank" );
288
291
static_assert (std::is_same_v<TC, TA> && std::is_same_v<TC, TB>,
289
292
" tensors must have the same data type" );
290
- // TODO: allow MIXED-PRECISION computation!
291
- static_assert (std::is_same_v<TC, float > || std::is_same_v<TC, double > ||
292
- std::is_same_v<TC, cuda::std::complex<float >> ||
293
- std::is_same_v<TC, cuda::std::complex<double >>,
293
+ static_assert (std::is_same_v<TC, matx::matxFp16> ||
294
+ std::is_same_v<TC, matx::matxBf16> ||
295
+ std::is_same_v<TC, float > ||
296
+ std::is_same_v<TC, double > ||
297
+ std::is_same_v<TC, cuda::std::complex<float >> ||
298
+ std::is_same_v<TC, cuda::std::complex<double >>,
294
299
" unsupported data type" );
295
300
MATX_ASSERT (a.Size (RANKA - 1 ) == b.Size (RANKB - 1 ) &&
296
301
a.Size (RANKA - 2 ) == c.Size (RANKC - 1 ),
0 commit comments