Skip to content

Commit

Permalink
Added example with redoing the factorization through KLU (#77)
Browse files Browse the repository at this point in the history
* example with KLU redoing factorization

* fixing memory leaks in LinSolverDirectRocSolverRf
  • Loading branch information
kswirydo authored Nov 22, 2023
1 parent fa10d84 commit d1e82eb
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 15 deletions.
7 changes: 6 additions & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ if(RESOLVE_USE_HIP)
# Build example with KLU factorization, rocsolver Rf refactorization, and FGMRES iterative refinement
add_executable(klu_rocsolverrf_fgmres.exe r_KLU_rocSolverRf_FGMRES.cpp)
target_link_libraries(klu_rocsolverrf_fgmres.exe PRIVATE ReSolve)

# Example in which factorization is redone if solution is bad
add_executable(klu_rocsolverrf_check_redo.exe r_KLU_rocsolverrf_redo_factorization.cpp)
target_link_libraries(klu_rocsolverrf_check_redo.exe PRIVATE ReSolve)

endif(RESOLVE_USE_HIP)

# Install all examples in bin directory
Expand All @@ -58,7 +63,7 @@ if(RESOLVE_USE_CUDA)
endif(RESOLVE_USE_CUDA)

if(RESOLVE_USE_HIP)
set(installable_executables ${installable_executables} klu_rocsolverrf.exe)
set(installable_executables ${installable_executables} klu_rocsolverrf.exe klu_rocsolverrf_fgmres.exe klu_rocsolverrf_check_redo.exe)
endif(RESOLVE_USE_HIP)

install(TARGETS ${installable_executables}
Expand Down
210 changes: 210 additions & 0 deletions examples/r_KLU_rocsolverrf_redo_factorization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#include <string>
#include <iostream>
#include <iomanip>

#include <resolve/matrix/Coo.hpp>
#include <resolve/matrix/Csr.hpp>
#include <resolve/matrix/Csc.hpp>
#include <resolve/vector/Vector.hpp>
#include <resolve/matrix/io.hpp>
#include <resolve/matrix/MatrixHandler.hpp>
#include <resolve/vector/VectorHandler.hpp>
#include <resolve/LinSolverDirectKLU.hpp>
#include <resolve/LinSolverDirectRocSolverRf.hpp>
#include <resolve/workspace/LinAlgWorkspace.hpp>

using namespace ReSolve::constants;

int main(int argc, char *argv[] )
{
// Use the same data types as those you specified in ReSolve build.
using index_type = ReSolve::index_type;
using real_type = ReSolve::real_type;
using vector_type = ReSolve::vector::Vector;

(void) argc; // TODO: Check if the number of input parameters is correct.
std::string matrixFileName = argv[1];
std::string rhsFileName = argv[2];

index_type numSystems = atoi(argv[3]);
std::cout<<"Family mtx file name: "<< matrixFileName << ", total number of matrices: "<<numSystems<<std::endl;
std::cout<<"Family rhs file name: "<< rhsFileName << ", total number of RHSes: " << numSystems<<std::endl;

std::string fileId;
std::string rhsId;
std::string matrixFileNameFull;
std::string rhsFileNameFull;

ReSolve::matrix::Coo* A_coo;
ReSolve::matrix::Csr* A;

ReSolve::LinAlgWorkspaceHIP* workspace_HIP = new ReSolve::LinAlgWorkspaceHIP;
workspace_HIP->initializeHandles();
ReSolve::MatrixHandler* matrix_handler = new ReSolve::MatrixHandler(workspace_HIP);
ReSolve::VectorHandler* vector_handler = new ReSolve::VectorHandler(workspace_HIP);
real_type* rhs = nullptr;
real_type* x = nullptr;

vector_type* vec_rhs;
vector_type* vec_x;
vector_type* vec_r;

ReSolve::LinSolverDirectKLU* KLU = new ReSolve::LinSolverDirectKLU;
ReSolve::LinSolverDirectRocSolverRf* Rf = new ReSolve::LinSolverDirectRocSolverRf(workspace_HIP);

real_type res_nrm;
real_type b_nrm;

for (int i = 0; i < numSystems; ++i)
{
index_type j = 4 + i * 2;
fileId = argv[j];
rhsId = argv[j + 1];

matrixFileNameFull = "";
rhsFileNameFull = "";

// Read matrix first
matrixFileNameFull = matrixFileName + fileId + ".mtx";
rhsFileNameFull = rhsFileName + rhsId + ".mtx";
std::cout << std::endl << std::endl << std::endl;
std::cout << "========================================================================================================================"<<std::endl;
std::cout << "Reading: " << matrixFileNameFull << std::endl;
std::cout << "========================================================================================================================"<<std::endl;
std::cout << std::endl;
// Read first matrix
std::ifstream mat_file(matrixFileNameFull);
if(!mat_file.is_open())
{
std::cout << "Failed to open file " << matrixFileNameFull << "\n";
return -1;
}
std::ifstream rhs_file(rhsFileNameFull);
if(!rhs_file.is_open())
{
std::cout << "Failed to open file " << rhsFileNameFull << "\n";
return -1;
}
if (i == 0) {
A_coo = ReSolve::io::readMatrixFromFile(mat_file);
A = new ReSolve::matrix::Csr(A_coo->getNumRows(),
A_coo->getNumColumns(),
A_coo->getNnz(),
A_coo->symmetric(),
A_coo->expanded());

rhs = ReSolve::io::readRhsFromFile(rhs_file);
x = new real_type[A->getNumRows()];
vec_rhs = new vector_type(A->getNumRows());
vec_x = new vector_type(A->getNumRows());
vec_r = new vector_type(A->getNumRows());
}
else {
ReSolve::io::readAndUpdateMatrix(mat_file, A_coo);
ReSolve::io::readAndUpdateRhs(rhs_file, &rhs);
}
std::cout<<"Finished reading the matrix and rhs, size: "<<A->getNumRows()<<" x "<<A->getNumColumns()<< ", nnz: "<< A->getNnz()<< ", symmetric? "<<A->symmetric()<< ", Expanded? "<<A->expanded()<<std::endl;
mat_file.close();
rhs_file.close();

//Now convert to CSR.
if (i < 2) {
matrix_handler->coo2csr(A_coo, A, "cpu");
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::HOST);
vec_rhs->setDataUpdated(ReSolve::memory::HOST);
} else {
matrix_handler->coo2csr(A_coo, A, "hip");
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
}
std::cout<<"COO to CSR completed. Expanded NNZ: "<< A->getNnzExpanded()<<std::endl;
//Now call direct solver
if (i == 0) {
KLU->setupParameters(1, 0.1, false);
}
int status;
if (i < 2){
KLU->setup(A);
status = KLU->analyze();
std::cout<<"KLU analysis status: "<<status<<std::endl;
status = KLU->factorize();
std::cout<<"KLU factorization status: "<<status<<std::endl;
status = KLU->solve(vec_rhs, vec_x);
std::cout<<"KLU solve status: "<<status<<std::endl;
if (i == 1) {
ReSolve::matrix::Csc* L = (ReSolve::matrix::Csc*) KLU->getLFactor();
ReSolve::matrix::Csc* U = (ReSolve::matrix::Csc*) KLU->getUFactor();
index_type* P = KLU->getPOrdering();
index_type* Q = KLU->getQOrdering();
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
Rf->setup(A, L, U, P, Q, vec_rhs);
Rf->refactorize();
}
} else {
std::cout<<"Using rocsolver rf"<<std::endl;
status = Rf->refactorize();
std::cout<<"rocsolver rf refactorization status: "<<status<<std::endl;
status = Rf->solve(vec_rhs, vec_x);
std::cout<<"rocsolver rf solve status: "<<status<<std::endl;
}
vec_r->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);

matrix_handler->setValuesChanged(true, "hip");

matrix_handler->matvec(A, vec_x, vec_r, &ONE, &MINUSONE,"csr", "hip");
res_nrm = sqrt(vector_handler->dot(vec_r, vec_r, "hip"));
b_nrm = sqrt(vector_handler->dot(vec_rhs, vec_rhs, "hip"));
std::cout << "\t 2-Norm of the residual: "
<< std::scientific << std::setprecision(16)
<< res_nrm/b_nrm << "\n";
if (!isnan(res_nrm)) {
if (res_nrm/b_nrm > 1e-7 ) {
std::cout << "\n \t !!! ALERT !!! Residual norm is too large; redoing KLU symbolic and numeric factorization. !!! ALERT !!! \n \n";

KLU->setup(A);
status = KLU->analyze();
std::cout<<"KLU analysis status: "<<status<<std::endl;
status = KLU->factorize();
std::cout<<"KLU factorization status: "<<status<<std::endl;
status = KLU->solve(vec_rhs, vec_x);
std::cout<<"KLU solve status: "<<status<<std::endl;

vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
vec_r->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);

matrix_handler->setValuesChanged(true, "hip");

matrix_handler->matvec(A, vec_x, vec_r, &ONE, &MINUSONE,"csr", "hip");
res_nrm = sqrt(vector_handler->dot(vec_r, vec_r, "hip"));

std::cout<<"\t New residual norm: "
<< std::scientific << std::setprecision(16)
<< res_nrm/b_nrm << "\n";


ReSolve::matrix::Csc* L = (ReSolve::matrix::Csc*) KLU->getLFactor();
ReSolve::matrix::Csc* U = (ReSolve::matrix::Csc*) KLU->getUFactor();

index_type* P = KLU->getPOrdering();
index_type* Q = KLU->getQOrdering();

Rf->setup(A, L, U, P, Q, vec_rhs);
}
}


} // for (int i = 0; i < numSystems; ++i)

