Skip to content

Commit

Permalink
Add mm_unbalanced function (#22)
Browse files Browse the repository at this point in the history
* Add mm_unbalanced function

* Update api.md

* Update src/lib.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/lib.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/lib.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/lib.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/lib.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/lib.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/lib.jl

Co-authored-by: David Widmann <[email protected]>

* Add doctest in mm_unbalanced function

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
pnavaro and devmotion authored Jul 10, 2024
1 parent a3c1d24 commit 6510029
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PythonOT"
uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
authors = ["David Widmann"]
version = "0.1.5"
version = "0.1.6"

[deps]
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ PythonOT.Smooth.smooth_ot_dual
sinkhorn_unbalanced
sinkhorn_unbalanced2
barycenter_unbalanced
mm_unbalanced
```
3 changes: 2 additions & 1 deletion src/PythonOT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export emd,
barycenter_unbalanced,
sinkhorn_unbalanced,
sinkhorn_unbalanced2,
empirical_sinkhorn_divergence
empirical_sinkhorn_divergence,
mm_unbalanced

const pot = PyCall.PyNULL()

Expand Down
76 changes: 63 additions & 13 deletions src/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,11 @@ julia> ν = [0.0, 1.0];
julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
julia> sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000)
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
3×2 Matrix{Float64}:
0.0 0.499964
0.0 0.200188
0.0 0.29983
0.0 0.5
0.0 0.2002
0.0 0.2998
```
It is possible to provide multiple target marginals as columns of a matrix. In this case the
Expand All @@ -325,10 +325,10 @@ optimal transport costs are returned:
```jldoctest sinkhorn_unbalanced
julia> ν = [0.0 0.5; 1.0 0.5];
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=6)
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
2-element Vector{Float64}:
0.949709
0.449411
0.9497
0.4494
```
See also: [`sinkhorn_unbalanced2`](@ref)
Expand Down Expand Up @@ -371,20 +371,19 @@ julia> ν = [0.0, 1.0];
julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
1-element Vector{Float64}:
0.949709
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
0.9497
```
It is possible to provide multiple target marginals as columns of a matrix:
```jldoctest sinkhorn_unbalanced2
julia> ν = [0.0 0.5; 1.0 0.5];
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
2-element Vector{Float64}:
0.949709
0.449411
0.9497
0.4494
```
See also: [`sinkhorn_unbalanced`](@ref)
Expand Down Expand Up @@ -516,3 +515,54 @@ Python function.
function entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss="square_loss"; kwargs...)
return pot.gromov.entropic_gromov_wasserstein(Cμ, Cν, μ, ν, loss, ε; kwargs...)
end

"""
mm_unbalanced(a, b, M, reg_m; reg=0, c=a*b', kwargs...)
Solve the unbalanced optimal transport problem and return the OT plan.
The function solves the following optimization problem:
```math
W = \\min_{\\gamma \\geq 0} \\langle \\gamma, M \\rangle_F +
\\mathrm{reg_{m1}} \\cdot \\operatorname{div}(\\gamma \\mathbf{1}, a) +
\\mathrm{reg_{m2}} \\cdot \\operatorname{div}(\\gamma^\\mathsf{T} \\mathbf{1}, b) +
\\mathrm{reg} \\cdot \\operatorname{div}(\\gamma, c)
```
where
- `M` is the metric cost matrix,
- `a` and `b` are source and target unbalanced distributions,
- `c` is a reference distribution for the regularization,
- `reg_m` is the marginal relaxation term (if it is a scalar or an indexable object of length 1, then the same term is applied to both marginal relaxations), and
- `reg` is a regularization term.
This function is a wrapper of the function
[`mm_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.mm_unbalanced) in the
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
Python function.
# Examples
```jldoctest
julia> a=[.5, .5];
julia> b=[.5, .5];
julia> M=[1. 36.; 9. 4.];
julia> round.(mm_unbalanced(a, b, M, 5, div="kl"), digits=2)
2×2 Matrix{Float64}:
0.45 0.0
0.0 0.34
julia> round.(mm_unbalanced(a, b, M, 5, div="l2"), digits=2)
2×2 Matrix{Float64}:
0.4 0.0
0.0 0.1
```
"""
function mm_unbalanced(a, b, M, reg_m; kwargs...)
return pot.unbalanced.mm_unbalanced(a, b, M, reg_m; kwargs...)
end

2 comments on commit 6510029

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/110799

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.6 -m "<description of version>" 6510029ae259947a9d419e6dd542f77db766cb99
git push origin v0.1.6

Please sign in to comment.