Skip to content

Commit 05d5c6a

Browse files
committed
Parallel general matrix inversion and test
1 parent cb6d689 commit 05d5c6a

File tree

4 files changed

+237
-3
lines changed

4 files changed

+237
-3
lines changed

lib/scalapack.fpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module scalapack_module
88
private
99

1010
public :: psygst, phegst, psyev, pheev, psyevd, pheevd, psyevr, pheevr
11-
public :: ptrsm, ppotrf, ppotri, ptrtri, pgetrf, pgesvd
11+
public :: ptrsm, ppotrf, ppotri, ptrtri, pgetrf, pgetri, pgesvd
1212
public :: sl_init, numroc, infog2l, indxl2g, descinit, indxg2p
1313

1414
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -76,12 +76,30 @@ module scalapack_module
7676
end subroutine p${TYPEABBREV}$getrf
7777
#:enddef interface_pgetrf_template
7878

79+
7980
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
80-
!!! psygst
81+
!!! pgetri
8182
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
8283

84+
#:def interface_pgetri_template(TYPEABBREV, FTYPES)
85+
!> Inverse of an LU factorized general matrix (${TYPE}$).
86+
subroutine p${TYPEABBREV}$getri(nn, aa, ia, ja, desca, ipiv, work, lwork, iwork, liwork, info)
87+
import
88+
integer, intent(in) :: nn
89+
integer, intent(in) :: ia, ja, desca(*)
90+
${FTYPES}$, intent(inout) :: aa(desca(LLD_), *)
91+
integer, intent(out) :: ipiv(*)
92+
${FTYPES}$, intent(inout) :: work(*)
93+
integer, intent(inout) :: iwork(*)
94+
integer, intent(in) :: lwork, liwork
95+
integer, intent(out) :: info
96+
end subroutine p${TYPEABBREV}$getri
97+
#:enddef interface_pgetri_template
8398

8499

100+
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
101+
!!! psygst
102+
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
85103

86104
#:def interface_psygst_template(TYPEABBREV, KIND)
87105
!> Reduces generalized symmetric eigenvalue problem to standard form (${TYPE}$).
@@ -388,6 +406,15 @@ module scalapack_module
388406
#:endfor
389407
end interface pgetrf
390408

409+
!> Inversion of an LU-decomposed general matrix with pivoting
410+
interface pgetri
411+
#:for TYPE in TYPES
412+
#:set TYPEABBREV = TYPE_ABBREVS[TYPE]
413+
#:set FTYPE = FORTRAN_TYPES[TYPE]
414+
$:interface_pgetri_template(TYPEABBREV, FTYPE)
415+
#:endfor
416+
end interface pgetri
417+
391418
!> Reduces generalized symmetric eigenvalue problem to standard form.
392419
interface psygst
393420
#:for TYPE in REAL_TYPES
@@ -533,4 +560,3 @@ module scalapack_module
533560

534561

535562
end module scalapack_module
536-

lib/scalapackfx.fpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ module scalapackfx_module
1313
public :: scalafx_ppotri
1414
public :: scalafx_ptrtri
1515
public :: scalafx_pgetrf
16+
public :: scalafx_pgetri
1617
public :: scalafx_psygst
1718
public :: scalafx_phegst
1819
public :: scalafx_psyev
@@ -64,6 +65,12 @@ module scalapackfx_module
6465
module procedure scalafx_pgetrf_complex, scalafx_pgetrf_dcomplex
6566
end interface scalafx_pgetrf
6667

68+
!> Inverse of a LU decomposed general matrix.
69+
interface scalafx_pgetri
70+
module procedure scalafx_pgetri_real, scalafx_pgetri_dreal
71+
module procedure scalafx_pgetri_complex, scalafx_pgetri_dcomplex
72+
end interface scalafx_pgetri
73+
6774
!> Reduces symmetric definite generalized eigenvalue problem to standard form.
6875
interface scalafx_psygst
6976
module procedure scalafx_psygst_real, scalafx_psygst_dreal
@@ -354,6 +361,75 @@ module scalapackfx_module
354361
#:enddef scalafx_pgetrf_template
355362

356363

