Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing the Greenkhorn #159

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

davibarreira
Copy link
Member

This PR is related to #151 .
I've implemented the Greenkhorn algorithm which is a greedy version of the Sinkhorn algorithm. The method I've implemented is actually the one in POT, which is a bit different from the one in the original paper (https://arxiv.org/pdf/1705.09634.pdf).

The implementation needs improvement. I was not able to get it to work with AD and with the batch tests. I was not very involved in the coding of the original Sinkhorn algorithm, and had some difficulty getting around the whole step!, solve! and cache structure.

Another point. The iteration in the Greenkhorn algorithm only updates one row of u and v each time, thus, it needs many more iterations in order to converge compared to the original Sinkhorn. Some preliminar benchmarks showed that this implementation of the Greenkhorn is slower then the original Sinkhorn_Gibbs, which seems to contradict the claims in the paper. I believe the reason for this might be that the Sinkhorn implementation is very optimized in the package compared to my version of Greenkhorn. Another possibility is that the Sinkhorn version of the paper was not very efficient ( if you read it here, they present the Sinkhron algorithm which computer diagm(u) K diagm(v) in each iteration).

I've compared the results from my algorithm against POT, and indeed it seems to be returning the exact same result each iteration, i.e. it seems that the Greenkhorn implementation is correct but not optimal.

@coveralls
Copy link

coveralls commented Jan 16, 2022

Pull Request Test Coverage Report for Build 1704769432

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 45 of 46 (97.83%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.2%) to 95.588%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/entropic/greenkhorn.jl 45 46 97.83%
Totals Coverage Status
Change from base Build 1620585433: 0.2%
Covered Lines: 650
Relevant Lines: 680

💛 - Coveralls

@codecov-commenter
Copy link

Codecov Report

Merging #159 (82b1026) into master (de56119) will increase coverage by 0.16%.
The diff coverage is 97.82%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #159      +/-   ##
==========================================
+ Coverage   95.42%   95.58%   +0.16%     
==========================================
  Files          14       15       +1     
  Lines         634      680      +46     
==========================================
+ Hits          605      650      +45     
- Misses         29       30       +1     
Impacted Files Coverage Δ
src/entropic/greenkhorn.jl 97.82% <97.82%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update de56119...82b1026. Read the comment docs.

Comment on lines +1 to +3
# Greenkhorn is a greedy version of the Sinkhorn algorithm
# This method is from https://arxiv.org/pdf/1705.09634.pdf
# Code is based on implementation from package POT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to describe what the differences are (if there are any apart from implementation details).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there are. The paper implementation is actually just a couple of lines commented out. I'll point out in the code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that it's already in the code. Inside the step! function.

u::U
v::V
K::KT
Kv::U #placeholder
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "placeholder"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a partial solution. Without this Kv, I got an error. It seems that the Sinkhorn structs further on require it. I could not find out how to get rid of it without changing the code for Sinkhorn.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you use it explicitly below for checking convergence?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's been used to check the convergence. But I think the convergence check might be done another way more efficiently. But I was having trouble getting it to work with the existing "api" for Sinkhorn, so I updated Kv and used the convergence verification already in place.

You are right that, as is, Kv is not actually just a placeholder, since it's been used in the convergence.

Comment on lines +87 to +88
i₁ = argmax(abs.(Δμ))
i₂ = argmax(abs.(Δν))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
i₁ = argmax(abs.(Δμ))
i₂ = argmax(abs.(Δν))
Δμ_max, Δμ_max_idx = findmax(abs, Δμ)
Δν_max, Δν_max_idx = findmax(abs, Δν)

Comment on lines +90 to +91
# if ρμ[i₁]> ρν[i₂]
if abs(Δμ[i₁]) > abs(Δν[i₂])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# if ρμ[i₁]> ρν[i₂]
if abs(Δμ[i₁]) > abs(Δν[i₂])
if Δμ_max > Δν_max


# if ρμ[i₁]> ρν[i₂]
if abs(Δμ[i₁]) > abs(Δν[i₂])
old_u = u[i₁]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also has to be changed for batch support it seems.

Suggested change
old_u = u[i₁]
old_u = u[Δμ_max_idx]

# if ρμ[i₁]> ρν[i₂]
if abs(Δμ[i₁]) > abs(Δν[i₂])
old_u = u[i₁]
u[i₁] = μ[i₁]/ (K[i₁,:] ⋅ v)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
u[i₁] = μ[i₁]/ (K[i₁,:] v)
u[Δμ_max_idx] = μ[Δμ_max_idx] / dot(K[Δμ_max_idx, :], v)

It would be better to select columns instead of rows and to use views. Julia uses column major order, so a column is close in memory and hence accessing columns is faster.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True! I'll try to come up with something.

if abs(Δμ[i₁]) > abs(Δν[i₂])
old_u = u[i₁]
u[i₁] = μ[i₁]/ (K[i₁,:] ⋅ v)
Δ = u[i₁] - old_u
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Δ = u[i₁] - old_u
Δ = u[Δμ_max_idx] - old_u

old_u = u[i₁]
u[i₁] = μ[i₁]/ (K[i₁,:] ⋅ v)
Δ = u[i₁] - old_u
G[i₁, :] = u[i₁] * K[i₁,:] .* v
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, better to work with columns than with rows. Also some unnecessary allocations here:

Suggested change
G[i₁, :] = u[i₁] * K[i₁,:] .* v
G[Δμ_max_idx, :] .= u[Δμ_max_idx] .* K[Δμ_max_idx, :] .* v

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the unnecessary allocation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh there are multiple unnecessary allocations.

First of all, K[i₁, :] creates, i.e., allocates, a row vector. Then u[i₁] * K[i₁, :] scales it and allocates a new row vector. And finally u[i₁] * K[i₁,:] .* v multiplies the entries of u[i₁] * K[i₁,:] elementwise with v and allocates yet another row vector.

Whereas the alternative suggestion allocates only K[i₁, :] (could be avoided by using a view) and then fuses all multiplications and writes the result directly to G without allocating any other row vector.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the answer, and sorry for the bad code. I'm still very crude with code optimization.

u[i₁] = μ[i₁]/ (K[i₁,:] ⋅ v)
Δ = u[i₁] - old_u
G[i₁, :] = u[i₁] * K[i₁,:] .* v
Δμ[i₁] = u[i₁] * (K[i₁,:] ⋅ v) - μ[i₁]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Δ = u[i₁] - old_u
G[i₁, :] = u[i₁] * K[i₁,:] .* v
Δμ[i₁] = u[i₁] * (K[i₁,:] ⋅ v) - μ[i₁]
@. Δν = Δν + Δ * K[i₁,:] * v
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here.

G[i₁, :] = u[i₁] * K[i₁,:] .* v
Δμ[i₁] = u[i₁] * (K[i₁,:] ⋅ v) - μ[i₁]
@. Δν = Δν + Δ * K[i₁,:] * v
else
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments in the second branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants