Skip to content

Commit

Permalink
Refactor optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Micki-D committed Feb 3, 2024
1 parent 758803f commit 3246a88
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 143 deletions.
18 changes: 14 additions & 4 deletions src/adaptive_flows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,17 @@ function ChangesOfVariables.with_logabsdet_jacobian(
return vec(y), ladj[1]
end

function ChangesOfVariables.with_logabsdet_jacobian(
f::F,
x::ArrayOfSimilarArrays
) where F <: AbstractFlow
y, ladj = with_logabsdet_jacobian(f, flatview(x))
return nestedview(y), ladj
end

(f::AbstractFlow)(x::Matrix) = f.flow(x)
(f::AbstractFlow)(x::Vector) = vec(f.flow(reshape(x, :, 1)))
(f::AbstractFlow)(x::ArrayOfSimilarArrays) = nestedview(f(flatview(x)))
(f::AbstractFlow)(vs::AbstractValueShape) = vs

function InverseFunctions.inverse(f::CompositeFlow)
Expand Down Expand Up @@ -117,15 +126,16 @@ function InverseFunctions.inverse(f::FlowModule)
end

"""
AbstractFlowBlock <: AbstractFlowModule
AbstractFlowBlock <: AbstractFlow
A flow block is a normalizing flow that can only transform a fraction of the dimensions of
a multidimensional input.
To transform all components of the input, several flow blocks must be composed to a flow
module (see `AbstractFlowModule`).
"""
abstract type AbstractFlowBlock <: AbstractFlowModule
abstract type AbstractFlowBlock <: AbstractFlow
end
# is not a subtype of AbstractFlowModule to facilitate distinction in flow optimization

export AbstractFlowBlock

Expand Down Expand Up @@ -239,8 +249,8 @@ function _is_trainable(flow)
end

if typeof(flow) <: Function
return flow isa AbstractFlowModule
return flow isa AbstractFlowModule || flow isa AbstractFlowBlock
end

return flow <: AbstractFlowModule
return flow <: AbstractFlowModule || flow <: AbstractFlowBlock
end
Loading

0 comments on commit 3246a88

Please sign in to comment.