364+
!************************************************************************
365+
!*** pgetri
366+
!************************************************************************
367+
368+
#:def scalafx_pgetri_template(TYPE, FTYPE)
369+
370+
!> Inversion of a LU-factorized general matrix with pivoting
371+
!!
372+
subroutine scalafx_pgetri_${TYPE}$(aa, desca, ipiv, ia, ja, nn, work, iwork, info)
373+
374+
!> Inverse exit, pivoted by ipiv
375+
${FTYPE}$, intent(inout) :: aa(:,:)
376+
377+
!> Descriptor of A.
378+
integer, intent(in) :: desca(DLEN_)
379+
380+
!> Pivot matrix
381+
integer, intent(out) :: ipiv(:)
382+
383+
!> First row of the submatrix of A. Default: 1
384+
integer, intent(in), optional :: ia
385+
386+
!> First column of the submatrix of A. Default: 1
387+
integer, intent(in), optional :: ja
388+
389+
!> Number of rows in the submatrix of A. Default: desca(N_)
390+
integer, intent(in), optional :: nn
391+
392+
!> Work array, if provided externally
393+
${FTYPE}$, intent(inout), allocatable, optional :: work(:)
394+
395+
!> Integer work array, if provided externally
396+
integer, intent(inout), allocatable, optional :: iwork(:)
397+
398+
!> Info flag. If not specified and error occurs, the subroutine stops.
399+
integer, intent(out), optional :: info
400+
401+
!------------------------------------------------------------------------
402+
403+
${FTYPE}$, allocatable :: work0(:)
404+
integer, allocatable :: iwork0(:)
405+
integer :: lwork, liwork, info0
406+
integer :: ia0, ja0, nn0
407+
${FTYPE}$ :: rtmp(1)
408+
integer :: itmp(1)
409+
410+
@:inoptflags(ia0, ia, 1)
411+
@:inoptflags(ja0, ja, 1)
412+
@:inoptflags(nn0, nn, desca(N_))
413+
414+
! Allocate workspace
415+
nn0 = desca(M_)
416+
call pgetri(nn0, aa, ia0, ja0, desca, ipiv, rtmp, -1, itmp, -1, info0)
417+
call handle_infoflag(info0, "pgetri in scalafx_pgetri_${TYPE}$", info)
418+
@:move_minoptalloc(work0, int(rtmp(1)), lwork, work)
419+
@:move_minoptalloc(iwork0, itmp(1), liwork, iwork)
420+
421+
call pgetri(nn0, aa, ia0, ja0, desca, ipiv, work0, lwork, iwork0, liwork, info0)
422+
call handle_infoflag(info0, "pgetri in scalafx_pgetri_${TYPE}$", info)
423+
424+
! Save work space allocations, if dummy argument present
425+
@:optmovealloc(work0, work)
426+
@:optmovealloc(iwork0, iwork)
427+
428+
end subroutine scalafx_pgetri_${TYPE}$
429+
430+
#:enddef scalafx_pgetri_template
431+
432+
357433
!************************************************************************
358434
!*** psygst / phegst
359435
!************************************************************************
@@ -1879,6 +1955,7 @@ contains
18791955
$:scalafx_ppotri_template(TYPE, FTYPE)
18801956
$:scalafx_ptrtri_template(TYPE, FTYPE)
18811957
$:scalafx_pgetrf_template(TYPE, FTYPE)
1958+
$:scalafx_pgetri_template(TYPE, FTYPE)
18821959
$:scalafx_psygst_phegst_template(TYPE, FTYPE)
18831960

