Skip to content

Commit

Permalink
Unbalanced barycenter (#11)
Browse files Browse the repository at this point in the history
Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
zsteve and devmotion authored May 24, 2021
1 parent b41007d commit 53806cf
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PythonOT"
uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
authors = ["David Widmann"]
version = "0.1.1"
version = "0.1.2"

[deps]
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ PythonOT.Smooth.smooth_ot_dual
```@docs
sinkhorn_unbalanced
sinkhorn_unbalanced2
barycenter_unbalanced
```
9 changes: 8 additions & 1 deletion src/PythonOT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@ module PythonOT

using PyCall: PyCall

export emd, emd2, sinkhorn, sinkhorn2, barycenter, sinkhorn_unbalanced, sinkhorn_unbalanced2
export emd,
emd2,
sinkhorn,
sinkhorn2,
barycenter,
barycenter_unbalanced,
sinkhorn_unbalanced,
sinkhorn_unbalanced2

const pot = PyCall.PyNULL()

Expand Down
45 changes: 45 additions & 0 deletions src/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,48 @@ true
```
"""
barycenter(A, C, ε; kwargs...) = pot.barycenter(A, C, ε; kwargs...)

"""
barycenter_unbalanced(A, C, ε, λ; kwargs...)
Compute the entropically regularized unbalanced Wasserstein barycenter with histograms `A`, cost matrix
`C`, entropic regularization parameter `ε` and marginal relaxation parameter `λ`.
The Wasserstein barycenter is a histogram and solves
```math
\\inf_{a} \\sum_{i} W_{\\varepsilon,C,\\lambda}(a, a_i),
```
where the histograms ``a_i`` are columns of matrix `A` and ``W_{\\varepsilon,C,\\lambda}(a, a_i)}``
is the optimal transport cost for the entropically regularized optimal transport problem
with marginals ``a`` and ``a_i``, cost matrix ``C``, entropic regularization parameter
``\\varepsilon`` and marginal relaxation parameter ``\\lambda``. Optionally, weights of the histograms ``a_i`` can be provided with the
keyword argument `weights`.
This function is a wrapper of the function
[`barycenter_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.barycenter_unbalanced) in the
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
Python function.
# Examples
```jldoctest
julia> A = rand(10, 3);
julia> A ./= sum(A; dims=1);
julia> C = rand(10, 10);
julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1; method="sinkhorn_stabilized")), 1; atol=1e-4)
false
julia> isapprox(sum(barycenter_unbalanced(
A, C, 0.01, 10_000; method="sinkhorn_stabilized", numItermax=5_000
)), 1; atol=1e-4)
true
```
See also: [`barycenter`](@ref)
"""
function barycenter_unbalanced(A, C, ε, λ; kwargs...)
return pot.barycenter_unbalanced(A, C, ε, λ; kwargs...)
end

2 comments on commit 53806cf

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/37387

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.2 -m "<description of version>" 53806cff373c8ad0229b23ecd36822493fa842f6
git push origin v0.1.2

Please sign in to comment.