Skip to content

Commit

Permalink
Merge pull request #2 from jvdp1/blas_jv
Browse files Browse the repository at this point in the history
Support of MKL
  • Loading branch information
zoziha authored Jul 17, 2024
2 parents 3254fd0 + b867c2b commit 9de0589
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 16 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ jobs:
-DCMAKE_MAXIMUM_RANK:String=4
-DCMAKE_INSTALL_PREFIX=$PWD/_dist
-DFIND_BLAS:STRING=TRUE
-DBLAS_LIBRARIES="/opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_rt.so"
-DLAPACK_LIBRARIES="/opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_rt.so"
-S . -B build_mkl
- name: Build and compile with MKL
Expand Down
16 changes: 14 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,23 @@ option(FIND_BLAS "Find external BLAS and LAPACK" ON)

# --- find BLAS and LAPACK
if(FIND_BLAS)
find_package(BLAS)
if(NOT BLAS_FOUND)
#Required for MKL
if(DEFINED ENV{MKLROOT} OR "${BLA_VENDOR}" MATCHES "^Intel")
enable_language("C")
endif()
find_package("BLAS")
endif()
if(BLAS_FOUND)
add_compile_definitions(STDLIB_EXTERNAL_BLAS)
endif()
find_package(LAPACK)
if(NOT LAPACK_FOUND)
#Required for MKL
if(DEFINED ENV{MKLROOT} OR "${BLA_VENDOR}" MATCHES "^Intel")
enable_language("C")
endif()
find_package("LAPACK")
endif()
if(LAPACK_FOUND)
add_compile_definitions(STDLIB_EXTERNAL_LAPACK)
endif()
Expand Down
4 changes: 2 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ add_library(${PROJECT_NAME} ${SRC})

# Link to BLAS and LAPACK
if(BLAS_FOUND)
target_link_libraries(${PROJECT_NAME} BLAS::BLAS)
target_link_libraries(${PROJECT_NAME} "BLAS::BLAS")
endif()
if(LAPACK_FOUND)
target_link_libraries(${PROJECT_NAME} LAPACK::LAPACK)
target_link_libraries(${PROJECT_NAME} "LAPACK::LAPACK")
endif()

set_target_properties(
Expand Down
20 changes: 10 additions & 10 deletions test/linalg/test_linalg_svd.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ module test_linalg_svd
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
call check(error, all(abs(abs(u)-abs(u_sol))<=tol), test//': U')
if (allocated(error)) return

!> [S, U]. Overwrite A matrix
Expand All @@ -104,7 +104,7 @@ module test_linalg_svd
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
call check(error, all(abs(abs(u)-abs(u_sol))<=tol), test//': U')
if (allocated(error)) return

!> [S, U, V^T]
Expand All @@ -116,9 +116,9 @@ module test_linalg_svd
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
call check(error, all(abs(abs(u)-abs(u_sol))<=tol), test//': U')
if (allocated(error)) return
call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
call check(error, all(abs(abs(vt)-abs(vt_sol))<=tol), test//': V^T')
if (allocated(error)) return

!> [S, V^T]. Do not overwrite A matrix
Expand All @@ -130,7 +130,7 @@ module test_linalg_svd
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
call check(error, all(abs(abs(vt)-abs(vt_sol))<=tol), test//': V^T')
if (allocated(error)) return

!> [S, V^T]. Overwrite A matrix
Expand All @@ -141,7 +141,7 @@ module test_linalg_svd
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
call check(error, all(abs(abs(vt)-abs(vt_sol))<=tol), test//': V^T')
if (allocated(error)) return

!> [U, S, V^T].
Expand All @@ -151,11 +151,11 @@ module test_linalg_svd
test = '[U, S, V^T]'
call check(error,state%ok(),test//': '//state%print())
if (allocated(error)) return
call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
call check(error, all(abs(abs(u)-abs(u_sol))<=tol), test//': U')
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
call check(error, all(abs(abs(vt)-abs(vt_sol))<=tol), test//': V^T')
if (allocated(error)) return

!> [U, S, V^T]. Partial storage -> compare until k=2 columns of U rows of V^T
Expand All @@ -167,11 +167,11 @@ module test_linalg_svd
test = '[U, S, V^T], partial storage'
call check(error,state%ok(),test//': '//state%print())
if (allocated(error)) return
call check(error, all(abs(u(:,:2)-u_sol(:,:2))<=tol) .or. all(abs(u(:,:2)+u_sol(:,:2))<=tol), test//': U(:,:2)')
call check(error, all(abs(abs(u(:,:2))-abs(u_sol(:,:2)))<=tol), test//': U(:,:2)')
if (allocated(error)) return
call check(error, all(abs(s-s_sol)<=tol), test//': S')
if (allocated(error)) return
call check(error, all(abs(vt(:2,:)-vt_sol(:2,:))<=tol) .or. all(abs(vt(:2,:)+vt_sol(:2,:))<=tol), test//': V^T(:2,:)')
call check(error, all(abs(abs(vt(:2,:))-abs(vt_sol(:2,:)))<=tol), test//': V^T(:2,:)')
if (allocated(error)) return

end subroutine test_svd_${ri}$
Expand Down

0 comments on commit 9de0589

Please sign in to comment.