diff --git a/CMakeLists.txt b/CMakeLists.txt index f13f707f98..793cd767a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -152,6 +152,9 @@ endif () if (NOT DEFINED BUILD_BFLOAT16) set (BUILD_BFLOAT16 false) endif () +if (NOT DEFINED BUILD_HFLOAT16) + set (BUILD_HFLOAT16 false) +endif () # set which float types we want to build for if (NOT DEFINED BUILD_SINGLE AND NOT DEFINED BUILD_DOUBLE AND NOT DEFINED BUILD_COMPLEX AND NOT DEFINED BUILD_COMPLEX16) # if none are defined, build for all diff --git a/Makefile.prebuild b/Makefile.prebuild index b7d695a750..b6c8d552f9 100644 --- a/Makefile.prebuild +++ b/Makefile.prebuild @@ -64,11 +64,11 @@ TARGET_FLAGS = -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d endif ifeq ($(TARGET), RISCV64_ZVL256B) -TARGET_FLAGS = -march=rv64imafdcv -mabi=lp64d +TARGET_FLAGS = -march=rv64imafdcv_zvfh_zfh -mabi=lp64d endif ifeq ($(TARGET), RISCV64_ZVL128B) -TARGET_FLAGS = -march=rv64imafdcv -mabi=lp64d +TARGET_FLAGS = -march=rv64imafdcv_zvfh_zfh -mabi=lp64d endif ifeq ($(TARGET), RISCV64_GENERIC) diff --git a/Makefile.riscv64 b/Makefile.riscv64 index 0ee26c1b5c..8fe734186b 100644 --- a/Makefile.riscv64 +++ b/Makefile.riscv64 @@ -7,12 +7,12 @@ CCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh_zvl512b -mabi=lp64d FCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d -static endif ifeq ($(CORE), RISCV64_ZVL256B) -CCOMMON_OPT += -march=rv64imafdcv_zvl256b -mabi=lp64d -FCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d +CCOMMON_OPT += -march=rv64imafdcv_zvl256b_zvfh_zfh -mabi=lp64d +FCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d endif ifeq ($(CORE), RISCV64_ZVL128B) -CCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d -FCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d +CCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d +FCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d endif ifeq ($(CORE), RISCV64_GENERIC) CCOMMON_OPT += -march=rv64imafdc -mabi=lp64d diff --git a/Makefile.rule b/Makefile.rule index 1472ed938b..24b34d1c20 100644 --- a/Makefile.rule +++ b/Makefile.rule @@ -308,6 +308,8 @@ COMMON_PROF = -pg # If you want to enable the experimental BFLOAT16 support # BUILD_BFLOAT16 = 1 +# If you want to enable the experimental HFLOAT16 support +# BUILD_HFLOAT16 = 1 # Set the thread number threshold beyond which the job array for the threaded level3 BLAS # will be allocated on the heap rather than the stack. (This array alone requires diff --git a/Makefile.system b/Makefile.system index 38646c3c6b..be31d05ef4 100644 --- a/Makefile.system +++ b/Makefile.system @@ -1547,6 +1547,9 @@ endif ifeq ($(BUILD_BFLOAT16), 1) CCOMMON_OPT += -DBUILD_BFLOAT16 endif +ifeq ($(BUILD_HFLOAT16), 1) +CCOMMON_OPT += -DBUILD_HFLOAT16 +endif ifeq ($(BUILD_SINGLE), 1) CCOMMON_OPT += -DBUILD_SINGLE=1 endif @@ -1889,6 +1892,7 @@ export TARGET_CORE export NO_AVX512 export NO_AVX2 export BUILD_BFLOAT16 +export BUILD_HFLOAT16 export NO_LSX export NO_LASX diff --git a/Makefile.tail b/Makefile.tail index 54ba649dbf..ed2c0e5073 100644 --- a/Makefile.tail +++ b/Makefile.tail @@ -1,4 +1,5 @@ SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) +SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) @@ -11,8 +12,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) -BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) -BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) +BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) +BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) ifdef EXPRECISION BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) @@ -24,6 +25,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) endif +$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX $(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX @@ -33,6 +35,7 @@ $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX $(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX +$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) diff --git a/benchmark/gemm.c b/benchmark/gemm.c index 35f5096f35..6662c26e97 100644 --- a/benchmark/gemm.c +++ b/benchmark/gemm.c @@ -35,6 +35,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define GEMM BLASFUNC(dgemm) #elif defined(HALF) #define GEMM BLASFUNC(sbgemm) +#elif defined(HFLOAT16) +#define GEMM BLASFUNC(shgemm) #else #define GEMM BLASFUNC(sgemm) #endif diff --git a/cblas.h b/cblas.h index 83686f7433..0364b216fc 100644 --- a/cblas.h +++ b/cblas.h @@ -446,6 +446,10 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); +/*** FLOAT16 extensions ***/ +void cblas_shgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, + OPENBLAS_CONST float alpha, OPENBLAS_CONST hfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST hfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc); + #ifdef __cplusplus } #endif /* __cplusplus */ diff --git a/cmake/system.cmake b/cmake/system.cmake index 14b2c65b11..bac756901f 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -640,6 +640,9 @@ endif() if (BUILD_BFLOAT16) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_BFLOAT16") endif() +if (BUILD_HFLOAT16) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_HFLOAT16") +endif() if(NOT MSVC) set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}") endif() @@ -647,14 +650,14 @@ endif() set(PFLAGS "${PFLAGS} ${CCOMMON_OPT} -I${TOPDIR} -DPROFILE ${COMMON_PROF}") if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release") -if ("${F_COMPILER}" STREQUAL "FLANG") -if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3) - set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops") -endif () -endif () -if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows") - set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2") -endif () + if ("${F_COMPILER}" STREQUAL "FLANG") + if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3) + set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops") + endif () + endif () + if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows") + set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2") + endif () endif () diff --git a/common.h b/common.h index 8d002c4aa0..23a08aaa98 100644 --- a/common.h +++ b/common.h @@ -266,6 +266,14 @@ typedef uint16_t bfloat16; #define BFLOAT16CONVERSION 1 #endif +#ifdef BUILD_HFLOAT16 + #ifndef hfloat16 + typedef _Float16 hfloat16; + #endif +#else + typedef uint16_t hfloat16; +#endif + #ifdef USE64BITINT typedef BLASLONG blasint; #if defined(OS_WINDOWS) && defined(__64BIT__) @@ -313,6 +321,13 @@ typedef int blasint; #define SIZE 2 #define BASE_SHIFT 1 #define ZBASE_SHIFT 2 +#elif defined(HFLOAT16) +#define IFLOAT hfloat16 +#define XFLOAT IFLOAT +#define FLOAT float +#define SIZE 2 +#define BASE_SHIFT 1 +#define ZBASE_SHIFT 2 #else #define FLOAT float #define SIZE 4 diff --git a/common_interface.h b/common_interface.h index efd3c6649d..23d86871fc 100644 --- a/common_interface.h +++ b/common_interface.h @@ -481,6 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint /* Level 3 routines */ +void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *, + hfloat16 *, blasint *, hfloat16 *, blasint *, float *, float *, blasint *); void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *, bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, diff --git a/common_level3.h b/common_level3.h index d370c1f96a..1838b4bf6a 100644 --- a/common_level3.h +++ b/common_level3.h @@ -54,7 +54,8 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K, int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); - +int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, + hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG); int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, @@ -78,6 +79,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG); #endif +int shgemm_incopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); +int shgemm_itcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); +int shgemm_oncopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); +int shgemm_otcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); @@ -505,6 +510,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); +int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG); int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); @@ -657,6 +663,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG); int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG); +int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); + int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); @@ -754,6 +765,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG); #endif +int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); +int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); + int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); @@ -1944,6 +1960,7 @@ int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +// int shgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); #ifdef __CUDACC__ } diff --git a/common_macro.h b/common_macro.h index 820cb472a6..b967c2e603 100644 --- a/common_macro.h +++ b/common_macro.h @@ -39,6 +39,7 @@ #ifndef COMMON_MACRO #define COMMON_MACRO +#include "common_sh.h" #include "common_sb.h" #include "common_s.h" #include "common_d.h" @@ -656,6 +657,50 @@ #define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT #define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN #define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT +#elif defined(HFLOAT16) +#define GEMM_BETA SHGEMM_BETA +#define GEMM_KERNEL_N SHGEMM_KERNEL +#define GEMM_KERNEL_L SHGEMM_KERNEL +#define GEMM_KERNEL_R SHGEMM_KERNEL +#define GEMM_KERNEL_B SHGEMM_KERNEL +#define GEMM_NN SHGEMM_NN +#define GEMM_CN SHGEMM_TN +#define GEMM_TN SHGEMM_TN +#define GEMM_NC SHGEMM_NT +#define GEMM_NT SHGEMM_NT +#define GEMM_CC SHGEMM_TT +#define GEMM_CT SHGEMM_TT +#define GEMM_TC SHGEMM_TT +#define GEMM_TT SHGEMM_TT +#define GEMM_NR SHGEMM_NN +#define GEMM_TR SHGEMM_TN +#define GEMM_CR SHGEMM_TN +#define GEMM_RN SHGEMM_NN +#define GEMM_RT SHGEMM_NT +#define GEMM_RC SHGEMM_NT +#define GEMM_RR SHGEMM_NN +#define GEMM_ONCOPY SHGEMM_ONCOPY +#define GEMM_OTCOPY SHGEMM_OTCOPY +#define GEMM_INCOPY SHGEMM_INCOPY +#define GEMM_ITCOPY SHGEMM_ITCOPY + +#define GEMM_THREAD_NN SHGEMM_THREAD_NN +#define GEMM_THREAD_CN SHGEMM_THREAD_TN +#define GEMM_THREAD_TN SHGEMM_THREAD_TN +#define GEMM_THREAD_NC SHGEMM_THREAD_NT +#define GEMM_THREAD_NT SHGEMM_THREAD_NT +#define GEMM_THREAD_CC SHGEMM_THREAD_TT +#define GEMM_THREAD_CT SHGEMM_THREAD_TT +#define GEMM_THREAD_TC SHGEMM_THREAD_TT +#define GEMM_THREAD_TT SHGEMM_THREAD_TT +#define GEMM_THREAD_NR SHGEMM_THREAD_NN +#define GEMM_THREAD_TR SHGEMM_THREAD_TN +#define GEMM_THREAD_CR SHGEMM_THREAD_TN +#define GEMM_THREAD_RN SHGEMM_THREAD_NN +#define GEMM_THREAD_RT SHGEMM_THREAD_NT +#define GEMM_THREAD_RC SHGEMM_THREAD_NT +#define GEMM_THREAD_RR SHGEMM_THREAD_NN + #elif defined(BFLOAT16) diff --git a/common_param.h b/common_param.h index 2d771a27da..f82b73a72b 100644 --- a/common_param.h +++ b/common_param.h @@ -48,6 +48,21 @@ typedef struct { int dtb_entries; int switch_ratio; int offsetA, offsetB, align; +#if BUILD_HFLOAT16 == 1 +int shgemm_p, shgemm_q, shgemm_r; +int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; + +int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG); +int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG); + +int (*shgemm_incopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); +int (*shgemm_itcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); +int (*shgemm_oncopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); +int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); + + +#endif + #if BUILD_BFLOAT16 == 1 int sbgemm_p, sbgemm_q, sbgemm_r; @@ -64,10 +79,10 @@ typedef struct { float (*sbamin_k) (BLASLONG, float *, BLASLONG); float (*sbmax_k) (BLASLONG, float *, BLASLONG); float (*sbmin_k) (BLASLONG, float *, BLASLONG); -BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG); -BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG); -BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG); -BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG); + BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG); + BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG); + BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG); + BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG); float (*sbnrm2_k) (BLASLONG, float *, BLASLONG); float (*sbasum_k) (BLASLONG, float *, BLASLONG); @@ -180,12 +195,12 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG); #endif #if (BUILD_SINGLE==1) || (BUILD_DOUBLE ==1) || (BUILD_COMPLEX==1) -BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG); + BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG); #endif #if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1) -BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG); -BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG); -BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); + BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG); + BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG); + BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); float (*snrm2_k) (BLASLONG, float *, BLASLONG); float (*sasum_k) (BLASLONG, float *, BLASLONG); #endif @@ -316,10 +331,10 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); double (*damin_k) (BLASLONG, double *, BLASLONG); double (*dmax_k) (BLASLONG, double *, BLASLONG); double (*dmin_k) (BLASLONG, double *, BLASLONG); -BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG); -BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG); -BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG); -BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); + BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG); + BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG); + BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG); + BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); double (*dnrm2_k) (BLASLONG, double *, BLASLONG); double (*dasum_k) (BLASLONG, double *, BLASLONG); @@ -435,10 +450,10 @@ BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); xdouble (*qamin_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qmax_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qmin_k) (BLASLONG, xdouble *, BLASLONG); -BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG); -BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG); -BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG); -BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG); + BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG); + BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG); + BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG); + BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qnrm2_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qasum_k) (BLASLONG, xdouble *, BLASLONG); @@ -528,8 +543,8 @@ BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG); float (*camax_k) (BLASLONG, float *, BLASLONG); float (*camin_k) (BLASLONG, float *, BLASLONG); -BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG); -BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG); + BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG); + BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG); float (*cnrm2_k) (BLASLONG, float *, BLASLONG); float (*casum_k) (BLASLONG, float *, BLASLONG); @@ -739,8 +754,8 @@ BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG); double (*zamax_k) (BLASLONG, double *, BLASLONG); double (*zamin_k) (BLASLONG, double *, BLASLONG); -BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG); -BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG); + BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG); + BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG); double (*znrm2_k) (BLASLONG, double *, BLASLONG); double (*zasum_k) (BLASLONG, double *, BLASLONG); @@ -950,8 +965,8 @@ BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG); xdouble (*xamax_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*xamin_k) (BLASLONG, xdouble *, BLASLONG); -BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG); -BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG); + BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG); + BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG); xdouble (*xnrm2_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*xasum_k) (BLASLONG, xdouble *, BLASLONG); @@ -1229,6 +1244,15 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 gotoblas -> exclusive_cache +#if (BUILD_HFLOAT16==1) +#define SHGEMM_P gotoblas -> shgemm_p +#define SHGEMM_Q gotoblas -> shgemm_q +#define SHGEMM_R gotoblas -> shgemm_r +#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m +#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n +#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn +#endif + #if (BUILD_BFLOAT16==1) #define SBGEMM_P gotoblas -> sbgemm_p #define SBGEMM_Q gotoblas -> sbgemm_q @@ -1357,6 +1381,19 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 0 #endif +#if (BUILD_HFLOAT16 == 1) +#define SHGEMM_P SHGEMM_DEFAULT_P +#define SHGEMM_Q SHGEMM_DEFAULT_Q +#define SHGEMM_R SHGEMM_DEFAULT_R +#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M +#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N +#ifdef SHGEMM_DEFAULT_UNROLL_MN +#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN +#else +#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N)) +#endif +#endif + #if (BUILD_BFLOAT16 == 1) #define SBGEMM_P SBGEMM_DEFAULT_P #define SBGEMM_Q SBGEMM_DEFAULT_Q @@ -1478,6 +1515,7 @@ extern gotoblas_t *gotoblas; #endif + #endif #ifndef COMPLEX @@ -1505,6 +1543,18 @@ extern gotoblas_t *gotoblas; #define GEMM_DEFAULT_R DGEMM_DEFAULT_R #define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M #define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N +#elif defined(HFLOAT16) +#define GEMM_P SHGEMM_P +#define GEMM_Q SHGEMM_Q +#define GEMM_R SHGEMM_R +#define GEMM_UNROLL_M SHGEMM_UNROLL_M +#define GEMM_UNROLL_N SHGEMM_UNROLL_N +#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN +#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P +#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q +#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R +#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M +#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N #elif defined(BFLOAT16) #define GEMM_P SBGEMM_P #define GEMM_Q SBGEMM_Q diff --git a/common_sh.h b/common_sh.h new file mode 100644 index 0000000000..69734d1dc2 --- /dev/null +++ b/common_sh.h @@ -0,0 +1,72 @@ +#ifndef COMMON_SH_H +#define COMMON_SH_H + +#ifndef DYNAMIC_ARCH + +#define SHGEMM_ONCOPY shgemm_oncopy +#define SHGEMM_OTCOPY shgemm_otcopy + +#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N +#define SHGEMM_INCOPY shgemm_oncopy +#define SHGEMM_ITCOPY shgemm_otcopy +#else +#define SHGEMM_INCOPY shgemm_incopy +#define SHGEMM_ITCOPY shgemm_itcopy +#endif + +#define SHGEMM_BETA shgemm_beta +#define SHGEMM_KERNEL shgemm_kernel + + +#else // #DYNAMIC_ARCH + +#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy +#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy +#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N +#define SHGEMM_INCOPY gotoblas -> shgemm_oncopy +#define SHGEMM_ITCOPY gotoblas -> shgemm_otcopy +#else +#define SHGEMM_INCOPY gotoblas -> shgemm_incopy +#define SHGEMM_ITCOPY gotoblas -> shgemm_itcopy +#endif + +#define SHGEMM_BETA gotoblas -> shgemm_beta +#define SHGEMM_KERNEL gotoblas -> shgemm_kernel +#endif // #DYNAMIC_ARCH + +#define SHGEMM_NN shgemm_nn +#define SHGEMM_CN shgemm_tn +#define SHGEMM_TN shgemm_tn +#define SHGEMM_NC shgemm_nt +#define SHGEMM_NT shgemm_nt +#define SHGEMM_CC shgemm_tt +#define SHGEMM_CT shgemm_tt +#define SHGEMM_TC shgemm_tt +#define SHGEMM_TT shgemm_tt +#define SHGEMM_NR shgemm_nn +#define SHGEMM_TR shgemm_tn +#define SHGEMM_CR shgemm_tn +#define SHGEMM_RN shgemm_nn +#define SHGEMM_RT shgemm_nt +#define SHGEMM_RC shgemm_nt +#define SHGEMM_RR shgemm_nn + +#define SHGEMM_THREAD_NN shgemm_thread_nn +#define SHGEMM_THREAD_CN shgemm_thread_tn +#define SHGEMM_THREAD_TN shgemm_thread_tn +#define SHGEMM_THREAD_NC shgemm_thread_nt +#define SHGEMM_THREAD_NT shgemm_thread_nt +#define SHGEMM_THREAD_CC shgemm_thread_tt +#define SHGEMM_THREAD_CT shgemm_thread_tt +#define SHGEMM_THREAD_TC shgemm_thread_tt +#define SHGEMM_THREAD_TT shgemm_thread_tt +#define SHGEMM_THREAD_NR shgemm_thread_nn +#define SHGEMM_THREAD_TR shgemm_thread_tn +#define SHGEMM_THREAD_CR shgemm_thread_tn +#define SHGEMM_THREAD_RN shgemm_thread_nn +#define SHGEMM_THREAD_RT shgemm_thread_nt +#define SHGEMM_THREAD_RC shgemm_thread_nt +#define SHGEMM_THREAD_RR shgemm_thread_nn + + +#endif // #COMMON_SH_H \ No newline at end of file diff --git a/driver/level3/CMakeLists.txt b/driver/level3/CMakeLists.txt index eabfeed24a..b4c6125315 100644 --- a/driver/level3/CMakeLists.txt +++ b/driver/level3/CMakeLists.txt @@ -18,6 +18,12 @@ foreach (GEMM_DEFINE ${GEMM_DEFINES}) GenerateNamedObjects("gemm.c" "${GEMM_DEFINE};THREADED_LEVEL3" "gemm_thread_${GEMM_DEFINE_LC}" 0 "" "" false "BFLOAT16") endif () endif () + if (BUILD_HFLOAT16) + GenerateNamedObjects("gemm.c" "${GEMM_DEFINE}" "gemm_${GEMM_DEFINE_LC}" 0 "" "" false "HFLOAT16") + if (USE_THREAD AND NOT USE_SIMPLE_THREADED_LEVEL3) + GenerateNamedObjects("gemm.c" "${GEMM_DEFINE};THREADED_LEVEL3" "gemm_thread_${GEMM_DEFINE_LC}" 0 "" "" false "HFLOAT16") + endif () + endif () endforeach () if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE) diff --git a/driver/level3/Makefile b/driver/level3/Makefile index c304838423..b0d1d6b623 100644 --- a/driver/level3/Makefile +++ b/driver/level3/Makefile @@ -23,6 +23,10 @@ ifeq ($(BUILD_BFLOAT16),1) SBBLASOBJS += sbgemm_nn.$(SUFFIX) sbgemm_nt.$(SUFFIX) sbgemm_tn.$(SUFFIX) sbgemm_tt.$(SUFFIX) endif +ifeq ($(BUILD_HFLOAT16),1) +SHBLASOBJS += shgemm_nn.$(SUFFIX) shgemm_nt.$(SUFFIX) shgemm_tn.$(SUFFIX) shgemm_tt.$(SUFFIX) +endif + SBLASOBJS += \ sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \ strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \ @@ -210,6 +214,9 @@ ifneq ($(USE_SIMPLE_THREADED_LEVEL3), 1) ifeq ($(BUILD_BFLOAT16),1) SBBLASOBJS += sbgemm_thread_nn.$(SUFFIX) sbgemm_thread_nt.$(SUFFIX) sbgemm_thread_tn.$(SUFFIX) sbgemm_thread_tt.$(SUFFIX) endif +ifeq ($(BUILD_HFLOAT16),1) +SHBLASOBJS += shgemm_thread_nn.$(SUFFIX) shgemm_thread_nt.$(SUFFIX) shgemm_thread_tn.$(SUFFIX) shgemm_thread_tt.$(SUFFIX) +endif SBLASOBJS += sgemm_thread_nn.$(SUFFIX) sgemm_thread_nt.$(SUFFIX) sgemm_thread_tn.$(SUFFIX) sgemm_thread_tt.$(SUFFIX) DBLASOBJS += dgemm_thread_nn.$(SUFFIX) dgemm_thread_nt.$(SUFFIX) dgemm_thread_tn.$(SUFFIX) dgemm_thread_tt.$(SUFFIX) QBLASOBJS += qgemm_thread_nn.$(SUFFIX) qgemm_thread_nt.$(SUFFIX) qgemm_thread_tn.$(SUFFIX) qgemm_thread_tt.$(SUFFIX) @@ -355,6 +362,18 @@ sbgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h sbgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) +shgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +shgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +shgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +shgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) @@ -562,6 +581,18 @@ sbgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h sbgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) +shgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +shgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +shgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +shgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) @@ -2747,6 +2778,18 @@ sbgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h sbgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) +shgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +shgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +shgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +shgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h $(CC) $(PFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) @@ -2970,6 +3013,18 @@ sbgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h sbgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) +shgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +shgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +shgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +shgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) diff --git a/driver/others/Makefile b/driver/others/Makefile index 719d617c45..6a53638bc1 100644 --- a/driver/others/Makefile +++ b/driver/others/Makefile @@ -218,7 +218,7 @@ mulx.$(SUFFIX) : $(ARCH)/mulx.c $(CC) $(CFLAGS) -c -DXDOUBLE -UCOMPLEX $< -o $(@F) detect_riscv64.$(SUFFIX): detect_riscv64.c - $(CC) $(CFLAGS) -c -march=rv64imafdcv $< -o $(@F) + $(CC) $(CFLAGS) -c -march=rv64imafdcv_zvfh_zfh $< -o $(@F) xerbla.$(PSUFFIX) : xerbla.c $(CC) $(PFLAGS) -c $< -o $(@F) diff --git a/driver/others/parameter.c b/driver/others/parameter.c index 597e5cac7e..3bcb0d4343 100644 --- a/driver/others/parameter.c +++ b/driver/others/parameter.c @@ -67,6 +67,11 @@ BLASLONG sbgemm_p = DEFAULT_GEMM_P; #else BLASLONG sbgemm_p = SBGEMM_P; #endif +#if SHGEMM_P == shgemm_p +BLASLONG shgemm_p = DEFAULT_GEMM_P; +#else +BLASLONG shgemm_p = SHGEMM_P; +#endif #if SGEMM_P == sgemm_p BLASLONG sgemm_p = DEFAULT_GEMM_P; #else @@ -93,6 +98,11 @@ BLASLONG sbgemm_q = DEFAULT_GEMM_Q; #else BLASLONG sbgemm_q = SBGEMM_Q; #endif +#if SHGEMM_Q == shgemm_q +BLASLONG shgemm_q = DEFAULT_GEMM_Q; +#else +BLASLONG shgemm_q = SHGEMM_Q; +#endif #if SGEMM_Q == sgemm_q BLASLONG sgemm_q = DEFAULT_GEMM_Q; #else @@ -119,6 +129,11 @@ BLASLONG sbgemm_r = DEFAULT_GEMM_R; #else BLASLONG sbgemm_r = SBGEMM_R; #endif +#if SHGEMM_R == shgemm_r +BLASLONG shgemm_r = DEFAULT_GEMM_R; +#else +BLASLONG shgemm_r = SHGEMM_R; +#endif #if SGEMM_R == sgemm_r BLASLONG sgemm_r = DEFAULT_GEMM_R; #else @@ -526,6 +541,9 @@ void blas_set_parameter(void){ #ifdef BUILD_BFLOAT16 sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15; +#endif +#ifdef BUILD_HFLOAT16 + shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15; #endif sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15; dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15; @@ -619,6 +637,7 @@ void blas_set_parameter(void){ size = BITMASK(cpuid3, 16, 0xff); sbgemm_p = 192 * (size + 1); + shgemm_p = 192 * (size + 1); sgemm_p = 192 * (size + 1); dgemm_p = 96 * (size + 1); cgemm_p = 96 * (size + 1); @@ -634,6 +653,9 @@ void blas_set_parameter(void){ #ifdef BUILD_BFLOAT16 sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15; +#endif +#ifdef BUILD_HFLOAT16 + shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15; #endif sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15; dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15; diff --git a/exports/Makefile b/exports/Makefile index 04fc64cfe0..b4b391a197 100644 --- a/exports/Makefile +++ b/exports/Makefile @@ -39,6 +39,9 @@ endif ifndef BUILD_BFLOAT16 BUILD_BFLOAT16 = 0 endif +ifndef BUILD_HFLOAT16 +BUILD_HFLOAT16 = 0 +endif ifndef BUILD_SINGLE BUILD_SINGLE = 0 endif diff --git a/exports/gensymbol b/exports/gensymbol index f747dd091f..231e72f48d 100755 --- a/exports/gensymbol +++ b/exports/gensymbol @@ -52,6 +52,7 @@ blasobjsz=" blasobjs="lsame xerbla" bfblasobjs="sbgemm sbgemmt sbgemmtr sbgemv sbdot sbstobf16 sbdtobf16 sbf16tos dbf16tod" +hfblasobjs="shgemm" cblasobjsc=" cblas_caxpy cblas_ccopy cblas_cdotc cblas_cdotu cblas_cgbmv cblas_cgemm cblas_cgemv cblas_cgerc cblas_cgeru cblas_chbmv cblas_chemm cblas_chemv cblas_cher2 cblas_cher2k @@ -100,6 +101,7 @@ cblasobjsz=" cblasobjs="cblas_xerbla" bfcblasobjs="cblas_sbgemm cblas_sbgemv cblas_sbdot cblas_sbstobf16 cblas_sbdtobf16 cblas_sbf16tos cblas_dbf16tod cblas_sbgemm_batch" +hfcblasobjs="cblas_shgemm" exblasobjs=" qamax qamin qasum qaxpy qcabs1 qcopy qdot qgbmv qgemm @@ -3816,8 +3818,8 @@ shift p17=$9 if [ $p13 -eq 1 ]; then - blasobjs="$blasobjs $bfblasobjs" - cblasobjs="$cblasobjs $bfcblasobjs" + blasobjs="$blasobjs $bfblasobjs $hfblasobjs" + cblasobjs="$cblasobjs $bfcblasobjs $hfcblasobjs" fi if [ $p14 -eq 1 ]; then diff --git a/exports/gensymbol.pl b/exports/gensymbol.pl index 5597306343..1c4e912f2b 100644 --- a/exports/gensymbol.pl +++ b/exports/gensymbol.pl @@ -52,6 +52,7 @@ @blasobjs = (lsame, xerbla); @bfblasobjs = (sbgemm, sbgemmt, sbgemmtr, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod); +@hfblasobjs = (shgemm); @cblasobjsc = ( cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv, cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k, @@ -97,7 +98,7 @@ @cblasobjs = ( cblas_xerbla ); @bfcblasobjs = (cblas_sbgemm, cblas_sbgemmt, cblas_sbgemmtr, cblas_sbgemv, cblas_sbdot, cblas_sbstobf16, cblas_sbdtobf16, cblas_sbf16tos, cblas_dbf16tod, cblas_sbgemm_batch); - +@hfcblasobjs = (cblas_shgemm); @exblasobjs = ( qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm, qgemv,qger,qmax,qmin, @@ -3773,8 +3774,8 @@ my $dirname = File::Spec->catfile(dirname(dirname(File::Spec->rel2abs(__FILE__))), "lapack-netlib"); if ($ARGV[12] == 1) { - @blasobjs = (@blasobjs, @bfblasobjs); - @cblasobjs = (@cblasobjs, @bfcblasobjs); + @blasobjs = (@blasobjs, @bfblasobjs, @hfblasobjs); + @cblasobjs = (@cblasobjs, @bfcblasobjs, @hfcblasobjs); } if ($ARGV[13] == 1) { @blasobjs = (@blasobjs, @blasobjss); diff --git a/getarch_2nd.c b/getarch_2nd.c index dd1f830895..8170e9cf33 100644 --- a/getarch_2nd.c +++ b/getarch_2nd.c @@ -19,6 +19,8 @@ int main(int argc, char **argv) { if ( (argc <= 1) || ((argc >= 2) && (*argv[1] == '0'))) { printf("SBGEMM_UNROLL_M=%d\n", SBGEMM_DEFAULT_UNROLL_M); printf("SBGEMM_UNROLL_N=%d\n", SBGEMM_DEFAULT_UNROLL_N); + printf("SHGEMM_UNROLL_M=%d\n", SHGEMM_DEFAULT_UNROLL_M); + printf("SHGEMM_UNROLL_N=%d\n", SHGEMM_DEFAULT_UNROLL_N); printf("SGEMM_UNROLL_M=%d\n", SGEMM_DEFAULT_UNROLL_M); printf("SGEMM_UNROLL_N=%d\n", SGEMM_DEFAULT_UNROLL_N); printf("DGEMM_UNROLL_M=%d\n", DGEMM_DEFAULT_UNROLL_M); diff --git a/interface/CMakeLists.txt b/interface/CMakeLists.txt index a3ee6559e9..b4c1b769d7 100644 --- a/interface/CMakeLists.txt +++ b/interface/CMakeLists.txt @@ -136,6 +136,9 @@ if (BUILD_BFLOAT16) GenerateNamedObjects("gemm_batch.c" "" "sbgemm_batch" ${CBLAS_FLAG} "" "" true "BFLOAT16") endif () endif () +if (BUILD_HFLOAT16) + GenerateNamedObjects("gemm.c" "" "shgemm" ${CBLAS_FLAG} "" "" true "HFLOAT16") +endif () # complex-specific sources foreach (float_type ${FLOAT_TYPES}) diff --git a/interface/Makefile b/interface/Makefile index f09a6f46b9..3ac54628f6 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -53,6 +53,10 @@ SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) endif +ifeq ($(BUILD_HFLOAT16),1) +SHBLAS3OBJS = shgemm.$(SUFFIX) +endif + DBLAS1OBJS = \ daxpy.$(SUFFIX) dswap.$(SUFFIX) \ dcopy.$(SUFFIX) dscal.$(SUFFIX) \ @@ -291,6 +295,10 @@ CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemmtr.$(S CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) endif +ifeq ($(BUILD_HFLOAT16),1) +CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX) +endif + CDBLAS1OBJS = \ cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \ cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \ @@ -388,6 +396,7 @@ SBLAS3OBJS += $(CSBLAS3OBJS) SBBLAS1OBJS += $(CSBBLAS1OBJS) SBBLAS2OBJS += $(CSBBLAS2OBJS) SBBLAS3OBJS += $(CSBBLAS3OBJS) +SHBLAS3OBJS += $(CSHBLAS3OBJS) DBLAS1OBJS += $(CDBLAS1OBJS) DBLAS2OBJS += $(CDBLAS2OBJS) DBLAS3OBJS += $(CDBLAS3OBJS) @@ -405,6 +414,7 @@ endif SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS) +SHBLASOBJS = $(SHBLAS3OBJS) DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS) CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS) @@ -512,7 +522,7 @@ ifneq ($(BUILD_COMPLEX16),1) ZBLASOBJS= endif -FUNCOBJS = $(SBEXTOBJS) $(CXERBLAOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) +FUNCOBJS = $(SBEXTOBJS) $(CXERBLAOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(SHBLASOBJS) ifeq ($(EXPRECISION), 1) FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS) @@ -550,7 +560,7 @@ level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $ level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ -level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) +level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(SHBLAS3OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ aux : $(CBAUXOBJS) @@ -1309,6 +1319,11 @@ sbgemmtr.$(SUFFIX) sbgemmtr.$(PSUFFIX) : sbgemmt.c ../param.h $(CC) -c $(CFLAGS) -DRNAME $< -o $(@F) endif +ifeq ($(BUILD_HFLOAT16),1) +shgemm.$(SUFFIX) shgemm.$(PSUFFIX) : gemm.c ../param.h + $(CC) -c $(CFLAGS) $< -o $(@F) +endif + sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -c $(CFLAGS) $< -o $(@F) @@ -1968,6 +1983,11 @@ cblas_sbgemm.$(SUFFIX) cblas_sbgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) endif +ifeq ($(BUILD_HFLOAT16),1) +cblas_shgemm.$(SUFFIX) cblas_shgemm.$(PSUFFIX) : gemm.c ../param.h + $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) +endif + cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index 48c8955888..9434f114ea 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -351,6 +351,22 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${SBGEMMKERNEL}" "" "gemm_kernel" false "" "" false "BFLOAT16") GenerateNamedObjects("${KERNELDIR}/${SBGEMM_BETA}" "" "gemm_beta" false "" "" false "BFLOAT16") endif () + if (BUILD_HFLOAT16) + if (SHGEMMINCOPY) + GenerateNamedObjects("${KERNELDIR}/${SHGEMMINCOPY}" "" "${SHGEMMINCOPYOBJ}" false "" "" true "HFLOAT16") + endif () + if (SHGEMMITCOPY) + GenerateNamedObjects("${KERNELDIR}/${SHGEMMITCOPY}" "" "${SHGEMMITCOPYOBJ}" false "" "" true "HFLOAT16") + endif () + if (SHGEMMONCOPY) + GenerateNamedObjects("${KERNELDIR}/${SHGEMMONCOPY}" "" "${SHGEMMONCOPYOBJ}" false "" "" true "HFLOAT16") + endif () + if (SHGEMMOTCOPY) + GenerateNamedObjects("${KERNELDIR}/${SHGEMMOTCOPY}" "" "${SHGEMMOTCOPYOBJ}" false "" "" true "HFLOAT16") + endif () + GenerateNamedObjects("${KERNELDIR}/${SHGEMMKERNEL}" "" "gemm_kernel" false "" "" false "HFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SHGEMM_BETA}" "" "gemm_beta" false "" "" false "HFLOAT16") + endif () foreach (float_type ${FLOAT_TYPES}) string(SUBSTRING ${float_type} 0 1 float_char) if (${float_char}GEMMINCOPY) @@ -769,6 +785,45 @@ endif () GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_B0_TN}" "B0" "gemm_small_kernel_b0_tn" false "" "" false "BFLOAT16") GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_B0_TT}" "B0" "gemm_small_kernel_b0_tt" false "" "" false "BFLOAT16") endif () + + if (BUILD_HFLOAT16) + if (NOT DEFINED SHGEMM_SMALL_M_PERMIT) + set(SHGEMM_SMALL_M_PERMIT ../generic/gemm_small_matrix_permit.c) + endif () + if (NOT DEFINED SHGEMM_SMALL_K_NN) + set(SHGEMM_SMALL_K_NN ../generic/gemm_small_matrix_kernel_nn.c) + endif () + if (NOT DEFINED SHGEMM_SMALL_K_NT) + set(SHGEMM_SMALL_K_NT ../generic/gemm_small_matrix_kernel_nt.c) + endif () + if (NOT DEFINED SHGEMM_SMALL_K_TN) + set(SHGEMM_SMALL_K_TN ../generic/gemm_small_matrix_kernel_tn.c) + endif () + if (NOT DEFINED SHGEMM_SMALL_K_TT) + set(SHGEMM_SMALL_K_TT ../generic/gemm_small_matrix_kernel_tt.c) + endif () + if (NOT DEFINED SHGEMM_SMALL_K_B0_NN) + set(SHGEMM_SMALL_K_B0_NN ../generic/gemm_small_matrix_kernel_nn.c) + endif () + if (NOT DEFINED SHGEMM_SMALL_K_B0_NT) + set(SHGEMM_SMALL_K_B0_NT ../generic/gemm_small_matrix_kernel_nt.c) + endif () + if (NOT DEFINED SHGEMM_SMALL_K_B0_TN) + set(SHGEMM_SMALL_K_B0_TN ../generic/gemm_small_matrix_kernel_tn.c) + endif () + if (NOT DEFINED SHGEMM_SMALL_K_B0_TT) + set(SHGEMM_SMALL_K_B0_TT ../generic/gemm_small_matrix_kernel_tt.c) + endif () + GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_M_PERMIT}" "" "gemm_small_matrix_permit" false "" "" false "HFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_NN}" "" "gemm_small_kernel_nn" false "" "" false "HFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_NT}" "" "gemm_small_kernel_nt" false "" "" false "HFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_TN}" "" "gemm_small_kernel_tn" false "" "" false "HFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_TT}" "" "gemm_small_kernel_tt" false "" "" false "HFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_B0_NN}" "B0" "gemm_small_kernel_b0_nn" false "" "" false "HFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_B0_NT}" "B0" "gemm_small_kernel_b0_nt" false "" "" false "HFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_B0_TN}" "B0" "gemm_small_kernel_b0_tn" false "" "" false "HFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_B0_TT}" "B0" "gemm_small_kernel_b0_tt" false "" "" false "HFLOAT16") + endif () endif () if (NOT DEFINED ${float_char}OMATCOPY_CN) diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 2bd6b294fb..71d66f8f34 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -129,6 +129,26 @@ SBKERNELOBJS += \ $(SBGEMMONCOPYOBJ) $(SBGEMMOTCOPYOBJ) endif +ifeq ($(BUILD_HFLOAT16), 1) +ifndef SHGEMMKERNEL +SHGEMM_BETA = ../generic/gemm_beta.c +SHGEMMKERNEL = ../generic/gemmkernel_2x2.c +SHGEMMINCOPY = ../generic/gemm_ncopy_2.c +SHGEMMITCOPY = ../generic/gemm_tcopy_2.c +SHGEMMONCOPY = ../generic/gemm_ncopy_2.c +SHGEMMOTCOPY = ../generic/gemm_tcopy_2.c +SHGEMMINCOPYOBJ = shgemm_incopy$(TSUFFIX).$(SUFFIX) +SHGEMMITCOPYOBJ = shgemm_itcopy$(TSUFFIX).$(SUFFIX) +SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX) +SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX) +endif + +SHKERNELOBJS += \ + shgemm_kernel$(TSUFFIX).$(SUFFIX) \ + $(SHGEMMINCOPYOBJ) $(SHGEMMITCOPYOBJ) \ + $(SHGEMMONCOPYOBJ) $(SHGEMMOTCOPYOBJ) +endif + ifneq "$(or $(BUILD_SINGLE),$(BUILD_DOUBLE),$(BUILD_COMPLEX))" "" SKERNELOBJS += \ sgemm_kernel$(TSUFFIX).$(SUFFIX) \ @@ -192,6 +212,9 @@ XKERNELOBJS += \ ifeq ($(BUILD_BFLOAT16),1) SBBLASOBJS += $(SBKERNELOBJS) endif +ifeq ($(BUILD_HFLOAT16),1) +SHBLASOBJS += $(SHKERNELOBJS) +endif SBLASOBJS += $(SKERNELOBJS) DBLASOBJS += $(DKERNELOBJS) QBLASOBJS += $(QKERNELOBJS) @@ -202,6 +225,9 @@ XBLASOBJS += $(XKERNELOBJS) ifeq ($(BUILD_BFLOAT16),1) SBBLASOBJS += sbgemm_beta$(TSUFFIX).$(SUFFIX) endif +ifeq ($(BUILD_HFLOAT16),1) +SHBLASOBJS += shgemm_beta$(TSUFFIX).$(SUFFIX) +endif ifneq "$(or $(BUILD_SINGLE),$(BUILD_DOUBLE),$(BUILD_COMPLEX))" "" SBLASOBJS += \ @@ -493,6 +519,15 @@ SBBLASOBJS += \ sbgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) endif +ifeq ($(BUILD_HFLOAT16),1) +SHBLASOBJS += \ + shgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ + shgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) shgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ + shgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) shgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + shgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) shgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ + shgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) shgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) +endif + SBLASOBJS += \ sgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ @@ -599,6 +634,13 @@ SBGEMMONCOPYOBJ_P = $(SBGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SBGEMMOTCOPYOBJ_P = $(SBGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) endif +ifeq ($(BUILD_HFLOAT16), 1) +SHGEMMINCOPYOBJ_P = $(SHGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +SHGEMMITCOPYOBJ_P = $(SHGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +SHGEMMONCOPYOBJ_P = $(SHGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +SHGEMMOTCOPYOBJ_P = $(SHGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +endif + SGEMMINCOPYOBJ_P = $(SGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMITCOPYOBJ_P = $(SGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMONCOPYOBJ_P = $(SGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) @@ -629,6 +671,11 @@ $(KDIR)sbgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA) $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ endif +ifeq ($(BUILD_HFLOAT16),1) +$(KDIR)shgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ +endif + $(KDIR)sgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_BETA) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ @@ -671,6 +718,25 @@ $(KDIR)$(SBGEMMITCOPYOBJ) : $(KERNELDIR)/$(SBGEMMITCOPY) endif endif +ifeq ($(BUILD_HFLOAT16), 1) + +$(KDIR)$(SHGEMMONCOPYOBJ) : $(KERNELDIR)/$(SHGEMMONCOPY) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(SHGEMMOTCOPYOBJ) : $(KERNELDIR)/$(SHGEMMOTCOPY) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N)) + +$(KDIR)$(SHGEMMINCOPYOBJ) : $(KERNELDIR)/$(SHGEMMINCOPY) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(SHGEMMITCOPYOBJ) : $(KERNELDIR)/$(SHGEMMITCOPY) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +endif +endif + $(KDIR)$(SGEMMONCOPYOBJ) : $(KERNELDIR)/$(SGEMMONCOPY) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ @@ -853,6 +919,12 @@ $(KDIR)sbgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMM $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ endif +ifeq ($(BUILD_HFLOAT16), 1) + +$(KDIR)shgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ +endif + $(KDIR)dgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMMKERNEL) $(DGEMMDEPEND) ifeq ($(OS), AIX) $(CC) $(CFLAGS) -S -DDOUBLE -UCOMPLEX $< -o - > dgemm_kernel$(TSUFFIX).s @@ -2840,6 +2912,11 @@ $(KDIR)sbgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA) $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ endif +ifeq ($(BUILD_HFLOAT16),1) +$(KDIR)shgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA) + $(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ +endif + $(KDIR)dgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DGEMM_BETA) $(CC) $(PFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ @@ -2873,6 +2950,23 @@ $(SBGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMITCOPY) endif endif +ifeq ($(BUILD_HFLOAT16), 1) +$(SHGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMONCOPY) + $(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(SHGEMMOTCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMOTCOPY) + $(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N)) +$(SHGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMINCOPY) + $(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(SHGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMITCOPY) + $(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +endif +endif + $(SGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SGEMMONCOPY) $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ @@ -2983,6 +3077,11 @@ $(KDIR)sbgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEM $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ endif +ifeq ($(BUILD_HFLOAT16), 1) +$(KDIR)shgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND) + $(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ +endif + $(KDIR)sgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMMKERNEL) $(SGEMMDEPEND) $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ @@ -4843,6 +4942,71 @@ $(KDIR)sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMA $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ endif +ifeq ($(BUILD_HFLOAT16), 1) +ifndef SHGEMM_SMALL_M_PERMIT +SHGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c +endif + +ifndef SHGEMM_SMALL_K_NN +SHGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c +endif + +ifndef SHGEMM_SMALL_K_NT +SHGEMM_SMALL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c +endif + +ifndef SHGEMM_SMALL_K_TN +SHGEMM_SMALL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c +endif + +ifndef SHGEMM_SMALL_K_TT +SHGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c +endif + +$(KDIR)shgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_M_PERMIT) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)shgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)shgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)shgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)shgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +ifndef SHGEMM_SMALL_K_B0_NN +SHGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_nn.c +endif + +ifndef SHGEMM_SMALL_K_B0_NT +SHGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_nt.c +endif + +ifndef SHGEMM_SMALL_K_B0_TN +SHGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_tn.c +endif + +ifndef SHGEMM_SMALL_K_B0_TT +SHGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_tt.c +endif + +$(KDIR)shgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)shgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)shgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)shgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ +endif + ifndef CGEMM_SMALL_M_PERMIT CGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c endif diff --git a/kernel/riscv64/KERNEL.RISCV64_ZVL128B b/kernel/riscv64/KERNEL.RISCV64_ZVL128B index 7fbc26d213..d2a2d35786 100644 --- a/kernel/riscv64/KERNEL.RISCV64_ZVL128B +++ b/kernel/riscv64/KERNEL.RISCV64_ZVL128B @@ -245,3 +245,12 @@ endif ifndef ZGEMM_BETA ZGEMM_BETA = zgemm_beta_rvv.c endif + +SHGEMMKERNEL = shgemm_kernel_$(SHGEMM_UNROLL_M)x$(SHGEMM_UNROLL_N)_zvl128b.c +SHGEMMONCOPY = ../generic/gemm_ncopy_$(SHGEMM_UNROLL_N).c +SHGEMMOTCOPY = ../generic/gemm_tcopy_$(SHGEMM_UNROLL_N).c +SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX) +SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX) +ifndef SHGEMM_BETA +SHGEMM_BETA = gemm_beta_rvv.c +endif \ No newline at end of file diff --git a/kernel/riscv64/KERNEL.RISCV64_ZVL256B b/kernel/riscv64/KERNEL.RISCV64_ZVL256B index 2b4f0a5455..847ebff705 100644 --- a/kernel/riscv64/KERNEL.RISCV64_ZVL256B +++ b/kernel/riscv64/KERNEL.RISCV64_ZVL256B @@ -209,5 +209,21 @@ COMATCOPY_CN = zomatcopy_cn_vector.c DOMATCOPY_CN = omatcopy_cn_vector.c SOMATCOPY_CN = omatcopy_cn_vector.c + +SHGEMMKERNEL = shgemm_kernel_$(SHGEMM_UNROLL_M)x$(SHGEMM_UNROLL_N)_zvl256b.c +ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N)) +SHGEMMINCOPY = ../generic/gemm_ncopy_$(SHGEMM_UNROLL_M).c +SHGEMMITCOPY = ../generic/gemm_tcopy_$(SHGEMM_UNROLL_M).c +SHGEMMINCOPYOBJ = shgemm_incopy$(TSUFFIX).$(SUFFIX) +SHGEMMITCOPYOBJ = shgemm_itcopy$(TSUFFIX).$(SUFFIX) +endif +SHGEMMONCOPY = ../generic/gemm_ncopy_$(SHGEMM_UNROLL_N).c +SHGEMMOTCOPY = ../generic/gemm_tcopy_$(SHGEMM_UNROLL_N).c +SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX) +SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX) +ifndef SHGEMM_BETA +SHGEMM_BETA = gemm_beta_rvv.c +endif + SAXPBYKERNEL = axpby_vector_v2.c DAXPBYKERNEL = axpby_vector_v2.c diff --git a/kernel/riscv64/shgemm_kernel_16x8_zvl256b.c b/kernel/riscv64/shgemm_kernel_16x8_zvl256b.c new file mode 100644 index 0000000000..fb98f564c3 --- /dev/null +++ b/kernel/riscv64/shgemm_kernel_16x8_zvl256b.c @@ -0,0 +1,969 @@ + +#include "common.h" +#include +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) +{ + BLASLONG gvl = 0; + BLASLONG m_top = 0; + BLASLONG n_top = 0; + + // -- MAIN PASS + for (BLASLONG j=0; j + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) +{ + BLASLONG gvl = 0; + BLASLONG m_top = 0; + BLASLONG n_top = 0; + + // -- MAIN PASS + for (BLASLONG j=0; j= 12) +typedef _Float16 hfloat16; +#else +#include +typedef uint16_t hfloat16; +#endif + #ifdef OPENBLAS_USE64BITINT typedef BLASLONG blasint; #else diff --git a/param.h b/param.h index 48b64fd2ae..cdc48cbe92 100644 --- a/param.h +++ b/param.h @@ -72,6 +72,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef PARAM_H #define PARAM_H +#define SHGEMM_DEFAULT_UNROLL_N 8 +#define SHGEMM_DEFAULT_UNROLL_M 8 +#define SHGEMM_DEFAULT_UNROLL_MN 32 +#define SHGEMM_DEFAULT_P 128 +#define SHGEMM_DEFAULT_R 240 +#define SHGEMM_DEFAULT_Q 12288 #define SBGEMM_DEFAULT_UNROLL_N 4 #define SBGEMM_DEFAULT_UNROLL_M 8 @@ -3138,10 +3144,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #endif #ifdef RISCV64_ZVL128B + #define GEMM_DEFAULT_OFFSET_A 0 #define GEMM_DEFAULT_OFFSET_B 0 #define GEMM_DEFAULT_ALIGN (BLASLONG)0x03fffUL +#undef SHGEMM_DEFAULT_UNROLL_M +#undef SHGEMM_DEFAULT_UNROLL_N +#define SHGEMM_DEFAULT_UNROLL_M 8 +#define SHGEMM_DEFAULT_UNROLL_N 8 + #define SGEMM_DEFAULT_UNROLL_M 8 #define SGEMM_DEFAULT_UNROLL_N 8 @@ -3154,16 +3166,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define ZGEMM_DEFAULT_UNROLL_M 4 #define ZGEMM_DEFAULT_UNROLL_N 4 +#undef SHGEMM_DEFAULT_P +#define SHGEMM_DEFAULT_P 128 #define SGEMM_DEFAULT_P 128 #define DGEMM_DEFAULT_P 128 #define CGEMM_DEFAULT_P 96 #define ZGEMM_DEFAULT_P 64 +#undef SHGEMM_DEFAULT_Q +#define SHGEMM_DEFAULT_Q 240 #define SGEMM_DEFAULT_Q 240 #define DGEMM_DEFAULT_Q 120 #define CGEMM_DEFAULT_Q 120 #define ZGEMM_DEFAULT_Q 120 +#undef SHGEMM_DEFAULT_R +#define SHGEMM_DEFAULT_R 12288 #define SGEMM_DEFAULT_R 12288 #define DGEMM_DEFAULT_R 8192 #define CGEMM_DEFAULT_R 4096 @@ -3181,6 +3199,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define GEMM_DEFAULT_OFFSET_B 0 #define GEMM_DEFAULT_ALIGN 0x03fffUL +#undef SHGEMM_DEFAULT_UNROLL_M +#undef SHGEMM_DEFAULT_UNROLL_N +#define SHGEMM_DEFAULT_UNROLL_M 16 +#define SHGEMM_DEFAULT_UNROLL_N 8 + #define SGEMM_DEFAULT_UNROLL_M 16 #define SGEMM_DEFAULT_UNROLL_N 8 @@ -3193,16 +3216,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define ZGEMM_DEFAULT_UNROLL_M 8 #define ZGEMM_DEFAULT_UNROLL_N 4 +#undef SHGEMM_DEFAULT_P +#define SHGEMM_DEFAULT_P 128 #define SGEMM_DEFAULT_P 128 #define DGEMM_DEFAULT_P 64 #define CGEMM_DEFAULT_P 64 #define ZGEMM_DEFAULT_P 64 +#undef SHGEMM_DEFAULT_Q +#define SHGEMM_DEFAULT_Q 128 #define SGEMM_DEFAULT_Q 128 #define DGEMM_DEFAULT_Q 128 #define CGEMM_DEFAULT_Q 128 #define ZGEMM_DEFAULT_Q 64 +#undef SHGEMM_DEFAULT_R +#define SHGEMM_DEFAULT_R 16384 #define SGEMM_DEFAULT_R 16384 #define DGEMM_DEFAULT_R 8192 #define CGEMM_DEFAULT_R 8192