Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Oct 17, 2022
1 parent a3170d6 commit 886c16f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
3 changes: 0 additions & 3 deletions src/DifferentiableFlatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,6 @@ flatten_expr(T, C) = quote
end

_cumsum(x) = cumsum(x)
if VERSION < v"1.5"
_cumsum(x::Tuple) = (_cumsum(collect(x))..., )
end

# Zygote can return a sparse vector co-tangent
# even if the input is a vector. This is causing
Expand Down
51 changes: 50 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
using DifferentiableFlatten: flatten, zygote_flatten
using DifferentiableFlatten: flatten, zygote_flatten, maybeflatten
using DifferentiableFlatten: DifferentiableFlatten, @constructor
using OrderedCollections, JuMP, Zygote, SparseArrays, LinearAlgebra, Test
using ChainRulesCore

struct SS
a
b
end

struct MyStruct{T, T1, T2}
a::T1
b::T2
end
MyStruct(a, b) = MyStruct{typeof(a), typeof(a), typeof(b)}(a, b)
@constructor MyStruct MyStruct

@testset "DifferentiableFlatten.jl" begin
xs = [
1.0,
Expand Down Expand Up @@ -42,6 +51,7 @@ end
sparse([1, 2, 2, 3], [2, 3, 1, 4], Float64[1.0, 2.0, 3.0, 4.0], 10, 10),
SS(1.0, 2.0),
[SS(1.0, 2.0), 1.0],
MyStruct(1.0, 1.0),
]
for x in xs
@show x
Expand All @@ -60,5 +70,44 @@ end
zygote_flatten(x, x)[1]
end[1]
@test logabsdet(J) == (0.0, 1.0)
if x isa Real
@test maybeflatten(x) == x
else
xvec, unflatten = maybeflatten(x)
@test x == unflatten(xvec)
end
@show DifferentiableFlatten._zero(x)
@test all(==(0), flatten(DifferentiableFlatten._zero(x))[1])
end

xvec, unflatten = zygote_flatten(SS(1.0, 2.0), Tangent{SS}(a = 1.0, b = 2.0))
@test unflatten(xvec) isa NamedTuple

xvec, unflatten = zygote_flatten(SS(1.0, 2.0), (a = 1.0, b = 2.0))
@test unflatten(xvec) isa NamedTuple
@test DifferentiableFlatten._length(nothing) == 0

@test DifferentiableFlatten._merge(
OrderedDict(:a => 1.0),
OrderedDict(:b => 2.0),
) == OrderedDict(:a => 0.0, :b => 2.0)
@test DifferentiableFlatten._merge(1, SS(1.0, 2.0)) == SS(1.0, 2.0)
x = OrderedDict(:a => 1.0)
@test DifferentiableFlatten._merge(
(1.0,),
Tangent{NamedTuple{(:b,), Tuple{Float64}}}(b = 1.0),
) == (ZeroTangent(),)

@test flatten(nothing)[1] == Float64[]
@test flatten(NoTangent())[1] == Float64[]
@test flatten(ZeroTangent())[1] == Float64[]
@test flatten(())[1] == Float64[]

@test zygote_flatten(1.0, nothing)[1] == [0.0]
@test zygote_flatten(1.0, NoTangent())[1] == [0.0]
@test zygote_flatten(1.0, ZeroTangent())[1] == [0.0]
@test zygote_flatten(1.0, ())[1] == Float64[]

x = JuMP.Containers.DenseAxisArray(reshape(Float64[1.0, 1.0], (2,)), 1)
@test zygote_flatten(x, (data = [1.0, 1.0],))[1] == [1.0, 1.0]
end

0 comments on commit 886c16f

Please sign in to comment.