Skip to content

Commit af740ed

Browse files
authored
Support algebra operations on expansions (#54)
* Support algebra operations on expansions * Start new simplify * Tests pass! * v0.3 * Remove \ simplify * * for AlephInfinity * Update Project.toml * Update runtests.jl * Use new simplifiable from LazyArrays.jl * Update ContinuumArrays.jl * Update Project.toml * Update bases.jl * Update Project.toml * Update bases.jl * require Julia v1.5 * Update bases.jl * Increase coverage
1 parent c0d13f4 commit af740ed

File tree

7 files changed

+241
-170
lines changed

7 files changed

+241
-170
lines changed

.travis.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ os:
55
- osx
66
- windows
77
julia:
8-
- 1.3
9-
- 1.4
8+
- 1.5
109
- nightly
1110
matrix:
1211
allow_failures:

Project.toml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
11
name = "ContinuumArrays"
22
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
3-
version = "0.2.5"
3+
version = "0.3.0"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
77
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
8+
FastTransforms = "057dd010-8810-581a-b7be-e3fc3b93f78c"
89
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
10+
InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c"
911
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
1012
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
13+
LazyBandedMatrices = "d7e5e226-e90b-4449-9968-0f923699bf6f"
1114
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1215
QuasiArrays = "c4ea9172-b204-11e9-377d-29865faadc5c"
1316

1417
[compat]
15-
ArrayLayouts = "0.2.4, 0.3"
16-
BandedMatrices = "0.15"
17-
FillArrays = "0.8.2, 0.9"
18+
ArrayLayouts = "0.4.3"
19+
BandedMatrices = "0.15.17"
20+
FillArrays = "0.9.3"
21+
InfiniteArrays = "0.8"
1822
IntervalSets = "0.3.2, 0.4, 0.5"
19-
LazyArrays = "0.16"
20-
QuasiArrays = "0.2.2"
21-
julia = "1.3"
23+
LazyArrays = "0.17.1"
24+
QuasiArrays = "0.3"
25+
julia = "1.5"
2226

2327
[extras]
2428
FastTransforms = "057dd010-8810-581a-b7be-e3fc3b93f78c"

src/ContinuumArrays.jl

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,75 @@
11
module ContinuumArrays
2-
using IntervalSets, LinearAlgebra, LazyArrays, FillArrays, BandedMatrices, QuasiArrays
2+
using IntervalSets, LinearAlgebra, LazyArrays, FillArrays, BandedMatrices, QuasiArrays, InfiniteArrays
33
import Base: @_inline_meta, @_propagate_inbounds_meta, axes, getindex, convert, prod, *, /, \, +, -, ==,
44
IndexStyle, IndexLinear, ==, OneTo, tail, similar, copyto!, copy, diff,
55
first, last, show, isempty, findfirst, findlast, findall, Slice, union, minimum, maximum, sum, _sum,
6-
getproperty
6+
getproperty, isone, iszero, zero, abs, <, , >, , string
77
import Base.Broadcast: materialize, BroadcastStyle, broadcasted
8-
import LazyArrays: MemoryLayout, Applied, ApplyStyle, flatten, _flatten, colsupport,
9-
adjointlayout, LdivApplyStyle, arguments, _arguments, call, broadcastlayout, layout_getindex,
10-
sublayout, sub_materialize, ApplyLayout, BroadcastLayout, combine_mul_styles
8+
import LazyArrays: MemoryLayout, Applied, ApplyStyle, flatten, _flatten, colsupport, most, combine_mul_styles, AbstractArrayApplyStyle,
9+
adjointlayout, arguments, _mul_arguments, call, broadcastlayout, layout_getindex,
10+
sublayout, sub_materialize, ApplyLayout, BroadcastLayout, combine_mul_styles, applylayout,
11+
simplifiable, _simplify
1112
import LinearAlgebra: pinv
1213
import BandedMatrices: AbstractBandedLayout, _BandedMatrix
1314
import FillArrays: AbstractFill, getindex_value, SquareEye
14-
15+
import ArrayLayouts: mul
1516
import QuasiArrays: cardinality, checkindex, QuasiAdjoint, QuasiTranspose, Inclusion, SubQuasiArray,
16-
QuasiDiagonal, MulQuasiArray, MulQuasiMatrix, MulQuasiVector, QuasiMatMulMat, quasimulapplystyle,
17-
ApplyQuasiArray, ApplyQuasiMatrix, LazyQuasiArrayApplyStyle, AbstractQuasiArrayApplyStyle,
18-
LazyQuasiArray, LazyQuasiVector, LazyQuasiMatrix, LazyLayout, LazyQuasiArrayStyle, quasildivapplystyle, _factorize
17+
QuasiDiagonal, MulQuasiArray, MulQuasiMatrix, MulQuasiVector, QuasiMatMulMat,
18+
ApplyQuasiArray, ApplyQuasiMatrix, LazyQuasiArrayApplyStyle, AbstractQuasiArrayApplyStyle, AbstractQuasiLazyLayout,
19+
LazyQuasiArray, LazyQuasiVector, LazyQuasiMatrix, LazyLayout, LazyQuasiArrayStyle, _factorize
20+
import InfiniteArrays: Infinity
1921

20-
export Spline, LinearSpline, HeavisideSpline, DiracDelta, Derivative, fullmaterialize, ℵ₁, Inclusion, Basis, WeightedBasis, grid, transform, affine
22+
export Spline, LinearSpline, HeavisideSpline, DiracDelta, Derivative, ℵ₁, Inclusion, Basis, WeightedBasis, grid, transform, affine
2123

2224
####
2325
# Interval indexing support
2426
####
2527
struct AlephInfinity{N} <: Integer end
2628

29+
isone(::AlephInfinity) = false
30+
iszero(::AlephInfinity) = false
31+
2732
==(::AlephInfinity, ::Int) = false
2833
==(::Int, ::AlephInfinity) = false
2934

3035
*(::AlephInfinity{N}, ::AlephInfinity{N}) where N = AlephInfinity{N}()
36+
*(::AlephInfinity{N}, ::Infinity) where N = AlephInfinity{N}()
37+
*(::Infinity, ::AlephInfinity{N}) where N = AlephInfinity{N}()
38+
function *(a::Integer, b::AlephInfinity)
39+
a > 0 || throw(ArgumentError("$a is negative"))
40+
b
41+
end
42+
43+
*(a::AlephInfinity, b::Integer) = b*a
44+
45+
46+
abs(a::AlephInfinity) = a
47+
zero(::AlephInfinity) = 0
48+
49+
for OP in (:<, :)
50+
@eval begin
51+
$OP(::Real, ::AlephInfinity) = true
52+
$OP(::AlephInfinity, ::Real) = false
53+
end
54+
end
55+
56+
for OP in (:>, :)
57+
@eval begin
58+
$OP(::Real, ::AlephInfinity) = false
59+
$OP(::AlephInfinity, ::Real) = true
60+
end
61+
end
62+
3163

3264
const ℵ₁ = AlephInfinity{1}()
3365

66+
string(::AlephInfinity{1}) = "ℵ₁"
67+
3468
show(io::IO, F::AlephInfinity{1}) where N =
3569
print(io, "ℵ₁")
3670

3771

38-
const QMul2{A,B} = Mul{<:AbstractQuasiArrayApplyStyle, <:Tuple{A,B}}
72+
const QMul2{A,B} = Mul{<:Any, <:Any, <:A,<:B}
3973
const QMul3{A,B,C} = Mul{<:AbstractQuasiArrayApplyStyle, <:Tuple{A,B,C}}
4074

4175
cardinality(::AbstractInterval) = ℵ₁

src/bases/bases.jl

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ abstract type Weight{T} <: LazyQuasiVector{T} end
44

55
const WeightedBasis{T, A<:AbstractQuasiVector, B<:Basis} = BroadcastQuasiMatrix{T,typeof(*),<:Tuple{A,B}}
66

7-
struct WeightLayout <: MemoryLayout end
8-
abstract type AbstractBasisLayout <: MemoryLayout end
7+
struct WeightLayout <: AbstractQuasiLazyLayout end
8+
abstract type AbstractBasisLayout <: AbstractQuasiLazyLayout end
99
struct BasisLayout <: AbstractBasisLayout end
1010
struct SubBasisLayout <: AbstractBasisLayout end
1111
struct MappedBasisLayout <: AbstractBasisLayout end
1212
struct WeightedBasisLayout <: AbstractBasisLayout end
1313

14-
abstract type AbstractAdjointBasisLayout <: MemoryLayout end
14+
abstract type AbstractAdjointBasisLayout <: AbstractQuasiLazyLayout end
1515
struct AdjointBasisLayout <: AbstractAdjointBasisLayout end
1616
struct AdjointSubBasisLayout <: AbstractAdjointBasisLayout end
1717
struct AdjointMappedBasisLayout <: AbstractAdjointBasisLayout end
@@ -25,9 +25,7 @@ adjointlayout(::Type, ::MappedBasisLayout) = AdjointMappedBasisLayout()
2525
broadcastlayout(::Type{typeof(*)}, ::WeightLayout, ::BasisLayout) = WeightedBasisLayout()
2626
broadcastlayout(::Type{typeof(*)}, ::WeightLayout, ::SubBasisLayout) = WeightedBasisLayout()
2727

28-
combine_mul_styles(::AbstractBasisLayout) = LazyQuasiArrayApplyStyle()
29-
combine_mul_styles(::AbstractAdjointBasisLayout) = LazyQuasiArrayApplyStyle()
30-
28+
# Default is lazy
3129
ApplyStyle(::typeof(pinv), ::Type{<:Basis}) = LazyQuasiArrayApplyStyle()
3230
pinv(J::Basis) = apply(pinv,J)
3331

@@ -37,10 +35,6 @@ function ==(A::Basis, B::Basis)
3735
false
3836
end
3937

40-
@inline quasildivapplystyle(::AbstractBasisLayout, ::AbstractBasisLayout) = LdivApplyStyle()
41-
@inline quasildivapplystyle(::AbstractBasisLayout, _) = LdivApplyStyle()
42-
@inline quasildivapplystyle(_, ::AbstractBasisLayout) = LdivApplyStyle()
43-
4438

4539
@inline copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(+)}}) = +(broadcast(\,Ref(L.A),arguments(L.B))...)
4640
@inline copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(+)},<:Any,<:AbstractQuasiVector}) =
@@ -138,7 +132,7 @@ copy(L::Ldiv{<:AbstractBasisLayout,ApplyLayout{typeof(*)},<:Any,<:AbstractQuasiV
138132
transform_ldiv(L.A, L.B)
139133

140134
function copy(L::Ldiv{ApplyLayout{typeof(*)},<:AbstractBasisLayout})
141-
args = arguments(L.A)
135+
args = arguments(ApplyLayout{typeof(*)}(), L.A)
142136
@assert length(args) == 2 # temporary
143137
apply(\, last(args), apply(\, first(args), L.B))
144138
end
@@ -149,6 +143,55 @@ function copy(L::Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)},<:Abstrac
149143
T \ L.B[p]
150144
end
151145

146+
147+
##
148+
# Algebra
149+
##
150+
151+
# struct ExpansionLayout <: MemoryLayout end
152+
# applylayout(::Type{typeof(*)}, ::BasisLayout, _) = ExpansionLayout()
153+
154+
const Expansion{T,Space<:Basis,Coeffs<:AbstractVector} = ApplyQuasiVector{T,typeof(*),<:Tuple{Space,Coeffs}}
155+
156+
basis(v::AbstractQuasiVector) = v.args[1]
157+
158+
for op in (:*, :\)
159+
@eval function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof($op), x::Number, f::Expansion)
160+
S,c = arguments(f)
161+
S * broadcast($op, x, c)
162+
end
163+
end
164+
for op in (:*, :/)
165+
@eval function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof($op), f::Expansion, x::Number)
166+
S,c = arguments(f)
167+
S * broadcast($op, c, x)
168+
end
169+
end
170+
171+
172+
function broadcastbasis(::typeof(+), a, b)
173+
a b && error("Overload broadcastbasis(::typeof(+), ::$(typeof(a)), ::$(typeof(b)))")
174+
a
175+
end
176+
177+
broadcastbasis(::typeof(-), a, b) = broadcastbasis(+, a, b)
178+
179+
for op in (:+, :-)
180+
@eval function broadcasted(::LazyQuasiArrayStyle{1}, ::typeof($op), f::Expansion, g::Expansion)
181+
S,c = arguments(f)
182+
T,d = arguments(g)
183+
ST = broadcastbasis($op, S, T)
184+
ST * $op((ST \ S) * c , (ST \ T) * d)
185+
end
186+
end
187+
188+
@eval function ==(f::Expansion, g::Expansion)
189+
S,c = arguments(f)
190+
T,d = arguments(g)
191+
ST = broadcastbasis(+, S, T)
192+
(ST \ S) * c == (ST \ T) * d
193+
end
194+
152195
## materialize views
153196

