Skip to content

Commit

Permalink
fix singleton type support
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Oct 20, 2022
1 parent 886c16f commit 4d15f51
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/DifferentiableFlatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,12 @@ function ChainRulesCore.rrule(::typeof(_build_ordered_dict), vals, keys)
end

function flatten(x)
v, un = flatten(ntfromstruct(x))
return identity.(v), Unflatten(x, y -> structfromnt(typeof(x), un(y)))
if Base.issingletontype(typeof(x))
v, un = Union{}[], _ -> x
else
v, un = flatten(ntfromstruct(x))
return identity.(v), Unflatten(x, y -> structfromnt(typeof(x), un(y)))
end
end

function zygote_flatten(::Real, x::Real)
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,7 @@ MyStruct(a, b) = MyStruct{typeof(a), typeof(a), typeof(b)}(a, b)

x = JuMP.Containers.DenseAxisArray(reshape(Float64[1.0, 1.0], (2,)), 1)
@test zygote_flatten(x, (data = [1.0, 1.0],))[1] == [1.0, 1.0]

@test flatten(exp)[1] == Union{}[]
@test flatten(exp)[2]([]) === exp
end

0 comments on commit 4d15f51

Please sign in to comment.