18841961
#:if TYPE in REAL_TYPES

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ set(common-dep-targets
88
test_diag
99
test_gemr2d
1010
test_linecomm
11+
test_matinv
1112
test_psyr_pher
1213
test_remoteelements
1314
test_svd

test/test_matinv.f90

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
!> Testing matrix inversion via getrf and getri
2+
program test_pmatinv
3+
use, intrinsic :: iso_fortran_env, stdout => output_unit
4+
use test_common_module
5+
use libscalapackfx_module
6+
implicit none
7+
8+
! Block size (using an extremely small value for test purposes)
9+
integer, parameter :: bsize = 1
10+
11+
call main()
12+
13+
contains
14+
15+
subroutine main()
16+
type(blacsgrid) :: mygrid
17+
real(dp), allocatable :: r_xx(:,:), r_xxOld(:,:), r_matI(:,:)
18+
complex(dp), allocatable :: c_xx(:,:), c_xxOld(:,:), c_matI(:,:)
19+
20+
integer :: descx(DLEN_)
21+
integer :: nprow, npcol, mm, nn, iproc, nproc, ii, jj, iLoc, jLoc, kLoc, jGlob
22+
integer, allocatable :: ipiv(:)
23+
logical :: isLocal, failed
24+
25+
! Initialize blas and create a square processor grid
26+
call blacsfx_pinfo(iproc, nproc)
27+
do nprow = int(sqrt(real(nproc, dp))), nproc
28+
if (mod(nproc, nprow) == 0) then
29+
exit
30+
end if
31+
end do
32+
npcol = nproc / nprow
33+
call mygrid%initgrid(nprow, npcol)
34+
if (mygrid%lead) then
35+
write(stdout, "(A,2(1X,I0))") "# processor grid:", nprow, npcol
36+
write(stdout, "(A,1X,I0)") "# block size:", bsize
37+
end if
38+
39+
! Set up matrix
40+
if (mygrid%lead) then
41+
write(stdout, "(A)") "Matrix read from file 'hamsqr1.dat'."
42+
end if
43+
call readfromfile(mygrid, "hamsqr1.dat", bsize, bsize, r_xx, descx)
44+
mm = descx(M_)
45+
nn = descx(N_)
46+
47+
allocate(ipiv(min(mm,nn)), source=0)
48+
49+
if (mygrid%lead) then
50+
write(stdout, "(A,2(1X,I0))") "# global matrix size:", mm, nn
51+
write(stdout, "(A,2(1X,I0))") "# local matrix size on leader:",&
52+
& size(r_xx, dim=1), size(r_xx, dim=2)
53+
end if
54+
55+
r_xxOld = r_xx
56+
r_matI = r_xx ! just to get sizing
57+
58+
call scalafx_pgetrf(r_xx, descx, ipiv)
59+
call scalafx_pgetri(r_xx, descx, ipiv)
60+
61+
! xx (*) xx^-1
62+
call pblasfx_pgemm(r_xx, descx, r_xxOld, descx, r_matI, descx)
63+
64+
failed = .false.
65+
do ii = 1, size(r_matI,dim=2)
66+
jGlob = scalafx_indxl2g(ii, descx(NB_), mygrid%mycol, descx(CSRC_), mygrid%ncol)
67+
! where is the global
68+
call scalafx_islocal(mygrid, descx, jGlob, jGlob, isLocal, iLoc, jLoc)
69+
if (isLocal) then
70+
! a global diagonal element is stored here
71+
if (abs(r_matI(iLoc,jLoc) - 1.0_dp) > epsilon(0.0)) failed = .true.
72+
else
73+
do kLoc = 1, size(r_matI,dim=1)
74+
if (abs(r_matI(iLoc,jLoc)) > epsilon(0.0)) failed = .true.
75+
end do
76+
end if
77+
end do
78+
79+
! Would normally accumulate product via mpi calls, but messy to do via blacs operations only
80+
if (failed) then
81+
write(stdout,*)'Real matrix element(s) on processor', mygrid%iproc, ' non-identity matrix'
82+
else
83+
write(stdout,*)'Real matrix elements on processor', mygrid%iproc, ' are OK'
84+
end if
85+
86+
deallocate(r_xx)
87+
deallocate(r_matI)
88+
c_xx = r_xxOld
89+
deallocate(r_xxOld)
90+
c_xxOld = c_xx
91+
c_matI = c_xx
92+
93+
call scalafx_pgetrf(c_xx, descx, ipiv)
94+
call scalafx_pgetri(c_xx, descx, ipiv)
95+
96+
! xx (*) xx^-1
97+
call pblasfx_pgemm(c_xx, descx, c_xxOld, descx, c_matI, descx)
98+
99+
failed = .false.
100+
do ii = 1, size(c_matI,dim=2)
101+
jGlob = scalafx_indxl2g(ii, descx(NB_), mygrid%mycol, descx(CSRC_), mygrid%ncol)
102+
! where is the global
103+
call scalafx_islocal(mygrid, descx, jGlob, jGlob, isLocal, iLoc, jLoc)
104+
if (isLocal) then
105+
! a global diagonal element is stored here
106+
if (abs(c_matI(iLoc,jLoc) - cmplx(1,0,dp)) > epsilon(0.0)) failed = .true.
107+
else
108+
do kLoc = 1, size(c_matI,dim=1)
109+
if (abs(c_matI(iLoc,jLoc)) > epsilon(0.0)) failed = .true.
110+
end do
111+
end if
112+
end do
113+
114+
! Would normally accumulate product via mpi calls, but messy to do via blacs operations only
115+
if (failed) then
116+
write(stdout,*)'Complex matrix element(s) on processor', mygrid%iproc, ' non-identity matrix'
117+
else
118+
write(stdout,*)'Complex matrix elements on processor', mygrid%iproc, ' are OK'
119+
end if
120+
121+
deallocate(c_xx)
122+
deallocate(c_matI)
123+
deallocate(c_xxOld)
124+
125+
! Finish blacs.
126+
call blacsfx_exit()
127+
128+
end subroutine main
129+
130+
end program test_pmatinv

0 commit comments

Comments
 (0)