Skip to content

Commit

Permalink
move code from NonconvexCore
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Oct 17, 2022
1 parent 720aa0b commit b3714e5
Show file tree
Hide file tree
Showing 3 changed files with 414 additions and 6 deletions.
16 changes: 15 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,25 @@ uuid = "c78775a3-ee38-4681-b694-0504db4f5dc7"
authors = ["Mohamed Tarek <[email protected]> and contributors"]
version = "0.1.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
ChainRulesCore = "1"
NamedTupleTools = "0.14"
OrderedCollections = "1"
Requires = "1"
julia = "1"

[extras]
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["JuMP", "Test", "Zygote"]
338 changes: 337 additions & 1 deletion src/DifferentiableFlatten.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,341 @@
module DifferentiableFlatten

# Write your package code here.
using SparseArrays, ChainRulesCore, NamedTupleTools, Requires, OrderedCollections

# Adapted from ParameterHandling.jl with the following license.
#=
Copyright (c) 2020 Invenia Technical Computing Corporation
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
=#

"""
flatten(x)
Returns a "flattened" representation of `x` as a vector of real numbers, and a function
`unflatten` that takes a vector of reals of the same length and returns an object of the
same type as `x`.
`unflatten` is the inverse of `flatten`, so
```julia
julia> x = (randn(5), 5.0, (a=5.0, b=randn(2, 3)));
julia> v, unflatten = flatten(x);
julia> x == unflatten(v)
true
```
"""
function flatten end

maybeflatten(x::Real) = x
maybeflatten(x) = flatten(x)

function flatten(x::Real)
v = [x]
unflatten_to_Real(v) = only(v)
return v, unflatten_to_Real
end

flatten(x::Vector{<:Real}) = (identity.(x), identity)

