Skip to content

Commit

Permalink
Fix gradient error
Browse files Browse the repository at this point in the history
  • Loading branch information
Micki-D committed Oct 7, 2023
1 parent cfcf5a3 commit 653ecbb
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 54 deletions.
3 changes: 2 additions & 1 deletion src/AdaptiveFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using AffineMaps
using ArgCheck
using ArraysOfArrays
using ChangesOfVariables
# using Flux # only for debugging
using FunctionChains
using Functors
using HeterogeneousComputing
Expand All @@ -28,6 +29,6 @@ using Zygote
include("adaptive_flows.jl")
include("optimize_flow.jl")
include("rqspline_coupling.jl")
include("scale_shift.jl")
#include("scale_shift.jl")
include("utils.jl")
end # module
124 changes: 118 additions & 6 deletions src/adaptive_flows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
"""
AbstractFlow <: Function
Abstract supertype for all functions that are used as "Normalizing Flows" in this package.
Abstract supertype for all functions that are used as "Normalizing Flows".
A Normalizing Flow is an invertible and diferentiable function from
a `D`-dimensional space to a `D`-dimensional space.
In this implementation, a flow may be applied to a batch of samples from a target distribution.
A flow may be applied to a batch of samples from a target distribution.
Depending on the selected computing device and the specific flow, the input samples may be transformed in parallel.
Here, a flow returns a tuple of the transformed output of the flow, and a row matrix, with the `i`-th entry holding
Expand All @@ -19,7 +19,7 @@ end
"""
CompositeFlow <: AbstractFlow
A concrete subtype of `AbstractFlow`[@ref]. A `CompositeFlow` is a composition of several flow modules (see `AbstractFlowModule`[@ref]),
A `CompositeFlow` is a composition of several flow modules (see `AbstractFlowModule`[@ref]),
individual normalizing flows, that each transform all components of the input data.
"""
struct CompositeFlow <: AbstractFlow
Expand All @@ -33,6 +33,10 @@ function CompositeFlow(modules::Vector{F}) where F <: Function
return CompositeFlow(fchain(modules))
end

function CompositeFlow(n_dims::Integer, modules::Vector{F}) where F <: Function
build_flow(ndims, modules)
end

function ChangesOfVariables.with_logabsdet_jacobian(
f::CompositeFlow,
x::Any
Expand All @@ -47,11 +51,23 @@ function InverseFunctions.inverse(f::CompositeFlow)
return CompositeFlow(InverseFunctions.inverse(f.flow).fs)
end

"""
prepend_flow_module(f::CompositeFlow, new_module::F) where F<:AbstractFlow
Prepend the chain of flow modules in `f` with `new_module`. Meaning that `new_module`
will be applied first in the resulting flow.
"""
function prepend_flow_module(f::CompositeFlow, new_module::F) where F<:AbstractFlow
return CompositeFlow([new_module, f.flow.fs...])
end
export prepend_flow_module

"""
append_flow_module(f::CompositeFlow, new_module::F) where F<:AbstractFlow
Append `new_module` to the the chain of flow modules in `f`. Meaning that `new_module`
will be applied last in the resulting flow.
"""
function append_flow_module(f::CompositeFlow, new_module::F) where F<:AbstractFlow
return CompositeFlow([f.flow.fs..., new_module])
end
Expand All @@ -60,20 +76,116 @@ export append_flow_module
"""
AbstractFlowModule <: AbstractFlow
A concrete subtype of `AbstractFlow`[@ref]. Here, a flow module is a normalizing flow that transforms each of
the input components.
A flow module is a normalizing flow that transforms each of the input components.
A flow module may consist of a scaling and shifting operation of the input samples, or be a composition of
several flow blocks of a specific type (see `AbstractFlowBlock`).
"""
abstract type AbstractFlowModule <: AbstractFlow
end

struct FlowModule <: AbstractFlowModule
flow::Function
trainable::Bool
end

export FlowModule
@functor FlowModule

function ChangesOfVariables.with_logabsdet_jacobian(
f::FlowModule,
x::Any
)
with_logabsdet_jacobian(f.flow, x)
end

(f::FlowModule)(x::Any) = f.flow(x)
(f::FlowModule)(vs::AbstractValueShape) = vs

function InverseFunctions.inverse(f::FlowModule)
return FlowModule(InverseFunctions.inverse(f.flow), f.trainable)
end

"""
AbstractFlowBlock <: AbstractFlowModule
A concrete subtype of `AbstractFlowModule`[@ref]. Here, a flow block is a normalizing flow that may only
A flow block is a normalizing flow that may only
transform a fraction of the components of the input samples. To transform all components of the input,
several flow blocks must be composed to a flow module (see `AbstractFlowModule`).
"""
abstract type AbstractFlowBlock <: AbstractFlowModule
end

"""
build_flow(n_dims::Integer, modules::Vector, compute_unit::AbstractComputeUnit = CPUnit())
Construct a `CompositeFlow` to transfrom samples from a `n_dims` -dimensional target distribution,
with the component modules in `modules`. The flow is initialized to target objects stored on `compute_unit` (defaults to CPU)
The first entry in `modules` is the function that is applied first to inputs of the resulting `CompositeFlow`.
The entries in `modules` may be actual functions or the names of the objects desired.
"""
function build_flow(n_dims::Integer, modules::Vector, compute_unit::AbstractComputeUnit = CPUnit())
@argcheck !any((broadcast(x -> x <: AffineMaps.AbstractAffineMap))) throw(DomainError(modules, "One or more of the specified modules are uninitailized and depend on the target input. Please use `build_flow(target_samples, modules)` to initialize modules depending on the target_samples."))
flow_modules = Function[flow_module isa Function ? flow_module : flow_module(n_dims, compute_unit = compute_unit) for flow_module in modules]

isa_flow = broadcast(flow_module -> flow_module isa AbstractFlow, flow_modules)
broadcast!(flow_module -> FlowModule(flow_module, _is_trainable(flow_module)), flow_modules[.! isa_flow])

return CompositeFlow(flow_modules)
end

function build_flow(target_samples::AbstractArray, modules::Vector = [InvMulAdd, RQSplineCouplingModule], compute_unit::AbstractComputeUnit = CPUnit())
# n_dims = target_samples isa Matrix ? size(target_samples, 1) : (target_samples isa ArrayOfSimilarArrays ? size(target_samples.data, 1) : throw(DomainError(target_samples, "Please input the target samples either as a `Matrix` or an `ArrayOfSimilarArrays`")))

flat_samples = flatview(target_samples)
n_dims = size(flat_samples, 1)

trainable =_is_trainable.(modules)
flow_modules = Vector{Function}(undef, length(modules))

if any(trainable)
flow_modules[trainable] = Function[flow_module isa Function ? flow_module : flow_module(n_dims, compute_unit = compute_unit) for flow_module in modules[trainable]]
end

if !trainable[1]
stds = vec(std(flat_samples, dims = 2))
means = vec(mean(flat_samples, dims = 2))

flow_modules[1] = modules[1] isa Function ? typeof(modules[1])(Diagonal(stds), means) : modules[1](Diagonal(stds), means)
end

for (i, flow_module) in enumerate(modules[2:end])
if !trainable[i + 1]
y_intermediate = fchain(flow_modules[1:i])(flat_samples)
stds = vec(std(y_intermediate, dims = 2))
means = vec(mean(y_intermediate, dims = 2))

flow_modules[i + 1] = flow_module isa Function ? typeof(flow_module)(Diagonal(stds), means) : flow_module(Diagonal(stds), means)
end
end

isa_flow = broadcast(flow_module -> flow_module isa AbstractFlow, flow_modules)
#broadcast!(flow_module -> FlowModule(flow_module), flow_modules[.! isa_flow])

flow_modules[.! isa_flow] = broadcast(flow_module -> FlowModule(flow_module, _is_trainable(flow_module)), flow_modules[.! isa_flow])

CompositeFlow(flow_modules)
end
export build_flow

function _is_trainable(flow)

if flow isa FlowModule && !flow.trainable
return false
end

if flow isa CompositeFlow
return any(_is_trainable.(flow.flow.fs))
end

if typeof(flow) <: Function
return flow isa AbstractFlowModule
end

return flow <: AbstractFlowModule
end
104 changes: 71 additions & 33 deletions src/optimize_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@ function mvnormal_negll_flow(flow::F, x::AbstractMatrix{<:Real}) where F<:Abstra
return -ll
end

function mvnormal_negll_flow(flow::B, X::AbstractMatrix{<:Real}) where B<:AbstractFlowBlock
nsamples = size(X, 2)
"""
mvnormal_negll_flow(flow::B, x::AbstractMatrix{<:Real}) where B<:AbstractFlowBlock
Calculate the negative log-likelihood (under a multivariate standard normal distribution) of the result
of applying `flow` to `x`.
"""
function mvnormal_negll_flow(flow::B, x::AbstractMatrix{<:Real}) where B<:AbstractFlowBlock
nsamples = size(x, 2)

Y, ladj = with_logabsdet_jacobian(flow, X)
ll = (sum(std_normal_logpdf.(Y[flow.mask,:])) + sum(ladj)) / nsamples
y, ladj = with_logabsdet_jacobian(flow, x)
ll = (sum(std_normal_logpdf.(y[flow.mask,:])) + sum(ladj)) / nsamples

return -ll
end
Expand All @@ -41,7 +47,7 @@ end
export mvnormal_negll_flow_grad

"""
optimize_flow(smpls::VectorOfSimilarVectors{<:Real},
optimize_flow(samples::VectorOfSimilarVectors{<:Real},
initial_flow::F where F<:AbstractFlow,
optimizer;
nbatches::Integer = 100,
Expand All @@ -51,94 +57,126 @@ export mvnormal_negll_flow_grad
shuffle_samples::Bool = false
)
Use `optimizer` to optimize the normalizing flow `initial_flow` to optimally transform `smpls` to follow
Use `optimizer` to optimize the normalizing flow `initial_flow` to optimally transform `samples` to follow
a multivariate standard normal distribution. Use `nbatches` and `nepochs` respectively to specify the
number of batches and epochs to use during training.
If desired, set `shuffle_samples` to `true` to have the samples be shuffled in between epochs. This may
improve the training, but increase training time.
Returns a tuple `(optimized_flow, final_optimizer_state, loss_history)` where `loss_history` is a vector
containing the values of the loss function during training.
Returns a `NamedTuple` `(result = optimized_flow, optimizer_state = final_optimizer_state, loss_hist = loss_history)` where `loss_history` is a vector
containing the values of the loss function for each iteration during training.
"""
function optimize_flow(smpls::VectorOfSimilarVectors{<:Real},
function optimize_flow(samples::AbstractArray,
initial_flow::F where F<:AbstractFlow,
optimizer;
nbatches::Integer = 100,
nbatches::Integer = 10,
nepochs::Integer = 100,
optstate = Optimisers.setup(optimizer, deepcopy(initial_flow)),
loss_history = Vector{Float64}(),
shuffle_samples::Bool = false
)
if initial_flow isa ScaleShiftModule
if !_is_trainable(initial_flow)
return (result = initial_flow, optimizer_state = nothing, loss_history = nothing)
end

batchsize = round(Int, length(smpls) / nbatches)
batches = collect(Iterators.partition(smpls, batchsize))
batchsize = round(Int, length(samples) / nbatches)
batches = collect(Iterators.partition(samples, batchsize))
flow = deepcopy(initial_flow)
state = deepcopy(optstate)
loss_hist = Vector{Float64}()

for i in 1:nepochs
for batch in batches
loss, d_flow = mvnormal_negll_flow_grad(flow, flatview(batch))
state, flow = Optimisers.update(state, flow, d_flow)
push!(loss_hist, loss)
end
if shuffle_samples
batches = collect(Iterators.partition(shuffle(smpls), batchsize))
batches = collect(Iterators.partition(shuffle(samples), batchsize))
end
end
(result = flow, optimizer_state = state, loss_history = vcat(loss_history, loss_hist))

return (result = flow, optimizer_state = state, loss_hist = vcat(loss_history, loss_hist))
end
export optimize_flow

"""
optimize_flow_sequentially(samples::AbstractArray,
initial_flow::CompositeFlow,
optimizer;
nbatches::Integer = 100,
nepochs::Integer = 100,
shuffle_samples::Bool = false
)
function optimize_flow_sequentially(smpls::VectorOfSimilarVectors{<:Real},
Use `optimizer` to optimize the normalizing flow `initial_flow` to optimally transform `samples` to follow
a multivariate standard normal distribution.
In contrast to `optimize_flow()`, this function optimizes each component of `initial_flow` in sequence. Meaning
that at first, the first component of `initial_flow` is optimized, then the resulting optimized component is applied
to the input samples, which are then used to optimize the second component and so on.
If a component of `initial_flow` is itself a composite of several component functions, these sub-component functions
are also optimized sequentially.
Use `nbatches` and `nepochs` respectively to specify the number of batches and epochs to use during training.
If desired, set `shuffle_samples` to `true` to have the samples be shuffled in between epochs. This may
improve the training, but increase training time.
Returns a `NamedTuple` `(result = optimized_flow, optimizer_states = final_optimizer_states, loss_hists = loss_histories)` where `loss_hists` is a vector
containing vectors of the values of the loss function for each iteration during training for each of the components of the input flow.
"""
function optimize_flow_sequentially(samples::AbstractArray,
initial_flow::CompositeFlow,
optimizer;
nbatches::Integer = 100,
nbatches::Integer = 10,
nepochs::Integer = 100,
shuffle_samples::Bool = false
)

optimized_modules = Vector{AbstractFlow}(undef, length(initial_flow.flow.fs))
module_optimizer_states = Vector(undef, length(initial_flow.flow.fs))
module_loss_hists = Vector{Vector}(undef, length(initial_flow.flow.fs))
module_loss_histories = Vector(undef, length(initial_flow.flow.fs))

intermediate_samples = flatview(samples)

for (i,flow_module) in enumerate(initial_flow.flow.fs)
opt_module, opt_state, loss_hist = optimize_flow_sequentially(smpls, flow_module, optimizer; nbatches, nepochs, shuffle_samples)
opt_module, opt_state, loss_hist = optimize_flow_sequentially(intermediate_samples, flow_module, optimizer; nbatches, nepochs, shuffle_samples)
optimized_modules[i] = opt_module
module_optimizer_states[i] = opt_state
module_loss_hists[i] = loss_hist
module_loss_histories[i] = loss_hist

intermediate_samples = opt_module(intermediate_samples)
end

(result = CompositeFlow(optimized_modules), optimizer_states = module_optimizer_states, loss_histories = module_loss_hists)
return (result = CompositeFlow(optimized_modules), optimizer_states = module_optimizer_states, loss_hists = module_loss_histories)
end

function optimize_flow_sequentially(smpls::VectorOfSimilarVectors{<:Real},
function optimize_flow_sequentially(samples::AbstractArray,
initial_flow::M where M<:AbstractFlowModule,
optimizer;
nbatches::Integer = 100,
nbatches::Integer = 10,
nepochs::Integer = 100,
shuffle_samples::Bool = false
)
@argcheck !(initial_flow isa AbstractFlowBlock) throw DomainError("The input flow is an individual flow block, please use `optimize_flow()`[@ref] to optimize flow blocks.")

if initial_flow isa ScaleShiftModule
if !_is_trainable(initial_flow)
return (result = initial_flow, optimizer_states = nothing, loss_hists = nothing)
end

optimized_blocks = Vector{AbstractFlow}(undef, length(initial_flow.flow_module.fs))
block_optimizer_states = Vector{NamedTuple}(undef, length(initial_flow.flow_module.fs))
block_loss_hists = Vector{Vector}(undef, length(initial_flow.flow_module.fs))
optimized_blocks = Vector{Function}(undef, length(initial_flow.flow.fs))
block_optimizer_states = Vector{NamedTuple}(undef, length(initial_flow.flow.fs))
block_loss_histories = Vector{Vector}(undef, length(initial_flow.flow.fs))

intermediate_samples = samples

for (i, block) in enumerate(initial_flow.flow.fs)
optimized_block, optimizer_state, loss_history = optimize_flow(nestedview(intermediate_samples), block, optimizer; nbatches, nepochs, shuffle_samples = shuffle_samples)
optimized_blocks[i] = optimized_block
block_optimizer_states[i] = optimizer_state
block_loss_histories[i] = loss_history

for (i,block) in enumerate(initial_flow.flow_module.fs)
opt_flow, opt_state, loss_hist = optimize_flow(smpls, block, optimizer; nbatches, nepochs, shuffle_samples = shuffle_samples)
optimized_blocks[i] = opt_flow
block_optimizer_states[i] = opt_state
block_loss_hists[i] = loss_hist
intermediate_samples = optimized_block(intermediate_samples)
end

(result = typeof(initial_flow)(optimized_blocks), optimizer_states = block_optimizer_states, loss_hists = block_loss_hists)
return (result = typeof(initial_flow)(optimized_blocks), optimizer_states = block_optimizer_states, loss_hists = block_loss_histories)
end
export optimize_flow_sequentially
Loading

0 comments on commit 653ecbb

Please sign in to comment.