diff --git a/src/DifferentiableFlatten.jl b/src/DifferentiableFlatten.jl index 12af175..50296a9 100644 --- a/src/DifferentiableFlatten.jl +++ b/src/DifferentiableFlatten.jl @@ -289,9 +289,6 @@ flatten_expr(T, C) = quote end _cumsum(x) = cumsum(x) -if VERSION < v"1.5" - _cumsum(x::Tuple) = (_cumsum(collect(x))..., ) -end # Zygote can return a sparse vector co-tangent # even if the input is a vector. This is causing diff --git a/test/runtests.jl b/test/runtests.jl index f22f164..90cd46a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,11 +1,20 @@ -using DifferentiableFlatten: flatten, zygote_flatten +using DifferentiableFlatten: flatten, zygote_flatten, maybeflatten +using DifferentiableFlatten: DifferentiableFlatten, @constructor using OrderedCollections, JuMP, Zygote, SparseArrays, LinearAlgebra, Test +using ChainRulesCore struct SS a b end +struct MyStruct{T, T1, T2} + a::T1 + b::T2 +end +MyStruct(a, b) = MyStruct{typeof(a), typeof(a), typeof(b)}(a, b) +@constructor MyStruct MyStruct + @testset "DifferentiableFlatten.jl" begin xs = [ 1.0, @@ -42,6 +51,7 @@ end sparse([1, 2, 2, 3], [2, 3, 1, 4], Float64[1.0, 2.0, 3.0, 4.0], 10, 10), SS(1.0, 2.0), [SS(1.0, 2.0), 1.0], + MyStruct(1.0, 1.0), ] for x in xs @show x @@ -60,5 +70,44 @@ end zygote_flatten(x, x)[1] end[1] @test logabsdet(J) == (0.0, 1.0) + if x isa Real + @test maybeflatten(x) == x + else + xvec, unflatten = maybeflatten(x) + @test x == unflatten(xvec) + end + @show DifferentiableFlatten._zero(x) + @test all(==(0), flatten(DifferentiableFlatten._zero(x))[1]) end + + xvec, unflatten = zygote_flatten(SS(1.0, 2.0), Tangent{SS}(a = 1.0, b = 2.0)) + @test unflatten(xvec) isa NamedTuple + + xvec, unflatten = zygote_flatten(SS(1.0, 2.0), (a = 1.0, b = 2.0)) + @test unflatten(xvec) isa NamedTuple + @test DifferentiableFlatten._length(nothing) == 0 + + @test DifferentiableFlatten._merge( + OrderedDict(:a => 1.0), + OrderedDict(:b => 2.0), + ) == OrderedDict(:a => 0.0, :b => 2.0) + @test DifferentiableFlatten._merge(1, SS(1.0, 2.0)) == SS(1.0, 2.0) + x = OrderedDict(:a => 1.0) + @test DifferentiableFlatten._merge( + (1.0,), + Tangent{NamedTuple{(:b,), Tuple{Float64}}}(b = 1.0), + ) == (ZeroTangent(),) + + @test flatten(nothing)[1] == Float64[] + @test flatten(NoTangent())[1] == Float64[] + @test flatten(ZeroTangent())[1] == Float64[] + @test flatten(())[1] == Float64[] + + @test zygote_flatten(1.0, nothing)[1] == [0.0] + @test zygote_flatten(1.0, NoTangent())[1] == [0.0] + @test zygote_flatten(1.0, ZeroTangent())[1] == [0.0] + @test zygote_flatten(1.0, ())[1] == Float64[] + + x = JuMP.Containers.DenseAxisArray(reshape(Float64[1.0, 1.0], (2,)), 1) + @test zygote_flatten(x, (data = [1.0, 1.0],))[1] == [1.0, 1.0] end