function flatten(x::AbstractVector)
x_vecs_and_backs = map(val -> flatten(val), identity.(x))
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
function Vector_from_vec(x_vec)
sz = _cumsum(map(_length, x_vecs))
x_Vec = [backs[n](x_vec[sz[n] - _length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x)]
return x_Vec
end
return reduce(vcat, x_vecs), Vector_from_vec
end

function flatten(x::AbstractArray)
x_vec, from_vec = flatten(vec(identity.(x)))
Array_from_vec(x_vec) = reshape(from_vec(x_vec), size(x))
return identity.(x_vec), Array_from_vec
end

function flatten(x::Tuple)
x_vecs_and_backs = map(val -> flatten(val), x)
x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
lengths = map(_length, x_vecs)
sz = _cumsum(lengths)
function unflatten_to_Tuple(v)
map(x_backs, lengths, sz) do x_back, l, s
return x_back(v[s - l + 1:s])
end
end
return reduce(vcat, x_vecs), unflatten_to_Tuple
end

function flatten(x::NamedTuple)
x_vec, unflatten = flatten(values(x))
function unflatten_to_NamedTuple(v)
v_vec_vec = unflatten(v)
return NamedTuple{keys(x)}(v_vec_vec)
end
return identity.(x_vec), unflatten_to_NamedTuple
end

function flatten(d::AbstractDict, ks = collect(keys(d)))
_d = OrderedDict(k => d[k] for k in ks)
d_vec, unflatten = flatten(identity.(collect(values(_d))))
function unflatten_to_Dict(v)
v_vec_vec = unflatten(v)
return _build_ordered_dict(v_vec_vec, keys(_d))
end
return identity.(d_vec), unflatten_to_Dict
end
function _build_ordered_dict(vals, keys)
OrderedDict(key => vals[n] for (n, key) in enumerate(keys))
end
function ChainRulesCore.rrule(::typeof(_build_ordered_dict), vals, keys)
_build_ordered_dict(vals, keys), Δ -> begin
NoTangent(), values(Δ), NoTangent()
end
end

function flatten(x)
v, un = flatten(ntfromstruct(x))
return identity.(v), Unflatten(x, y -> structfromnt(typeof(x), un(y)))
end

function zygote_flatten(::Real, x::Real)
v = [x]
unflatten_to_Real(v) = only(v)
return v, unflatten_to_Real
end

zygote_flatten(::Vector{<:Real}, x::Vector{<:Real}) = (identity.(x), identity)

# x_vecs_and_backs = map(val -> flatten(val), identity.(x))
# x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
# function Vector_from_vec(x_vec)
# sz = _cumsum(map(_length, x_vecs))
# x_Vec = [backs[n](x_vec[sz[n] - _length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x)]
# return x_Vec
# end
# return reduce(vcat, x_vecs), Vector_from_vec

function zygote_flatten(x1::AbstractVector, x2::AbstractVector)
x_vecs_and_backs = map(tuple.(identity.(x1), identity.(x2))) do val
zygote_flatten(val[1], val[2])
end
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
function Vector_from_vec(x_vec)
sz = _cumsum(map(_length, x_vecs))
x_Vec = [backs[n](x_vec[sz[n] - _length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x2)]
return x_Vec
end
return reduce(vcat, x_vecs), Vector_from_vec
end

function zygote_flatten(x1::AbstractArray, x2::AbstractArray)
x_vec, from_vec = zygote_flatten(vec(identity.(x1)), vec(identity.(x2)))
Array_from_vec(x_vec) = reshape(from_vec(x_vec), size(x2))
return identity.(x_vec), Array_from_vec
end

function zygote_flatten(x1::Tuple, x2::Tuple)
x_vecs_and_backs = map(tuple.(x1, x2)) do val
zygote_flatten(val[1], val[2])
end
x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
lengths = map(_length, x_vecs)
sz = _cumsum(lengths)
function unflatten_to_Tuple(v)
map(x_backs, lengths, sz) do x_back, l, s
return x_back(v[s - l + 1:s])
end
end
return reduce(vcat, x_vecs), unflatten_to_Tuple
end

function zygote_flatten(x1, x2::Tangent)
zygote_flatten(x1, ntfromstruct(x2).backing)
end
function zygote_flatten(x1, x2::NamedTuple)
zygote_flatten(ntfromstruct(x1), x2)
end

function zygote_flatten(x1::NamedTuple, x2::NamedTuple)
x_vec, unflatten = zygote_flatten(values(x1), values(x2))
function unflatten_to_NamedTuple(v)
v_vec_vec = unflatten(v)
return NamedTuple{keys(x1)}(v_vec_vec)
end
return identity.(x_vec), unflatten_to_NamedTuple
end

function zygote_flatten(d1::AbstractDict, d2::AbstractDict, ks = collect(keys(d2)))
_d1 = OrderedDict(k => d1[k] for k in ks)
_d2 = OrderedDict(k => d2[k] for k in ks)
d_vec, unflatten = zygote_flatten(identity.(collect(values(_d1))), identity.(collect(values(_d2))))
function unflatten_to_Dict(v)
v_vec_vec = unflatten(v)
return OrderedDict(key => v_vec_vec[n] for (n, key) in enumerate(ks))
end
return identity.(d_vec), Unflatten(d1, unflatten_to_Dict)
end

function zygote_flatten(x1, x2)
v, un = zygote_flatten(ntfromstruct(x1), ntfromstruct(x2))
return identity.(v), Unflatten(x1, y -> structfromnt(typeof(x2), un(y)))
end

_length(x) = length(x)
_length(::Nothing) = 0

function ChainRulesCore.rrule(::typeof(flatten), x)
d_vec, un = flatten(x)
return (d_vec, un), Δ -> begin
(NoTangent(), un(Δ[1]), NoTangent())
end
end
function ChainRulesCore.rrule(::typeof(flatten), d::AbstractDict, ks)
_d = OrderedDict(k => d[k] for k in ks)
d_vec, un = flatten(_d, ks)
return (d_vec, un), Δ -> begin
(NoTangent(), un(Δ[1]), NoTangent())
end
end

struct Unflatten{X, F} <: Function
x::X
unflatten::F
end
(f::Unflatten)(x) = f.unflatten(x)

_zero(x::Real) = zero(x)
_zero(x::AbstractArray) = _zero.(x)
_zero(x::AbstractDict) = Dict(keys(x) .=> map(_zero, values(x)))
_zero(x::NamedTuple) = map(_zero, x)
_zero(x::Tuple) = map(_zero, x)
_zero(x) = structfromnt(typeof(x), _zero(ntfromstruct(x)))

function _merge(d1::AbstractDict{K, V}, d2::AbstractDict) where {K, V}
_d = OrderedDict{K, V}(k => _zero(v) for (k, v) in d1)
return sort!(merge(_d, OrderedDict{K, V}(d2)))
end
function _merge(d1::Tuple, d2::Tangent)
return _merge.(d1, d2.backing)
end
_merge(::Any, d2) = d2

function ChainRulesCore.rrule(un::Unflatten, v)
x = un(v)
return x, Δ -> begin
= _merge(x, Δ)
return (NoTangent(), zygote_flatten(un.x, _Δ)[1])
end
end

function flatten(::Nothing)
return Float64[], _ -> nothing
end
function flatten(::NoTangent)
return Float64[], _ -> NoTangent()
end
function flatten(::ZeroTangent)
return Float64[], _ -> ZeroTangent()
end
function flatten(::Tuple{})
return Float64[], _ -> ()
end

function zygote_flatten(x, ::Nothing)
t = flatten(x)
return zero(t[1]), Base.tail(t)
end
function zygote_flatten(x, ::NoTangent)
t = flatten(x)
return zero(t[1]), Base.tail(t)
end
function zygote_flatten(x, ::ZeroTangent)
t = flatten(x)
return zero(t[1]), Base.tail(t)
end
function zygote_flatten(::Any, ::Tuple{})
return Float64[], _ -> ()
end

macro constructor(T)
return flatten_expr(T, T)
end
macro constructor(T, C)
return flatten_expr(T, C)
end
flatten_expr(T, C) = quote
function DifferentiableFlatten.flatten(x::$(esc(T)))
v, un = flatten(ntfromstruct(x))
return identity.(v), Unflatten(x, y -> structfromnt($(esc(C)), un(y)))
end
function DifferentiableFlatten.zygote_flatten(x1::$(esc(T)), x2::$(esc(T)))
v, un = zygote_flatten(ntfromstruct(x1), ntfromstruct(x2))
return identity.(v), Unflatten(x2, y -> structfromnt($(esc(C)), un(y)))
end
DifferentiableFlatten._zero(x::$(esc(T))) = structfromnt($(esc(C)), _zero(ntfromstruct(x)))
end

_cumsum(x) = cumsum(x)
if VERSION < v"1.5"
_cumsum(x::Tuple) = (_cumsum(collect(x))..., )
end

# Zygote can return a sparse vector co-tangent
# even if the input is a vector. This is causing
# issues in the rrule definition of Unflatten
flatten(x::SparseVector) = flatten(Array(x))

function flatten(x::SparseMatrixCSC)
x_vec, from_vec = flatten(x.nzval)
Array_from_vec(x_vec) = SparseMatrixCSC(x.m, x.n, x.colptr, x.rowval, from_vec(x_vec))
return identity.(x_vec), Array_from_vec
end

# Zygote can return a sparse vector co-tangent
# even if the input is a vector. This is causing
# issues in the rrule definition of Unflatten
zygote_flatten(x1::SparseVector, x2::SparseVector) = zygote_flatten(Array(x1), Array(x2))

function zygote_flatten(x1::SparseMatrixCSC, x2::SparseMatrixCSC)
x_vec, from_vec = zygote_flatten(x1.nzval, x2.nzval)
Array_from_vec(x_vec) = SparseMatrixCSC(x1.m, x1.n, x1.colptr, x1.rowval, from_vec(x_vec))
return identity.(x_vec), Unflatten(x1, Array_from_vec)
end

@init @require JuMP="4076af6c-e467-56ae-b986-b466b2749572" begin
import .JuMP
@eval begin
function flatten(x::JuMP.Containers.DenseAxisArray)
x_vec, from_vec = flatten(vec(identity.(x.data)))
Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray(reshape(from_vec(x_vec), size(x)), axes(x)...)
return identity.(x_vec), Array_from_vec
end

function zygote_flatten(x1::JuMP.Containers.DenseAxisArray, x2::NamedTuple)
x_vec, from_vec = zygote_flatten(vec(identity.(x1.data)), vec(identity.(x2.data)))
Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray(reshape(from_vec(x_vec), size(x2)), axes(x2)...)
return identity.(x_vec), Array_from_vec
end

function zygote_flatten(x1::JuMP.Containers.DenseAxisArray, x2::JuMP.Containers.DenseAxisArray)
x_vec, from_vec = zygote_flatten(vec(identity.(x1.data)), vec(identity.(x2.data)))
Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray(reshape(from_vec(x_vec), size(x2)), axes(x2)...)
return identity.(x_vec), Array_from_vec
end
end
end

end
Loading

0 comments on commit b3714e5

Please sign in to comment.