154197
# materialize(S::SubQuasiArray{<:Any,2,<:ApplyQuasiArray{<:Any,2,typeof(*),<:Tuple{<:Basis,<:Any}}}) =
@@ -164,9 +207,8 @@ end
164207
_sub_getindex(A, kr, jr) = A[kr, jr]
165208
_sub_getindex(A, ::Slice, ::Slice) = A
166209

167-
function copy(M::QMul2{<:QuasiAdjoint{<:Any,<:SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:AbstractAffineQuasiVector,<:Any}}},
168-
<:SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:AbstractAffineQuasiVector,<:Any}}})
169-
Ac, B = M.args
210+
@simplify function *(Ac::QuasiAdjoint{<:Any,<:SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:AbstractAffineQuasiVector,<:Any}}},
211+
B::SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:AbstractAffineQuasiVector,<:Any}})
170212
A = Ac'
171213
PA,PB = parent(A),parent(B)
172214
kr,jr = parentindices(B)
@@ -175,14 +217,12 @@ end
175217

176218

177219
# Differentiation of sub-arrays
178-
function copy(M::QMul2{<:Derivative,<:SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:Inclusion,<:Any}}})
179-
A, B = M.args
220+
@simplify function *(A::Derivative, B::SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:Inclusion,<:Any}})
180221
P = parent(B)
181222
(Derivative(axes(P,1))*P)[parentindices(B)...]
182223
end
183224

