Skip to content

Commit

Permalink
Adjust tests for optimize_flow.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
Micki-D committed Feb 24, 2024
1 parent ff1193a commit 7e43779
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 38 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MonotonicSplines = "568f7cb4-8305-41bc-b90d-d32b39cc99d1"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Expand All @@ -31,9 +30,12 @@ FunctionChains = "0.1"
Functors = "0.2, 0.3, 0.4"
HeterogeneousComputing = "0.1, 0.2"
InverseFunctions = "0.1"
LinearAlgebra = "1"
Lux = "0.5"
MonotonicSplines = "0.1.1"
Optimisers = "0.2, 0.3"
Random = "1"
Statistics = "1, 2"
StatsFuns = "1"
ValueShapes = "0.8.3, 0.9, 0.10"
Zygote = "0.6"
Expand Down
1 change: 0 additions & 1 deletion src/AdaptiveFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ using LinearAlgebra
using Lux
using MonotonicSplines
using Optimisers
using ProgressBars
using Random
using Statistics
using StatsFuns
Expand Down
17 changes: 11 additions & 6 deletions src/optimize_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
std_normal_logpdf(x::Real) = -(abs2(x) + log2π)/2
std_normal_logpdf(x::AbstractArray) = vec(sum(std_normal_logpdf.(flatview(x)), dims = 1))

function negll_flow_loss(flow::F, x::AbstractMatrix{<:Real}, logd_orig::AbstractVector, logpdf::Function) where F<:AbstractFlow
function negll_flow_loss(flow::F, x::AbstractMatrix{<:Real}, logpdf::Function) where F<:AbstractFlow
nsamples = size(x, 2)
flow_corr = fchain(flow,logpdf.f)
y, ladj = with_logabsdet_jacobian(flow_corr, x)
Expand All @@ -12,15 +12,15 @@ function negll_flow_loss(flow::F, x::AbstractMatrix{<:Real}, logd_orig::Abstract
end

function negll_flow(flow::F, x::AbstractMatrix{<:Real}, logd_orig::AbstractVector, logpdf::Tuple{Function, Function}) where F<:AbstractFlow
negll, back = Zygote.pullback(negll_flow, flow, x, logd_orig, logpdf[2])
negll, back = Zygote.pullback(negll_flow_loss, flow, x, logpdf[2])
d_flow = back(one(eltype(x)))[1]
return negll, d_flow
end
export negll_flow

function KLDiv_flow_loss(flow::F, x::AbstractMatrix{<:Real}, logd_orig::AbstractVector, logpdfs::Tuple{Function, Function}) where F<:AbstractFlow
nsamples = size(x, 2)
flow_corr = fchain(flow,logpdfs[2].f)
flow_corr = fchain(flow, logpdfs[2].f)
logpdf_y = logpdfs[2].logdensity
y, ladj = with_logabsdet_jacobian(flow_corr, x)
KLDiv = sum(exp.(logd_orig - vec(ladj)) .* (logd_orig - vec(ladj) - logpdf_y(y))) / nsamples
Expand All @@ -38,7 +38,7 @@ function optimize_flow(samples::Union{Matrix, Tuple{Matrix, Matrix}},
initial_flow::F where F<:AbstractFlow,
optimizer;
sequential::Bool = true,
loss::Function = negll_flow_grad,
loss::Function = negll_flow,
logpdf::Union{Function, Tuple{Function, Function}} = std_normal_logpdf,
nbatches::Integer = 10,
nepochs::Integer = 100,
Expand Down Expand Up @@ -75,12 +75,17 @@ function optimize_flow(samples::Union{AbstractArray, Tuple{AbstractArray, Abstra

n_dims = _get_n_dims(samples)
logd_orig = samples isa Tuple ? logpdf[1](samples[1]) : logpdf[1](samples)
pushfwd_logpdf = logpdf[2] == std_normal_logpdf ? (PushForwardLogDensity(first(initial_flow.flow.fs), logpdf[1]), PushForwardLogDensity(FlowModule(InvMulAdd(I(n_dims), zeros(n_dims)), false), logpdf[2])) : (PushForwardLogDensity(first(initial_flow.flow.fs), logpdf[1]), PushForwardLogDensity(last(initial_flow.flow.fs), logpdf[2]))

if !(initial_flow isa AbstractFlowBlock)
pushfwd_logpdf = logpdf[2] == std_normal_logpdf ? (PushForwardLogDensity(first(initial_flow.flow.fs), logpdf[1]), PushForwardLogDensity(FlowModule(InvMulAdd(I(n_dims), zeros(n_dims)), false), logpdf[2])) : (PushForwardLogDensity(first(initial_flow.flow.fs), logpdf[1]), PushForwardLogDensity(last(initial_flow.flow.fs), logpdf[2]))
else
pushfwd_logpdf = (PushForwardLogDensity(InvMulAdd(I(n_dims), zeros(n_dims)), logpdf[1]), PushForwardLogDensity(InvMulAdd(I(n_dims), zeros(n_dims)), logpdf[2]))
end

if sequential
flow, state, loss_hist = _train_flow_sequentially(samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples)
else
flow, state, loss_hist = _train_flow(samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpd, logd_orig, shuffle_samples)
flow, state, loss_hist = _train_flow(samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples)
end

return (result = flow, optimizer_state = state, loss_hist = vcat(loss_history, loss_hist))
Expand Down
10 changes: 6 additions & 4 deletions test/test_aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@ import Test
import Aqua
import AdaptiveFlows

# ToDo: Fix ambiguities and enable ambiguity testing:
#=
Test.@testset "Package ambiguities" begin
Test.@test isempty(Test.detect_ambiguities(AdaptiveFlows))
end # testset
end
=#

Test.@testset "Aqua tests" begin
Aqua.test_all(
AdaptiveFlows,
ambiguities = false,
piracy = false,
project_toml_formatting = VERSIONv"1.7"
unbound_args = false
)
end # testset
end
Loading

0 comments on commit 7e43779

Please sign in to comment.