//now DELETE
delete A;
delete A_coo;
delete KLU;
delete Rf;
delete [] x;
delete [] rhs;
delete vec_r;
delete vec_x;
delete workspace_HIP;
delete matrix_handler;
delete vector_handler;
return 0;
}
48 changes: 48 additions & 0 deletions resolve/LinSolverDirectKLU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ namespace ReSolve
{
Symbolic_ = nullptr;
Numeric_ = nullptr;

L_ = nullptr;
U_ = nullptr;

klu_defaults(&Common_) ;
}

Expand Down Expand Up @@ -40,8 +44,24 @@ namespace ReSolve

int LinSolverDirectKLU::analyze()
{
// in case we called this function AGAIN
if (Symbolic_ != nullptr) {
klu_free_symbolic(&Symbolic_, &Common_);
}

Symbolic_ = klu_analyze(A_->getNumRows(), A_->getRowData(memory::HOST), A_->getColData(memory::HOST), &Common_) ;

factors_extracted_ = false;
if (L_ != nullptr) {
delete L_;
L_ = nullptr;
}

if (U_ != nullptr) {
delete U_;
U_ = nullptr;
}

if (Symbolic_ == nullptr){
printf("Symbolic_ factorization crashed withCommon_.status = %d \n", Common_.status);
return 1;
Expand All @@ -51,8 +71,24 @@ namespace ReSolve

int LinSolverDirectKLU::factorize()
{
if (Numeric_ != nullptr) {
klu_free_numeric(&Numeric_, &Common_);
}

Numeric_ = klu_factor(A_->getRowData(memory::HOST), A_->getColData(memory::HOST), A_->getValues(memory::HOST), Symbolic_, &Common_);

factors_extracted_ = false;

if (L_ != nullptr) {
delete L_;
L_ = nullptr;
}

if (U_ != nullptr) {
delete U_;
U_ = nullptr;
}

if (Numeric_ == nullptr){
return 1;
}
Expand All @@ -63,6 +99,18 @@ namespace ReSolve
{
int kluStatus = klu_refactor (A_->getRowData(memory::HOST), A_->getColData(memory::HOST), A_->getValues(memory::HOST), Symbolic_, Numeric_, &Common_);

factors_extracted_ = false;

if (L_ != nullptr) {
delete L_;
L_ = nullptr;
}

if (U_ != nullptr) {
delete U_;
U_ = nullptr;
}

if (!kluStatus){
//display error
return 1;
Expand Down
41 changes: 32 additions & 9 deletions resolve/LinSolverDirectRocSolverRf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,19 @@ namespace ReSolve
//set matrix info
rocsolver_create_rfinfo(&infoM_, workspace_->getRocblasHandle());
//create combined factor
addFactors(L,U);

addFactors(L, U);

M_->setUpdated(ReSolve::memory::HOST);
M_->copyData(ReSolve::memory::DEVICE);
mem_.allocateArrayOnDevice(&d_P_, n);
mem_.allocateArrayOnDevice(&d_Q_, n);

if (d_P_ == nullptr) {
mem_.allocateArrayOnDevice(&d_P_, n);
}

if (d_Q_ == nullptr) {
mem_.allocateArrayOnDevice(&d_Q_, n);
}
mem_.copyArrayHostToDevice(d_P_, P, n);
mem_.copyArrayHostToDevice(d_Q_, Q, n);

Expand All @@ -70,12 +77,22 @@ namespace ReSolve

// tri solve setup
if (solve_mode_ == 1) { // fast mode
L_csr_ = new ReSolve::matrix::Csr(L->getNumRows(), L->getNumColumns(), L->getNnz());
U_csr_ = new ReSolve::matrix::Csr(U->getNumRows(), U->getNumColumns(), U->getNnz());

if (L_csr_ != nullptr) {
delete L_csr_;
}

L_csr_ = new ReSolve::matrix::Csr(L->getNumRows(), L->getNumColumns(), L->getNnz());
L_csr_->allocateMatrixData(ReSolve::memory::DEVICE);

if (U_csr_ != nullptr) {
delete U_csr_;
}

U_csr_ = new ReSolve::matrix::Csr(U->getNumRows(), U->getNumColumns(), U->getNnz());
U_csr_->allocateMatrixData(ReSolve::memory::DEVICE);


rocsparse_create_mat_descr(&(descr_L_));
rocsparse_set_mat_fill_mode(descr_L_, rocsparse_fill_mode_lower);
rocsparse_set_mat_index_base(descr_L_, rocsparse_index_base_zero);
Expand Down Expand Up @@ -161,9 +178,12 @@ namespace ReSolve
error_sum += status_rocsparse_;
if (status_rocsparse_!=0)printf("status after analysis 2 %d \n", status_rocsparse_);
//allocate aux data

mem_.allocateArrayOnDevice(&d_aux1_,n);
mem_.allocateArrayOnDevice(&d_aux2_,n);
if (d_aux1_ == nullptr) {
mem_.allocateArrayOnDevice(&d_aux1_,n);
}
if (d_aux2_ == nullptr) {
mem_.allocateArrayOnDevice(&d_aux2_,n);
}

}
return error_sum;
Expand Down Expand Up @@ -193,7 +213,7 @@ namespace ReSolve

if (solve_mode_ == 1) {
//split M, fill L and U with correct values
printf("solve mode 1, splitting the factors again \n");
printf("solve mode 1, splitting the factors again \n");
status_rocblas_ = rocsolver_dcsrrf_splitlu(workspace_->getRocblasHandle(),
A_->getNumRows(),
M_->getNnzExpanded(),
Expand Down Expand Up @@ -360,6 +380,9 @@ printf("solve mode 1, splitting the factors again \n");
index_type* Li = L->getRowData(ReSolve::memory::HOST);
index_type* Up = U->getColData(ReSolve::memory::HOST);
index_type* Ui = U->getRowData(ReSolve::memory::HOST);
if (M_ != nullptr) {
delete M_;
}

index_type nnzM = ( L->getNnz() + U->getNnz() - n );
M_ = new matrix::Csr(n, n, nnzM);
Expand Down
Loading

0 comments on commit d1e82eb

Please sign in to comment.