184-
function copy(M::QMul2{<:Derivative,<:SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:AbstractAffineQuasiVector,<:Any}}})
185-
A, B = M.args
225+
@simplify function *(A::Derivative, B::SubQuasiArray{<:Any,2,<:AbstractQuasiMatrix,<:Tuple{<:AbstractAffineQuasiVector,<:Any}})
186226
P = parent(B)
187227
kr,jr = parentindices(B)
188228
(Derivative(axes(P,1))*P*kr.A)[kr,jr]
@@ -220,15 +260,18 @@ end
220260
# SubLayout behaves like ApplyLayout{typeof(*)}
221261

222262
combine_mul_styles(::SubBasisLayout) = combine_mul_styles(ApplyLayout{typeof(*)}())
223-
_arguments(::SubBasisLayout, A) = _arguments(ApplyLayout{typeof(*)}(), A)
263+
_mul_arguments(::SubBasisLayout, A) = _mul_arguments(ApplyLayout{typeof(*)}(), A)
264+
arguments(::SubBasisLayout, A) = arguments(ApplyLayout{typeof(*)}(), A)
224265
call(::SubBasisLayout, ::SubQuasiArray) = *
225266

226267
combine_mul_styles(::AdjointSubBasisLayout) = combine_mul_styles(ApplyLayout{typeof(*)}())
227-
_arguments(::AdjointSubBasisLayout, A) = _arguments(ApplyLayout{typeof(*)}(), A)
268+
_mul_arguments(::AdjointSubBasisLayout, A) = _mul_arguments(ApplyLayout{typeof(*)}(), A)
228269
arguments(::AdjointSubBasisLayout, A) = arguments(ApplyLayout{typeof(*)}(), A)
229270
call(::AdjointSubBasisLayout, ::SubQuasiArray) = *
230271

