From 022f2748d9ced2ff29e575f7bf217258994bf1a5 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Wed, 15 Nov 2023 22:29:24 +0200 Subject: [PATCH] formatting --- src/DifferentiableFlatten.jl | 53 ++++++++++++++++++++++++------------ test/runtests.jl | 30 ++++++++++---------- 2 files changed, 51 insertions(+), 32 deletions(-) diff --git a/src/DifferentiableFlatten.jl b/src/DifferentiableFlatten.jl index bea71d3..49bf4c1 100644 --- a/src/DifferentiableFlatten.jl +++ b/src/DifferentiableFlatten.jl @@ -54,7 +54,7 @@ function flatten(x::AbstractVector) x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs) function Vector_from_vec(x_vec) sz = _cumsum(map(_length, x_vecs)) - x_Vec = [backs[n](x_vec[sz[n] - _length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x)] + x_Vec = [backs[n](x_vec[sz[n]-_length(x_vecs[n])+1:sz[n]]) for n in eachindex(x)] return x_Vec end return reduce(vcat, x_vecs), Vector_from_vec @@ -73,7 +73,7 @@ function flatten(x::Tuple) sz = _cumsum(lengths) function unflatten_to_Tuple(v) map(x_backs, lengths, sz) do x_back, l, s - return x_back(v[s - l + 1:s]) + return x_back(v[s-l+1:s]) end end return reduce(vcat, x_vecs), unflatten_to_Tuple @@ -139,7 +139,7 @@ function zygote_flatten(x1::AbstractVector, x2::AbstractVector) x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs) function Vector_from_vec(x_vec) sz = _cumsum(map(_length, x_vecs)) - x_Vec = [backs[n](x_vec[sz[n] - _length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x2)] + x_Vec = [backs[n](x_vec[sz[n]-_length(x_vecs[n])+1:sz[n]]) for n in eachindex(x2)] return x_Vec end return reduce(vcat, x_vecs), Vector_from_vec @@ -160,7 +160,7 @@ function zygote_flatten(x1::Tuple, x2::Tuple) sz = _cumsum(lengths) function unflatten_to_Tuple(v) map(x_backs, lengths, sz) do x_back, l, s - return x_back(v[s - l + 1:s]) + return x_back(v[s-l+1:s]) end end return reduce(vcat, x_vecs), unflatten_to_Tuple @@ -185,7 +185,8 @@ end function zygote_flatten(d1::AbstractDict, d2::AbstractDict, ks = collect(keys(d2))) _d1 = OrderedDict(k => d1[k] for k in ks) _d2 = OrderedDict(k => d2[k] for k in ks) - d_vec, unflatten = zygote_flatten(identity.(collect(values(_d1))), identity.(collect(values(_d2)))) + d_vec, unflatten = + zygote_flatten(identity.(collect(values(_d1))), identity.(collect(values(_d2)))) function unflatten_to_Dict(v) v_vec_vec = unflatten(v) return OrderedDict(key => v_vec_vec[n] for (n, key) in enumerate(ks)) @@ -215,7 +216,7 @@ function ChainRulesCore.rrule(::typeof(flatten), d::AbstractDict, ks) end end -struct Unflatten{X, F} <: Function +struct Unflatten{X,F} <: Function x::X unflatten::F end @@ -228,9 +229,9 @@ _zero(x::NamedTuple) = map(_zero, x) _zero(x::Tuple) = map(_zero, x) _zero(x) = structfromnt(typeof(x), _zero(ntfromstruct(x))) -function _merge(d1::AbstractDict{K, V}, d2::AbstractDict) where {K, V} - _d = OrderedDict{K, V}(k => _zero(v) for (k, v) in d1) - return sort!(merge(_d, OrderedDict{K, V}(d2))) +function _merge(d1::AbstractDict{K,V}, d2::AbstractDict) where {K,V} + _d = OrderedDict{K,V}(k => _zero(v) for (k, v) in d1) + return sort!(merge(_d, OrderedDict{K,V}(d2))) end function _merge(d1::Tuple, d2::Tangent) return _merge.(d1, d2.backing) @@ -289,7 +290,8 @@ flatten_expr(T, C) = quote v, un = zygote_flatten(ntfromstruct(x1), ntfromstruct(x2)) return identity.(v), Unflatten(x2, y -> structfromnt($(esc(C)), un(y))) end - DifferentiableFlatten._zero(x::$(esc(T))) = structfromnt($(esc(C)), _zero(ntfromstruct(x))) + DifferentiableFlatten._zero(x::$(esc(T))) = + structfromnt($(esc(C)), _zero(ntfromstruct(x))) end _cumsum(x) = cumsum(x) @@ -312,28 +314,43 @@ zygote_flatten(x1::SparseVector, x2::SparseVector) = zygote_flatten(Array(x1), A function zygote_flatten(x1::SparseMatrixCSC, x2::SparseMatrixCSC) x_vec, from_vec = zygote_flatten(x1.nzval, x2.nzval) - Array_from_vec(x_vec) = SparseMatrixCSC(x1.m, x1.n, x1.colptr, x1.rowval, from_vec(x_vec)) + Array_from_vec(x_vec) = + SparseMatrixCSC(x1.m, x1.n, x1.colptr, x1.rowval, from_vec(x_vec)) return identity.(x_vec), Unflatten(x1, Array_from_vec) end -@init @require JuMP="4076af6c-e467-56ae-b986-b466b2749572" begin +@init @require JuMP = "4076af6c-e467-56ae-b986-b466b2749572" begin import .JuMP @eval begin function flatten(x::JuMP.Containers.DenseAxisArray) x_vec, from_vec = flatten(vec(identity.(x.data))) - Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray(reshape(from_vec(x_vec), size(x)), axes(x)...) + Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray( + reshape(from_vec(x_vec), size(x)), + axes(x)..., + ) return identity.(x_vec), Array_from_vec end function zygote_flatten(x1::JuMP.Containers.DenseAxisArray, x2::NamedTuple) - x_vec, from_vec = zygote_flatten(vec(identity.(x1.data)), vec(identity.(x2.data))) - Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray(reshape(from_vec(x_vec), size(x2)), axes(x2)...) + x_vec, from_vec = + zygote_flatten(vec(identity.(x1.data)), vec(identity.(x2.data))) + Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray( + reshape(from_vec(x_vec), size(x2)), + axes(x2)..., + ) return identity.(x_vec), Array_from_vec end - function zygote_flatten(x1::JuMP.Containers.DenseAxisArray, x2::JuMP.Containers.DenseAxisArray) - x_vec, from_vec = zygote_flatten(vec(identity.(x1.data)), vec(identity.(x2.data))) - Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray(reshape(from_vec(x_vec), size(x2)), axes(x2)...) + function zygote_flatten( + x1::JuMP.Containers.DenseAxisArray, + x2::JuMP.Containers.DenseAxisArray, + ) + x_vec, from_vec = + zygote_flatten(vec(identity.(x1.data)), vec(identity.(x2.data))) + Array_from_vec(x_vec) = JuMP.Containers.DenseAxisArray( + reshape(from_vec(x_vec), size(x2)), + axes(x2)..., + ) return identity.(x_vec), Array_from_vec end end diff --git a/test/runtests.jl b/test/runtests.jl index 979f3fa..e701ae6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,18 +4,18 @@ using OrderedCollections, JuMP, Zygote, SparseArrays, LinearAlgebra, Test using ChainRulesCore struct SS - a - b + a::Any + b::Any end -struct MyStruct{T, T1, T2} +struct MyStruct{T,T1,T2} a::T1 b::T2 end -MyStruct(a, b) = MyStruct{typeof(a), typeof(a), typeof(b)}(a, b) +MyStruct(a, b) = MyStruct{typeof(a),typeof(a),typeof(b)}(a, b) @constructor MyStruct MyStruct -@testset "DifferentiableFlatten.jl" begin +@testset "DifferentiableFlatten.jl" begin xs = [ 1.0, [1.0], @@ -24,14 +24,18 @@ MyStruct(a, b) = MyStruct{typeof(a), typeof(a), typeof(b)}(a, b) [1.0, (1.0, 2.0)], [1.0, OrderedDict(1 => Float64[1.0, 2.0])], [[1.0], OrderedDict(1 => Float64[1.0, 2.0])], - [(1.0,), [1.0,], OrderedDict(1 => Float64[1.0, 2.0])], + [(1.0,), [1.0], OrderedDict(1 => Float64[1.0, 2.0])], [1.0 1.0; 1.0 1.0], rand(2, 2, 2), [Float64[1.0, 2.0], Float64[3.0, 4.0]], OrderedDict(1 => 1.0), OrderedDict(1 => Float64[1.0]), OrderedDict(1 => 1.0, 2 => Float64[2.0]), - OrderedDict(1 => 1.0, 2 => Float64[2.0], 3 => [Float64[1.0, 2.0], Float64[3.0, 4.0]]), + OrderedDict( + 1 => 1.0, + 2 => Float64[2.0], + 3 => [Float64[1.0, 2.0], Float64[3.0, 4.0]], + ), JuMP.Containers.DenseAxisArray(reshape(Float64[1.0, 1.0], (2,)), 1), (1.0,), (1.0, 2.0), @@ -39,14 +43,14 @@ MyStruct(a, b) = MyStruct{typeof(a), typeof(a), typeof(b)}(a, b) (1.0, Float64[1.0, 2.0]), (1.0, OrderedDict(1 => Float64[1.0, 2.0])), ([1.0], OrderedDict(1 => Float64[1.0, 2.0])), - ((1.0,), [1.0,], OrderedDict(1 => Float64[1.0, 2.0])), + ((1.0,), [1.0], OrderedDict(1 => Float64[1.0, 2.0])), (a = 1.0,), (a = 1.0, b = 2.0), (a = 1.0, b = (1.0, 2.0)), (a = 1.0, b = Float64[1.0, 2.0]), (a = 1.0, b = OrderedDict(1 => Float64[1.0, 2.0])), (a = [1.0], b = OrderedDict(1 => Float64[1.0, 2.0])), - (a = (1.0,), b = [1.0,], c = OrderedDict(1 => Float64[1.0, 2.0])), + (a = (1.0,), b = [1.0], c = OrderedDict(1 => Float64[1.0, 2.0])), sparsevec(Float64[1.0, 2.0], [1, 3], 10), sparse([1, 2, 2, 3], [2, 3, 1, 4], Float64[1.0, 2.0, 3.0, 4.0], 10, 10), SS(1.0, 2.0), @@ -87,15 +91,13 @@ MyStruct(a, b) = MyStruct{typeof(a), typeof(a), typeof(b)}(a, b) @test unflatten(xvec) isa NamedTuple @test DifferentiableFlatten._length(nothing) == 0 - @test DifferentiableFlatten._merge( - OrderedDict(:a => 1.0), - OrderedDict(:b => 2.0), - ) == OrderedDict(:a => 0.0, :b => 2.0) + @test DifferentiableFlatten._merge(OrderedDict(:a => 1.0), OrderedDict(:b => 2.0)) == + OrderedDict(:a => 0.0, :b => 2.0) @test DifferentiableFlatten._merge(1, SS(1.0, 2.0)) == SS(1.0, 2.0) x = OrderedDict(:a => 1.0) @test DifferentiableFlatten._merge( (1.0,), - Tangent{NamedTuple{(:b,), Tuple{Float64}}}(b = 1.0), + Tangent{NamedTuple{(:b,),Tuple{Float64}}}(b = 1.0), ) == (ZeroTangent(),) @test flatten(nothing)[1] == Float64[]