Skip to content

Add support for FP16 to openBLAS and shgemm on RISCV #5290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ endif ()
if (NOT DEFINED BUILD_BFLOAT16)
set (BUILD_BFLOAT16 false)
endif ()
if (NOT DEFINED BUILD_HFLOAT16)
set (BUILD_HFLOAT16 false)
endif ()
# set which float types we want to build for
if (NOT DEFINED BUILD_SINGLE AND NOT DEFINED BUILD_DOUBLE AND NOT DEFINED BUILD_COMPLEX AND NOT DEFINED BUILD_COMPLEX16)
# if none are defined, build for all
Expand Down
4 changes: 2 additions & 2 deletions Makefile.prebuild
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ TARGET_FLAGS = -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d
endif

ifeq ($(TARGET), RISCV64_ZVL256B)
TARGET_FLAGS = -march=rv64imafdcv -mabi=lp64d
TARGET_FLAGS = -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
endif

ifeq ($(TARGET), RISCV64_ZVL128B)
TARGET_FLAGS = -march=rv64imafdcv -mabi=lp64d
TARGET_FLAGS = -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
endif

ifeq ($(TARGET), RISCV64_GENERIC)
Expand Down
8 changes: 4 additions & 4 deletions Makefile.riscv64
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ CCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh_zvl512b -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d -static
endif
ifeq ($(CORE), RISCV64_ZVL256B)
CCOMMON_OPT += -march=rv64imafdcv_zvl256b -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
CCOMMON_OPT += -march=rv64imafdcv_zvl256b_zvfh_zfh -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
endif
ifeq ($(CORE), RISCV64_ZVL128B)
CCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
CCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
endif
ifeq ($(CORE), RISCV64_GENERIC)
CCOMMON_OPT += -march=rv64imafdc -mabi=lp64d
Expand Down
2 changes: 2 additions & 0 deletions Makefile.rule
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ COMMON_PROF = -pg
# If you want to enable the experimental BFLOAT16 support
# BUILD_BFLOAT16 = 1

# If you want to enable the experimental HFLOAT16 support
# BUILD_HFLOAT16 = 1

# Set the thread number threshold beyond which the job array for the threaded level3 BLAS
# will be allocated on the heap rather than the stack. (This array alone requires
Expand Down
4 changes: 4 additions & 0 deletions Makefile.system
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,9 @@ endif
ifeq ($(BUILD_BFLOAT16), 1)
CCOMMON_OPT += -DBUILD_BFLOAT16
endif
ifeq ($(BUILD_HFLOAT16), 1)
CCOMMON_OPT += -DBUILD_HFLOAT16
endif
ifeq ($(BUILD_SINGLE), 1)
CCOMMON_OPT += -DBUILD_SINGLE=1
endif
Expand Down Expand Up @@ -1889,6 +1892,7 @@ export TARGET_CORE
export NO_AVX512
export NO_AVX2
export BUILD_BFLOAT16
export BUILD_HFLOAT16
export NO_LSX
export NO_LASX

Expand Down
7 changes: 5 additions & 2 deletions Makefile.tail
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
Expand All @@ -11,8 +12,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))

HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))

BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)

ifdef EXPRECISION
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
Expand All @@ -24,6 +25,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
endif

$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX
$(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX
Expand All @@ -33,6 +35,7 @@ $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
$(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX

$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
Expand Down
2 changes: 2 additions & 0 deletions benchmark/gemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define GEMM BLASFUNC(dgemm)
#elif defined(HALF)
#define GEMM BLASFUNC(sbgemm)
#elif defined(HFLOAT16)
#define GEMM BLASFUNC(shgemm)
#else
#define GEMM BLASFUNC(sgemm)
#endif
Expand Down
4 changes: 4 additions & 0 deletions cblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,10 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);

/*** FLOAT16 extensions ***/
void cblas_shgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST hfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST hfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);