231-
function arguments(V::SubQuasiArray{<:Any,2,<:Any,<:Tuple{<:Inclusion,<:AbstractUnitRange}})
272+
copy(M::Mul{AdjointSubBasisLayout,SubBasisLayout}) = apply(*, arguments(M.A)..., arguments(M.B)...)
273+
274+
function arguments(::ApplyLayout{typeof(*)}, V::SubQuasiArray{<:Any,2,<:Any,<:Tuple{<:Inclusion,<:AbstractUnitRange}})
232275
A = parent(V)
233276
_,jr = parentindices(V)
234277
first(jr) 1 || throw(BoundsError())
@@ -249,8 +292,8 @@ function __sum(::SubBasisLayout, Vm, dims)
249292
@assert dims == 1
250293
sum(parent(Vm); dims=dims)[:,parentindices(Vm)[2]]
251294
end
252-
function __sum(::ApplyLayout{typeof(*)}, V::AbstractQuasiVector, ::Colon)
253-
a = arguments(V)
295+
function __sum(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiVector, ::Colon)
296+
a = arguments(LAY, V)
254297
first(apply(*, sum(a[1]; dims=1), tail(a)...))
255298
end
256299

src/bases/splines.jl

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,19 @@ end
4848

4949

5050
## Mass matrix
51-
52-
ApplyStyle(::typeof(*), ::Type{<:QuasiAdjoint{<:Any,<:LinearSpline}}, ::Type{<:LinearSpline}) =
53-
SimplifyStyle()
54-
55-
5651
function similar(AB::QMul2{<:QuasiAdjoint{<:Any,<:LinearSpline},<:LinearSpline}, ::Type{T}) where T
5752
n = size(AB,1)
5853
SymTridiagonal(Vector{T}(undef, n), Vector{T}(undef, n-1))
5954
end
6055
#
61-
copy(M::QMul2{<:QuasiAdjoint{<:Any,<:LinearSpline},<:LinearSpline}) =
56+
@simplify function *(Ac::QuasiAdjoint{<:Any,<:LinearSpline}, B::LinearSpline)
57+
M = Mul(Ac, B)
6258
copyto!(similar(M, eltype(M)), M)
59+
end
6360

