diff --git a/src/DifferentiableFlatten.jl b/src/DifferentiableFlatten.jl index 50296a9..bea71d3 100644 --- a/src/DifferentiableFlatten.jl +++ b/src/DifferentiableFlatten.jl @@ -107,8 +107,12 @@ function ChainRulesCore.rrule(::typeof(_build_ordered_dict), vals, keys) end function flatten(x) - v, un = flatten(ntfromstruct(x)) - return identity.(v), Unflatten(x, y -> structfromnt(typeof(x), un(y))) + if Base.issingletontype(typeof(x)) + v, un = Union{}[], _ -> x + else + v, un = flatten(ntfromstruct(x)) + return identity.(v), Unflatten(x, y -> structfromnt(typeof(x), un(y))) + end end function zygote_flatten(::Real, x::Real) diff --git a/test/runtests.jl b/test/runtests.jl index 90cd46a..979f3fa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -110,4 +110,7 @@ MyStruct(a, b) = MyStruct{typeof(a), typeof(a), typeof(b)}(a, b) x = JuMP.Containers.DenseAxisArray(reshape(Float64[1.0, 1.0], (2,)), 1) @test zygote_flatten(x, (data = [1.0, 1.0],))[1] == [1.0, 1.0] + + @test flatten(exp)[1] == Union{}[] + @test flatten(exp)[2]([]) === exp end