Skip to content

Commit 32722b7

Browse files
committed
circulant implementation to obtain structure of CGMRF
1 parent db3d2f8 commit 32722b7

File tree

7 files changed

+61
-6
lines changed

7 files changed

+61
-6
lines changed

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,7 @@ Implementation of GMRF for spatial analysis.
1414
## Graphs
1515

1616
![graphs](figures/03-graph.png)
17+
18+
## Todo
19+
20+
- [ ] Check structure_base to return always `Integers`

src/GMRFs.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module GMRFs
22

33
using Meshes
4-
import SparseArrays: sparse, sparsevec, spdiagm, spzeros
4+
import SparseArrays: sparse, sparsevec, spdiagm, spzeros, findnz, SparseVector, SparseMatrixCSC
55
import SuiteSparse # this is only for fixes
66
import FFTW
77
import Distributions: InverseGamma, Gamma, Distributions

src/gmrf/cgmrf.jl

+10-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,16 @@ CGMRF(domain::CartesianGrid, order::Integer, κ::Real, δ::Real) =
3131
Base.length(d::CGMRF) = length(d.base)
3232
scale(d::CGMRF) = d.κ
3333
structure_base(d::CGMRF) = d.base
34-
# structure(d::CGMRF) = structure(d.g; δ = d.δ, order = d.order, circular = true)
34+
35+
function structure(d::CGMRF)
36+
# get base and transpose for 2d
37+
base = structure_base(d)
38+
if base isa SparseMatrixCSC
39+
base = sparse(base')
40+
end
41+
# convert to circulant
42+
spcirculant(base)
43+
end
3544

3645
## Random generator
3746

src/utils.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
for filename in ["adjacency.jl", "difference.jl", "structure.jl"]
1+
for filename in ["adjacency.jl", "difference.jl", "structure.jl", "circulant.jl"]
22
include(joinpath("utils", filename))
33
end
44

src/utils/circulant.jl

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""
2+
spcirculant(x)
3+
4+
Return a sparse circulant or block-circulant matrix based on the input sparse vector or
5+
matrix `x`.
6+
7+
The size of the resulting matrix is `n×n`, where `n` represents the total number of
8+
elements in `x`. If `x` is a `SparseVector`, the function generates a `n×n` sparse
9+
circulant matrix. Conversely, if `x` is a `SparseMatrixCSC`, the function creates a `n×n`
10+
sparse block circulant matrix.
11+
"""
12+
13+
function spcirculant(x::SparseVector)
14+
# get indices and values
15+
n = length(x)
16+
(is, vals) = findnz(x)
17+
18+
# compute the indices for the full structure matrix
19+
Is = repeat([i for i in 1:n], length(vals))
20+
Js = [mod1(i -is[k]+1, n) for i in 1:n, k in 1:length(vals)]
21+
Vs = repeat(vals, inner = n)
22+
23+
# spare structure matrix
24+
sparse(Is, vec(Js), Vs, n, n)
25+
end
26+
27+
function spcirculant(x::SparseMatrixCSC)
28+
# get base and values
29+
(n1,n2) = size(x)
30+
(is, js, vals) = findnz(x)
31+
32+
# compute the indices for the full structure matrix
33+
Is = repeat(vec([i + n1 * (b-1) for i in 1:n1, b in 1:n2]), length(vals))
34+
Js = [mod1(n1*(js[k]-1) + n1 * (b-1) + mod1(i -is[k]+1, n1), n1*n2) for i in 1:n1, b in 1:n2, k in 1:length(vals)]
35+
Vs = repeat(vals, inner = n1 * n2)
36+
37+
# spare structure matrix
38+
sparse(Is, vec(Js), Vs, n1 * n2, n1 * n2)
39+
end
40+

test/cgmrf.jl

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
@test length(X) == n
1515
@test GMRFs.scale(X) == kappa[i]
1616
@test GMRFs.structure_base(X) == GMRFs.structure_base(grid, order = order[i], δ = δ[i])
17+
@test GMRFs.structure(X) == GMRFs.structure(grid, δ = δ[i], order = order[i], circular = true)
1718
# rand and logpdf: single
1819
x = rand(X)
1920
@test length(x) == n
@@ -41,6 +42,7 @@ end
4142
@test length(X) == n1 * n2
4243
@test GMRFs.scale(X) == kappa[i]
4344
@test GMRFs.structure_base(X) == GMRFs.structure_base(grid, order = order[i], δ = δ[i])
45+
@test GMRFs.structure(X) == GMRFs.structure(grid, δ = δ[i], order = order[i], circular = true)
4446
# rand and logpdf: single
4547
x = rand(X)
4648
@test length(x) == n1 * n2

test/utils.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,12 @@ end
112112
# circular: structure
113113
S = GMRFs.structure(grid, order = 1, circular = true)
114114
@test S == GMRFs.structure(graphc, order = 1)
115-
@test S[1, :] == vec(GMRFs.structure_base(grid, order = 1))
115+
@test S[1, :] == vec(GMRFs.structure_base(grid, order = 1)')
116116
S = GMRFs.structure(grid, order = 2, circular = true)
117117
@test S == GMRFs.structure(graphc, order = 2)
118-
@test S[1, :] == vec(GMRFs.structure_base(grid, order = 2))
118+
@test S[1, :] == vec(GMRFs.structure_base(grid, order = 2)')
119119
S = GMRFs.structure(grid, order = 3, circular = true)
120120
@test S == GMRFs.structure(graphc, order = 3)
121-
@test S[1, :] == vec(GMRFs.structure_base(grid, order = 3))
121+
@test S[1, :] == vec(GMRFs.structure_base(grid, order = 3)')
122122

123123
end

0 commit comments

Comments
 (0)