6461
function copyto!(dest::SymTridiagonal,
6562
AB::QMul2{<:QuasiAdjoint{<:Any,<:LinearSpline},<:LinearSpline}) where T
66-
Ac,B = AB.args
63+
Ac,B = AB.A,AB.B
6764
A = parent(Ac)
6865
A.points == B.points || throw(ArgumentError())
6966
dv,ev = dest.dv,dest.ev
@@ -92,13 +89,9 @@ end
9289

9390

9491
## Derivative
95-
ApplyStyle(::typeof(*), ::Type{<:Derivative}, ::Type{<:LinearSpline}) = SimplifyStyle()
96-
97-
98-
9992
function copyto!(dest::MulQuasiMatrix{<:Any,<:Tuple{<:HeavisideSpline,<:Any}},
10093
M::QMul2{<:Derivative,<:LinearSpline})
101-
D, L = M.args
94+
D, L = M.A, M.B
10295
H, A = dest.args
10396
x = H.points
10497

@@ -114,20 +107,15 @@ function copyto!(dest::MulQuasiMatrix{<:Any,<:Tuple{<:HeavisideSpline,<:Any}},
114107
end
115108

116109
function similar(M::QMul2{<:Derivative,<:LinearSpline}, ::Type{T}) where T
117-
D, B = M.args
110+
D, B = M.A, M.B
118111
n = size(B,2)
119112
ApplyQuasiMatrix(*, HeavisideSpline{T}(B.points),
120113
BandedMatrix{T}(undef, (n-1,n), (0,1)))
121114
end
122115

123-
copy(M::QMul2{<:Derivative,<:LinearSpline}) =
116+
@simplify function *(D::Derivative, L::LinearSpline)
117+
M = Mul(D, L)
124118
copyto!(similar(M, eltype(M)), M)
125-
126-
ApplyStyle(::typeof(*), ::Type{<:QuasiAdjoint{<:Any,<:LinearSpline}}, ::Type{<:QuasiAdjoint{<:Any,<:Derivative}}) = SimplifyStyle()
127-
128-
function copy(M::QMul2{<:QuasiAdjoint{<:Any,<:LinearSpline},<:QuasiAdjoint{<:Any,<:Derivative}})
129-
Bc,Ac = M.args
130-
apply(*,Ac',Bc')'
131119
end
132120

133121

@@ -138,4 +126,15 @@ end
138126
function _sum(A::HeavisideSpline, dims)
139127
@assert dims == 1
140128
permutedims(diff(A.points))
129+
end
130+
131+
function _sum(P::LinearSpline, dims)
132+
d = diff(P.points)
133+
ret = Array{float(eltype(d))}(undef, length(d)+1)
134+
ret[1] = d[1]/2
135+
for k = 2:length(d)
136+
ret[k] = (d[k-1] + d[k])/2
137+
end
138+
ret[end] = d[end]/2
139+
permutedims(ret)
141140
end

0 commit comments

Comments
 (0)