Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
fast_finish: true
include:
- stage: "Documentation"
julia: 1.0
julia: 1
os: linux
script:
- julia --project=docs/ -e 'using Pkg; Pkg.instantiate()'
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@ uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.10.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
ChainRulesCore = "0.9"
julia = "1"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Random", "StaticArrays", "Test"]
9 changes: 7 additions & 2 deletions src/FiniteDifferences.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
module FiniteDifferences

using Printf, LinearAlgebra
using ChainRulesCore
using LinearAlgebra
using Printf
using Random

export to_vec, grad, jacobian, jvp, j′vp
export to_vec, grad, jacobian, jvp, j′vp, difference, rand_tangent

include("rand_tangent.jl")
include("difference.jl")
include("methods.jl")
include("numerics.jl")
include("to_vec.jl")
Expand Down
45 changes: 45 additions & 0 deletions src/difference.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
difference(ε::Real, y::T, x::T) where {T}

Computes `dx` where `dx` is defined s.t.
```julia
y = x + ε * dx
```
where `dx` is a valid tangent type for `x`.

If `(y - x) / ε` is defined, then this operation is equivalent to doing that. For functions
where these operations aren't defined, `difference` can still be defined without commiting
type piracy while `-` and `/` cannot.
"""
difference(::Real, ::T, ::T) where {T<:Symbol} = DoesNotExist()
difference(::Real, ::T, ::T) where {T<:AbstractChar} = DoesNotExist()
difference(::Real, ::T, ::T) where {T<:AbstractString} = DoesNotExist()
difference(::Real, ::T, ::T) where {T<:Integer} = DoesNotExist()

difference(ε::Real, y::T, x::T) where {T<:Number} = (y - x) / ε

difference(ε::Real, y::T, x::T) where {T<:StridedArray} = difference.(ε, y, x)

function difference(ε::Real, y::T, x::T) where {T<:Tuple}
return Composite{T}(difference.(ε, y, x)...)
end

function difference(ε::Real, ys::T, xs::T) where {T<:NamedTuple}
return Composite{T}(; map((y, x) -> difference(ε, y, x), ys, xs)...)
end

function difference(ε::Real, y::T, x::T) where {T}
if !isstructtype(T)
throw(ArgumentError("Non-struct types are not supported by this fallback."))
end

field_names = fieldnames(T)
if length(field_names) > 0
tangents = map(field_names) do field_name
difference(ε, getfield(y, field_name), getfield(x, field_name))
end
return Composite{T}(; NamedTuple{field_names}(tangents)...)
else
return NO_FIELDS
end
end
41 changes: 41 additions & 0 deletions src/rand_tangent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
rand_tangent([rng::AbstractRNG,] x)

Returns a randomly generated tangent vector appropriate for the primal value `x`.
"""
rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x)

rand_tangent(rng::AbstractRNG, x::Symbol) = DoesNotExist()
rand_tangent(rng::AbstractRNG, x::AbstractChar) = DoesNotExist()
rand_tangent(rng::AbstractRNG, x::AbstractString) = DoesNotExist()
rand_tangent(rng::AbstractRNG, ::Nothing) = DoesNotExist()

rand_tangent(rng::AbstractRNG, x::Integer) = DoesNotExist()

rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} = randn(rng, T)

rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)

function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
return Composite{T}(rand_tangent.(Ref(rng), x)...)
end

function rand_tangent(rng::AbstractRNG, xs::T) where {T<:NamedTuple}
return Composite{T}(; map(x -> rand_tangent(rng, x), xs)...)
end

function rand_tangent(rng::AbstractRNG, x::T) where {T}
if !isstructtype(T)
throw(ArgumentError("Non-struct types are not supported by this fallback."))
end

field_names = fieldnames(T)
if length(field_names) > 0
tangents = map(field_names) do field_name
rand_tangent(rng, getfield(x, field_name))
end
return Composite{T}(; NamedTuple{field_names}(tangents)...)
else
return NO_FIELDS
end
end
57 changes: 57 additions & 0 deletions test/difference.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
function test_difference(ε::Real, x, dx)
y = x + ε * dx
dx_diff = difference(ε, y, x)
@test typeof(dx) == typeof(dx_diff)
end

@testset "difference" begin