#ifdef __cplusplus
}
#endif /* __cplusplus */
Expand Down
19 changes: 11 additions & 8 deletions cmake/system.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -640,21 +640,24 @@ endif()
if (BUILD_BFLOAT16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_BFLOAT16")
endif()
if (BUILD_HFLOAT16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_HFLOAT16")
endif()
if(NOT MSVC)
set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}")
endif()
# TODO: not sure what PFLAGS is -hpa
set(PFLAGS "${PFLAGS} ${CCOMMON_OPT} -I${TOPDIR} -DPROFILE ${COMMON_PROF}")
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release")

if ("${F_COMPILER}" STREQUAL "FLANG")
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
endif ()
endif ()
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
endif ()
if ("${F_COMPILER}" STREQUAL "FLANG")
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
endif ()
endif ()
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
endif ()
endif ()


Expand Down
15 changes: 15 additions & 0 deletions common.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,14 @@ typedef uint16_t bfloat16;
#define BFLOAT16CONVERSION 1
#endif

#ifdef BUILD_HFLOAT16
#ifndef hfloat16
typedef _Float16 hfloat16;
#endif
#else
typedef uint16_t hfloat16;
#endif

#ifdef USE64BITINT
typedef BLASLONG blasint;
#if defined(OS_WINDOWS) && defined(__64BIT__)
Expand Down Expand Up @@ -313,6 +321,13 @@ typedef int blasint;
#define SIZE 2
#define BASE_SHIFT 1
#define ZBASE_SHIFT 2
#elif defined(HFLOAT16)
#define IFLOAT hfloat16
#define XFLOAT IFLOAT
#define FLOAT float
#define SIZE 2
#define BASE_SHIFT 1
#define ZBASE_SHIFT 2
#else
#define FLOAT float
#define SIZE 4
Expand Down
2 changes: 2 additions & 0 deletions common_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint

/* Level 3 routines */

void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
hfloat16 *, blasint *, hfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
Expand Down
19 changes: 18 additions & 1 deletion common_level3.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,

int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);


int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG);
int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
Expand All @@ -78,6 +79,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *,
xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
#endif

int shgemm_incopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_itcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_oncopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_otcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
Expand Down Expand Up @@ -505,6 +510,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);

int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG);
int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);
Expand Down Expand Up @@ -657,6 +663,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float
int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG);
int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG);

int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);

int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
Expand Down Expand Up @@ -754,6 +765,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON
int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG);
#endif

int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);

int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
Expand Down Expand Up @@ -1944,6 +1960,7 @@ int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
// int shgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);

#ifdef __CUDACC__
}
Expand Down
45 changes: 45 additions & 0 deletions common_macro.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#ifndef COMMON_MACRO
#define COMMON_MACRO

#include "common_sh.h"
#include "common_sb.h"
#include "common_s.h"
#include "common_d.h"
Expand Down Expand Up @@ -656,6 +657,50 @@
#define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT
#define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN
#define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT
#elif defined(HFLOAT16)
#define GEMM_BETA SHGEMM_BETA
#define GEMM_KERNEL_N SHGEMM_KERNEL
#define GEMM_KERNEL_L SHGEMM_KERNEL
#define GEMM_KERNEL_R SHGEMM_KERNEL
#define GEMM_KERNEL_B SHGEMM_KERNEL
#define GEMM_NN SHGEMM_NN
#define GEMM_CN SHGEMM_TN
#define GEMM_TN SHGEMM_TN
#define GEMM_NC SHGEMM_NT
#define GEMM_NT SHGEMM_NT
#define GEMM_CC SHGEMM_TT
#define GEMM_CT SHGEMM_TT
#define GEMM_TC SHGEMM_TT
#define GEMM_TT SHGEMM_TT
#define GEMM_NR SHGEMM_NN
#define GEMM_TR SHGEMM_TN
#define GEMM_CR SHGEMM_TN
#define GEMM_RN SHGEMM_NN
#define GEMM_RT SHGEMM_NT
#define GEMM_RC SHGEMM_NT
#define GEMM_RR SHGEMM_NN
#define GEMM_ONCOPY SHGEMM_ONCOPY
#define GEMM_OTCOPY SHGEMM_OTCOPY
#define GEMM_INCOPY SHGEMM_INCOPY
#define GEMM_ITCOPY SHGEMM_ITCOPY

#define GEMM_THREAD_NN SHGEMM_THREAD_NN
#define GEMM_THREAD_CN SHGEMM_THREAD_TN
#define GEMM_THREAD_TN SHGEMM_THREAD_TN
#define GEMM_THREAD_NC SHGEMM_THREAD_NT
#define GEMM_THREAD_NT SHGEMM_THREAD_NT
#define GEMM_THREAD_CC SHGEMM_THREAD_TT
#define GEMM_THREAD_CT SHGEMM_THREAD_TT
#define GEMM_THREAD_TC SHGEMM_THREAD_TT
#define GEMM_THREAD_TT SHGEMM_THREAD_TT
#define GEMM_THREAD_NR SHGEMM_THREAD_NN
#define GEMM_THREAD_TR SHGEMM_THREAD_TN
#define GEMM_THREAD_CR SHGEMM_THREAD_TN
#define GEMM_THREAD_RN SHGEMM_THREAD_NN
#define GEMM_THREAD_RT SHGEMM_THREAD_NT
#define GEMM_THREAD_RC SHGEMM_THREAD_NT
#define GEMM_THREAD_RR SHGEMM_THREAD_NN


#elif defined(BFLOAT16)

Expand Down
Loading