Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Nov 15, 2023
1 parent ae4fe90 commit 022f274
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 32 deletions.
53 changes: 35 additions & 18 deletions src/DifferentiableFlatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function flatten(x::AbstractVector)
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
function Vector_from_vec(x_vec)
sz = _cumsum(map(_length, x_vecs))
x_Vec = [backs[n](x_vec[sz[n] - _length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x)]
x_Vec = [backs[n](x_vec[sz[n]-_length(x_vecs[n])+1:sz[n]]) for n in eachindex(x)]
return x_Vec
end
return reduce(vcat, x_vecs), Vector_from_vec
Expand All @@ -73,7 +73,7 @@ function flatten(x::Tuple)
sz = _cumsum(lengths)
function unflatten_to_Tuple(v)
map(x_backs, lengths, sz) do x_back, l, s
return x_back(v[s - l + 1:s])
return x_back(v[s-l+1:s])
end
end
return reduce(vcat, x_vecs), unflatten_to_Tuple
Expand Down Expand Up @@ -139,7 +139,7 @@ function zygote_flatten(x1::AbstractVector, x2::AbstractVector)
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
function Vector_from_vec(x_vec)
sz = _cumsum(map(_length, x_vecs))
x_Vec = [backs[n](x_vec[sz[n] - _length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x2)]
x_Vec = [backs[n](x_vec[sz[n]-_length(x_vecs[n])+1:sz[n]]) for n in eachindex(x2)]
return x_Vec
end
return reduce(vcat, x_vecs), Vector_from_vec
Expand All @@ -160,7 +160,7 @@ function zygote_flatten(x1::Tuple, x2::Tuple)
sz = _cumsum(lengths)
function unflatten_to_Tuple(v)
map(x_backs, lengths, sz) do x_back, l, s
return x_back(v[s - l + 1:s])
return x_back(v[s-l+1:s])
end
end
return reduce(vcat, x_vecs), unflatten_to_Tuple
Expand All @@ -185,7 +185,8 @@ end
function zygote_flatten(d1::AbstractDict, d2::AbstractDict, ks = collect(keys(d2)))
_d1 = OrderedDict(k => d1[k] for k in ks)
_d2 = OrderedDict(k => d2[k] for k in ks)
d_vec, unflatten = zygote_flatten(identity.(collect(values(_d1))), identity.(collect(values(_d2))))
d_vec, unflatten =
zygote_flatten(identity.(collect(values(_d1))), identity.(collect(values(_d2))))
function unflatten_to_Dict(v)
v_vec_vec = unflatten(v)
return OrderedDict(key => v_vec_vec[n] for (n, key) in enumerate(ks))
Expand Down Expand Up @@ -215,7 +216,7 @@ function ChainRulesCore.rrule(::typeof(flatten), d::AbstractDict, ks)
end
end

struct Unflatten{X, F} <: Function
struct Unflatten{X,F} <: Function
x::X
unflatten::F
end
Expand All @@ -228,9 +229,9 @@ _zero(x::NamedTuple) = map(_zero, x)
_zero(x::Tuple) = map(_zero, x)
_zero(x) = structfromnt(typeof(x), _zero(ntfromstruct(x)))

function _merge(d1::AbstractDict{K, V}, d2::AbstractDict) where {K, V}
_d = OrderedDict{K, V}(k => _zero(v) for (k, v) in d1)
return sort!(merge(_d, OrderedDict{K, V}(d2)))
function _merge(d1::AbstractDict{K,V}, d2::AbstractDict) where {K,V}
_d = OrderedDict{K,V}(k => _zero(v) for (k, v) in d1)
return sort!(merge(_d, OrderedDict{K,V}(d2)))
end
function _merge(d1::Tuple, d2::Tangent)
return _merge.(d1, d2.backing)
Expand Down Expand Up @@ -289,7 +290,8 @@ flatten_expr(T, C) = quote
v, un = zygote_flatten(ntfromstruct(x1), ntfromstruct(x2))
return identity.(v), Unflatten(x2, y -> structfromnt($(esc(C)), un(y)))
end
DifferentiableFlatten._zero(x::$(esc(T))) = structfromnt($(esc(C)), _zero(ntfromstruct(x)))
DifferentiableFlatten._zero(x::$(esc(T))) =
structfromnt($(esc(C)), _zero(ntfromstruct(x)))
end

_cumsum(x) = cumsum(x)
Expand All @@ -312,28 +314,43 @@ zygote_flatten(x1::SparseVector, x2::SparseVector) = zygote_flatten(Array(x1), A

function zygote_flatten(x1::SparseMatrixCSC, x2::SparseMatrixCSC)
x_vec, from_vec = zygote_flatten(x1.nzval, x2.nzval)
Array_from_vec(x_vec) = SparseMatrixCSC(x1.m, x1.n, x1.colptr, x1.rowval, from_vec(x_vec))
Array_from_vec(x_vec) =
SparseMatrixCSC(x1.m, x1.n, x1.colptr, x1.rowval, from_vec(x_vec))
return identity.(x_vec), Unflatten(x1, Array_from_vec)
end

@init @require JuMP="4076af6c-e467-56ae-b986-b466b2749572" begin
@init @require JuMP = "4076af6c-e467-56ae-b986-b466b2749572" begin
import .JuMP
@eval begin
function flatten(x::JuMP.Containers.DenseAxisArray)
x_vec, from_vec = flatten(vec(identity.(x.data)))
Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray(reshape(from_vec(x_vec), size(x)), axes(x)...)
Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray(
reshape(from_vec(x_vec), size(x)),
axes(x)...,
)
return identity.(x_vec), Array_from_vec
end

function zygote_flatten(x1::JuMP.Containers.DenseAxisArray, x2::NamedTuple)
x_vec, from_vec = zygote_flatten(vec(identity.(x1.data)), vec(identity.(x2.data)))
Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray(reshape(from_vec(x_vec), size(x2)), axes(x2)...)
x_vec, from_vec =
zygote_flatten(vec(identity.(x1.data)), vec(identity.(x2.data)))
Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray(
reshape(from_vec(x_vec), size(x2)),
axes(x2)...,
)
return identity.(x_vec), Array_from_vec
end

function zygote_flatten(x1::JuMP.Containers.DenseAxisArray, x2::JuMP.Containers.DenseAxisArray)
x_vec, from_vec = zygote_flatten(vec(identity.(x1.data)), vec(identity.(x2.data)))
Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray(reshape(from_vec(x_vec), size(x2)), axes(x2)...)
function zygote_flatten(
x1::JuMP.Containers.DenseAxisArray,
x2::JuMP.Containers.DenseAxisArray,
)
x_vec, from_vec =
zygote_flatten(vec(identity.(x1.data)), vec(identity.(x2.data)))
Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray(
reshape(from_vec(x_vec), size(x2)),
axes(x2)...,
)
return identity.(x_vec), Array_from_vec
end
end
Expand Down
30 changes: 16 additions & 14 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ using OrderedCollections, JuMP, Zygote, SparseArrays, LinearAlgebra, Test
using ChainRulesCore

struct SS
a
b
a::Any
b::Any
end

struct MyStruct{T, T1, T2}
struct MyStruct{T,T1,T2}
a::T1
b::T2
end
MyStruct(a, b) = MyStruct{typeof(a), typeof(a), typeof(b)}(a, b)
MyStruct(a, b) = MyStruct{typeof(a),typeof(a),typeof(b)}(a, b)
@constructor MyStruct MyStruct

@testset "DifferentiableFlatten.jl" begin
@testset "DifferentiableFlatten.jl" begin
xs = [
1.0,
[1.0],
Expand All @@ -24,29 +24,33 @@ MyStruct(a, b) = MyStruct{typeof(a), typeof(a), typeof(b)}(a, b)
[1.0, (1.0, 2.0)],
[1.0, OrderedDict(1 => Float64[1.0, 2.0])],
[[1.0], OrderedDict(1 => Float64[1.0, 2.0])],
[(1.0,), [1.0,], OrderedDict(1 => Float64[1.0, 2.0])],
[(1.0,), [1.0], OrderedDict(1 => Float64[1.0, 2.0])],
[1.0 1.0; 1.0 1.0],
rand(2, 2, 2),
[Float64[1.0, 2.0], Float64[3.0, 4.0]],
OrderedDict(1 => 1.0),
OrderedDict(1 => Float64[1.0]),
OrderedDict(1 => 1.0, 2 => Float64[2.0]),
OrderedDict(1 => 1.0, 2 => Float64[2.0], 3 => [Float64[1.0, 2.0], Float64[3.0, 4.0]]),
OrderedDict(
1 => 1.0,
2 => Float64[2.0],
3 => [Float64[1.0, 2.0], Float64[3.0, 4.0]],
),
JuMP.Containers.DenseAxisArray(reshape(Float64[1.0, 1.0], (2,)), 1),
(1.0,),
(1.0, 2.0),
(1.0, (1.0, 2.0)),
(1.0, Float64[1.0, 2.0]),
(1.0, OrderedDict(1 => Float64[1.0, 2.0])),
([1.0], OrderedDict(1 => Float64[1.0, 2.0])),
((1.0,), [1.0,], OrderedDict(1 => Float64[1.0, 2.0])),
((1.0,), [1.0], OrderedDict(1 => Float64[1.0, 2.0])),
(a = 1.0,),
(a = 1.0, b = 2.0),
(a = 1.0, b = (1.0, 2.0)),
(a = 1.0, b = Float64[1.0, 2.0]),
(a = 1.0, b = OrderedDict(1 => Float64[1.0, 2.0])),
(a = [1.0], b = OrderedDict(1 => Float64[1.0, 2.0])),
(a = (1.0,), b = [1.0,], c = OrderedDict(1 => Float64[1.0, 2.0])),
(a = (1.0,), b = [1.0], c = OrderedDict(1 => Float64[1.0, 2.0])),
sparsevec(Float64[1.0, 2.0], [1, 3], 10),
sparse([1, 2, 2, 3], [2, 3, 1, 4], Float64[1.0, 2.0, 3.0, 4.0], 10, 10),
SS(1.0, 2.0),
Expand Down Expand Up @@ -87,15 +91,13 @@ MyStruct(a, b) = MyStruct{typeof(a), typeof(a), typeof(b)}(a, b)
@test unflatten(xvec) isa NamedTuple
@test DifferentiableFlatten._length(nothing) == 0

@test DifferentiableFlatten._merge(
OrderedDict(:a => 1.0),
OrderedDict(:b => 2.0),
) == OrderedDict(:a => 0.0, :b => 2.0)
@test DifferentiableFlatten._merge(OrderedDict(:a => 1.0), OrderedDict(:b => 2.0)) ==
OrderedDict(:a => 0.0, :b => 2.0)
@test DifferentiableFlatten._merge(1, SS(1.0, 2.0)) == SS(1.0, 2.0)
x = OrderedDict(:a => 1.0)
@test DifferentiableFlatten._merge(
(1.0,),
Tangent{NamedTuple{(:b,), Tuple{Float64}}}(b = 1.0),
Tangent{NamedTuple{(:b,),Tuple{Float64}}}(b = 1.0),
) == (ZeroTangent(),)

@test flatten(nothing)[1] == Float64[]
Expand Down

0 comments on commit 022f274

Please sign in to comment.