diff --git a/src/cyclic_reduction.jl b/src/cyclic_reduction.jl index df312c2..de5caff 100644 --- a/src/cyclic_reduction.jl +++ b/src/cyclic_reduction.jl @@ -1,67 +1,28 @@ -using LinearAlgebra.BLAS: gemm! - -if VERSION < v"1.9.0-" - @show VERSION - @eval begin """ CRSolverWs Workspace used for solving with the cyclic reduction algorithm of [Bini et al.](https://link.springer.com/article/10.1007/s11075-008-9253-0). """ - mutable struct CRSolverWs{T, WS} <: Workspace - linsolve_ws::WS - ahat1::Matrix{T} - a1copy::Matrix{T} - x::Matrix{T} - m::Matrix{T} - m1::Matrix{T} - m2::Matrix{T} - end - - function CRSolverWs(a0::AbstractMatrix{T}) where {T<:AbstractFloat} - n = size(a0,1) - linsolve_ws = LUWs(n) - ahat1 = Matrix{T}(undef, n,n) - a1copy = Matrix{T}(undef, n,n) - m = Matrix{T}(undef, 2*n,2*n) - m1 = Matrix{T}(undef, n, 2*n) - m2 = Matrix{T}(undef, 2*n, n) - x = Matrix{T}(undef, n,n) - CRSolverWs(linsolve_ws, ahat1, a1copy, x, m, m1, m2) - end - - end - @show methods(CRSolverWs) -else - - @eval begin -""" - CRSolverWs +mutable struct CRSolverWs{T, WS} <: Workspace + linsolve_ws::WS + ahat1::Matrix{T} + a1copy::Matrix{T} + x::Matrix{T} + m::Matrix{T} + m1::Matrix{T} + m2::Matrix{T} +end -Workspace used for solving with the cyclic reduction algorithm of [Bini et al.](https://link.springer.com/article/10.1007/s11075-008-9253-0). -""" - mutable struct CRSolverWs{T, WS, MT<:AbstractMatrix{T}} <: Workspace - linsolve_ws::WS - ahat1::Matrix{T} - a1copy::Matrix{T} - x::MT - m::Matrix{T} - m1::Matrix{T} - m2::Matrix{T} - end - - function CRSolverWs(a0::AbstractMatrix{T}) where {T<:AbstractFloat} - n = size(a0,1) - linsolve_ws = LUWs(n) - ahat1 = Matrix{T}(undef, n,n) - a1copy = Matrix{T}(undef, n,n) - m = Matrix{T}(undef, 2*n,2*n) - m1 = Matrix{T}(undef, n, 2*n) - m2 = Matrix{T}(undef, 2*n, n) - CRSolverWs(linsolve_ws, ahat1, a1copy, similar(a0), m, m1, m2) - end - end - +function CRSolverWs(a0::AbstractMatrix{T}) where {T<:AbstractFloat} + n = size(a0,1) + linsolve_ws = LUWs(n) + ahat1 = Matrix{T}(undef, n,n) + a1copy = Matrix{T}(undef, n,n) + m = Matrix{T}(undef, 2*n,2*n) + m1 = Matrix{T}(undef, n, 2*n) + m2 = Matrix{T}(undef, 2*n, n) + x = Matrix{T}(undef, n,n) + CRSolverWs(linsolve_ws, ahat1, a1copy, x, m, m1, m2) end function solve!(ws::CRSolverWs{T}, a0::Matrix{T}, a1::Matrix{T}, a2::Matrix{T}; @@ -324,7 +285,7 @@ function solve!(ws::CRSolverWs, a0::SparseMatrixCSC, a1_::AbstractMatrix, a2::Sp lu_t = LU(factorize!(ws.linsolve_ws, ws.ahat1)...) ldiv!(lu_t, x) @inbounds lmul!(-1.0, x) - return x + return sparse(x) end function check_convergence!(x, it, crit1, crit2, m1, tolerance, max_iterations)