@testset "$(typeof(x))" for (ε, x) in [
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests cover everything covered in the to_vec tests other than Dict, which doesn't currently play nicely / I'm not entirely sure how best to construct a tangent, nor how to add it to the primal.


# Test things that don't have tangents.
(randn(), :a),
(randn(), 'a'),
(randn(), "a"),
(randn(), 0),

# Test Numbers.
(randn(Float32), randn(Float32)),
(randn(Float64), randn(Float64)),
(randn(Float32), randn(ComplexF32)),
(randn(Float64), randn(ComplexF64)),

# Test StridedArrays.
(randn(), randn(5)),
(randn(Float32), randn(ComplexF32, 5, 2)),
(randn(), [randn(1) for _ in 1:3]),
(randn(), [randn(5, 4), "a"]),

# Tuples.
(randn(), (randn(5, 2), )),
(randn(), (randn(), 4)),
(randn(), (4, 3, 2)),

# NamedTuples.
(randn(), (a=randn(5, 2),)),
(randn(), (a=randn(), b=4)),
(randn(), (a=4, b=3, c=2)),

# Arbitrary structs.
(randn(), sin),
(randn(), cos),
(randn(), Foo(5.0, 4, randn(5, 2))),
(randn(), Foo(randn(), 1, Foo(randn(), 1, 1))),

# LinearAlgebra types (also just structs).
(randn(), UpperTriangular(randn(2, 2))),
(randn(), Diagonal(randn(4))),
(randn(), SVector{2, Float64}(1.0, 2.0)),
(randn(), SMatrix{2, 2, ComplexF64}(1.0, 2.0, 3.0, 4.0)),
(randn(), Symmetric(randn(2, 2))),
(randn(), Hermitian(randn(ComplexF64, 1, 1))),
(randn(), Adjoint(randn(ComplexF64, 3, 3))),
(randn(), Transpose(randn(3))),
]
test_difference(ε, x, rand_tangent(x))
end
end
85 changes: 85 additions & 0 deletions test/rand_tangent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Test struct for `rand_tangent`.
struct Foo
a::Float64
b::Int
c::Any
end

@testset "generate_tangent" begin
rng = MersenneTwister(123456)

foreach([

# Things without sensible tangents.
("hi", DoesNotExist),
('a', DoesNotExist),
(:a, DoesNotExist),
(true, DoesNotExist),
(4, DoesNotExist),

# Numbers.
(5.0, Float64),
(5.0 + 0.4im, Complex{Float64}),

# StridedArrays.
(randn(Float32, 3), Vector{Float32}),
(randn(Complex{Float64}, 2), Vector{Complex{Float64}}),
(randn(5, 4), Matrix{Float64}),
(randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}),
([randn(5, 4), 4.0], Vector{Any}),

# Tuples.
((4.0, ), Composite{Tuple{Float64}}),
((5.0, randn(3)), Composite{Tuple{Float64, Vector{Float64}}}),

# NamedTuples.
((a=4.0, ), Composite{NamedTuple{(:a,), Tuple{Float64}}}),
((a=5.0, b=1), Composite{NamedTuple{(:a, :b), Tuple{Float64, Int}}}),

# structs.
(sin, typeof(NO_FIELDS)),
(Foo(5.0, 4, rand(rng, 3)), Composite{Foo}),
(Foo(4.0, 3, Foo(5.0, 2, 4)), Composite{Foo}),

# LinearAlgebra types (also just structs).
(
UpperTriangular(randn(3, 3)),
Composite{UpperTriangular{Float64, Matrix{Float64}}},
),
(
Diagonal(randn(2)),
Composite{Diagonal{Float64, Vector{Float64}}},
),
(
SVector{2, Float64}(1.0, 2.0),
Composite{typeof(SVector{2, Float64}(1.0, 2.0))},
),
(
SMatrix{2, 2, ComplexF64}(1.0, 2.0, 3.0, 4.0),
Composite{typeof(SMatrix{2, 2, ComplexF64}(1.0, 2.0, 3.0, 4.0))},
),
(
Symmetric(randn(2, 2)),
Composite{Symmetric{Float64, Matrix{Float64}}},
),
(
Hermitian(randn(ComplexF64, 1, 1)),
Composite{Hermitian{ComplexF64, Matrix{ComplexF64}}},
),
(
Adjoint(randn(ComplexF64, 3, 3)),
Composite{Adjoint{ComplexF64, Matrix{ComplexF64}}},
),
(
Transpose(randn(3)),
Composite{Transpose{Float64, Vector{Float64}}},
),
]) do (x, T_tangent)
@test rand_tangent(rng, x) isa T_tangent
@test rand_tangent(x) isa T_tangent
@test x + rand_tangent(rng, x) isa typeof(x)
end

# Ensure struct fallback errors for non-struct types.
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
end
10 changes: 9 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
using FiniteDifferences, Test, Random, Printf, LinearAlgebra, StaticArrays
using ChainRulesCore
using FiniteDifferences
using LinearAlgebra
using Printf
using Random
using StaticArrays
using Test

@testset "FiniteDifferences" begin
include("rand_tangent.jl")
include("difference.jl")
include("methods.jl")
include("numerics.jl")
include("to_vec.jl")
Expand Down