From 0a941246191859cd4b39ace717b4920a9bc91933 Mon Sep 17 00:00:00 2001 From: Michael Dudkowiak Date: Fri, 1 Sep 2023 18:35:53 +0200 Subject: [PATCH] Add ScaleShiftModule and tests --- Project.toml | 6 +++++- src/AdaptiveFlows.jl | 4 ++++ src/adaptive_flows.jl | 10 ++++++++++ src/scale_shift.jl | 34 +++++++++++++++++++++++++++++++++ test/Project.toml | 3 +++ test/runtests.jl | 3 ++- test/test_adaptive_flows.jl | 5 +++++ test/test_scale_shift.jl | 38 +++++++++++++++++++++++++++++++++++++ 8 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 src/scale_shift.jl create mode 100644 test/test_scale_shift.jl diff --git a/Project.toml b/Project.toml index 973c0d3..a6ab2e4 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "24d2106d-e7e1-4641-aa0a-4a5934943aa1" version = "0.1.0" [deps] +AffineMaps = "2c83c9a8-abf5-4329-a0d7-deffaf474661" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" @@ -10,15 +11,18 @@ FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" HeterogeneousComputing = "2182be2a-124f-4a91-8389-f06db5907a21" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MonotonicSplines = "568f7cb4-8305-41bc-b90d-d32b39cc99d1" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +AffineMaps = "0.1, 0.2" ArgCheck = "2" ArraysOfArrays = "0.5.1, 0.6" ChangesOfVariables = "0.1.3" @@ -28,7 +32,7 @@ HeterogeneousComputing = "0.1, 0.2" InverseFunctions = "0.1" Lux = "0.5" MonotonicSplines = "0.1" -Optimisers = "0.2" +Optimisers = "0.2, 0.3" StatsFuns = "1" ValueShapes = "0.8.3, 0.9, 0.10" Zygote = "0.6" diff --git a/src/AdaptiveFlows.jl b/src/AdaptiveFlows.jl index 558c63a..6c52dae 100644 --- a/src/AdaptiveFlows.jl +++ b/src/AdaptiveFlows.jl @@ -7,6 +7,7 @@ Adaptive normalizing flows. """ module AdaptiveFlows +using AffineMaps using ArgCheck using ArraysOfArrays using ChangesOfVariables @@ -14,10 +15,12 @@ using FunctionChains using Functors using HeterogeneousComputing using InverseFunctions +using LinearAlgebra using Lux using MonotonicSplines using Optimisers using Random +using Statistics using StatsFuns using ValueShapes using Zygote @@ -25,5 +28,6 @@ using Zygote include("adaptive_flows.jl") include("optimize_flow.jl") include("rqspline_coupling.jl") +include("scale_shift.jl") include("utils.jl") end # module diff --git a/src/adaptive_flows.jl b/src/adaptive_flows.jl index 95f60d4..35650ee 100644 --- a/src/adaptive_flows.jl +++ b/src/adaptive_flows.jl @@ -47,6 +47,16 @@ function InverseFunctions.inverse(f::CompositeFlow) return CompositeFlow(InverseFunctions.inverse(f.flow).fs) end +function prepend_flow_module(f::CompositeFlow, new_module::F) where F<:AbstractFlow + return CompositeFlow([new_module, f.flow.fs...]) +end +export prepend_flow_module + +function append_flow_module(f::CompositeFlow, new_module::F) where F<:AbstractFlow + return CompositeFlow([f.flow.fs..., new_module]) +end +export append_flow_module + """ AbstractFlowModule <: AbstractFlow diff --git a/src/scale_shift.jl b/src/scale_shift.jl new file mode 100644 index 0000000..a95232f --- /dev/null +++ b/src/scale_shift.jl @@ -0,0 +1,34 @@ +# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT). + +struct ScaleShiftModule <: AbstractFlowModule + A::Matrix{Real} + b::Vector{Real} +end + +export ScaleShiftModule +@functor ScaleShiftModule + +function ScaleShiftModule(stds::AbstractVector, means::AbstractVector) + A = Diagonal(inv.(stds)) + return ScaleShiftModule(A, .- A * means) +end + +function ScaleShiftModule(x::AbstractArray) + stds = vec(std(x, dims = 2)) + means = vec(mean(x, dims = 2)) + ScaleShiftModule(stds, means) +end + +function ChangesOfVariables.with_logabsdet_jacobian(f::ScaleShiftModule, x::Any) + y, ladj = ChangesOfVariables.with_logabsdet_jacobian(MulAdd(f.A, f.b), x) + + return y, fill(ladj, 1, size(y,2)) +end + +(f::ScaleShiftModule)(x::AbstractMatrix) = MulAdd(f.A, f.b)(x) +(f::ScaleShiftModule)(vs::AbstractValueShape) = vs + +function InverseFunctions.inverse(f::ScaleShiftModule) + A = inv(f.A) + return ScaleShiftModule(A, .- A * f.b) +end diff --git a/test/Project.toml b/test/Project.toml index 67f95d6..8914214 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +AffineMaps = "2c83c9a8-abf5-4329-a0d7-deffaf474661" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" @@ -7,9 +8,11 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b" HeterogeneousComputing = "2182be2a-124f-4a91-8389-f06db5907a21" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f" diff --git a/test/runtests.jl b/test/runtests.jl index c3d1197..038c217 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,9 +3,10 @@ import Test Test.@testset "Package AdaptiveFlows" begin - include("test_aqua.jl") include("test_adaptive_flows.jl") + include("test_aqua.jl") include("test_docs.jl") + include("test_scale_shift.jl") include("test_optimize_flow.jl") include("test_rqspline_coupling.jl") end # testset diff --git a/test/test_adaptive_flows.jl b/test/test_adaptive_flows.jl index 00bcd24..d3a6014 100644 --- a/test/test_adaptive_flows.jl +++ b/test/test_adaptive_flows.jl @@ -19,6 +19,8 @@ x = randn(rng, n_dims, n_smpls) vs_test = valshape(x) comp_flow_test = CompositeFlow([RQSplineCouplingModule(4), RQSplineCouplingModule(4)]) +prepended_flow_test = prepend_flow_module(comp_flow_test, ScaleShiftModule(ones(4), zeros(4))) +appended_flow_test = append_flow_module(comp_flow_test, ScaleShiftModule(ones(4), zeros(4))) # test outputs # comp_flow_y_test, comp_flow_ladj_test = with_logabsdet_jacobian(comp_flow_test, x) @@ -32,4 +34,7 @@ comp_flow_ladj_test = readdlm("test_outputs/comp_flow_ladj_test.txt") @test all(isapprox.(ChangesOfVariables.with_logabsdet_jacobian(InverseFunctions.inverse(comp_flow_test), comp_flow_y_test), (x, .- comp_flow_ladj_test))) @test isapprox(InverseFunctions.inverse(comp_flow_test)(comp_flow_y_test), x) + + @test prepended_flow_test.flow.fs[1] isa ScaleShiftModule + @test appended_flow_test.flow.fs[end] isa ScaleShiftModule end diff --git a/test/test_scale_shift.jl b/test/test_scale_shift.jl new file mode 100644 index 0000000..b378be3 --- /dev/null +++ b/test/test_scale_shift.jl @@ -0,0 +1,38 @@ +# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT). + +using AdaptiveFlows +using Test + +using ArraysOfArrays +using InverseFunctions +using LinearAlgebra +using Random +using Statistics +using ValueShapes + +# test inputs +n_dims = 4 +n_smpls = 10 + +rng = MersenneTwister(1234) +x = muladd(Diagonal(randn(rng, n_dims)), randn(rng, n_dims, n_smpls), randn(rng, n_dims)) + +smpls = nestedview(x) +vs_test = valshape(x) + +scale_shift_test = ScaleShiftModule(x) +inv_scale_shift_test = inverse(scale_shift_test) + +y_test = scale_shift_test(x) +x_inverted_test = inv_scale_shift_test(y_test) + +stds_test = vec(std(y_test, dims = 2)) +means_test = vec(mean(y_test, dims = 2)) + +@testset "ScaleShiftModule" begin + @test all(isapprox.(stds_test, 1)) && all(isapprox.(means_test, 0, atol = 1f-15)) + @test all(isapprox.(x_inverted_test, x)) + + @test scale_shift_test(vs_test) == vs_test + @test all(isapprox.(with_logabsdet_jacobian(scale_shift_test, x)[2], 10.637223371435223)) +end