Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove custom rrule for _with_ladj_on_mapped #21

Merged
merged 1 commit into from
Apr 29, 2023
Merged

Conversation

oschulz
Copy link
Collaborator

@oschulz oschulz commented Apr 4, 2023

Tricky to implement a correct rrule here that handles tangents which contain NoTangent or ZeroTangent.

Tricky to implement a correct rrule here that handles
tangents which contain NoTangent or ZeroTangent.
@oschulz
Copy link
Collaborator Author

oschulz commented Apr 4, 2023

Closes #20

@oschulz oschulz requested a review from devmotion April 4, 2023 20:31
@oschulz
Copy link
Collaborator Author

oschulz commented Apr 4, 2023

We can always add a (better) rrule again later, but better a slightly slower AD than one that fails (#20).

@codecov
Copy link

codecov bot commented Apr 4, 2023

Codecov Report

Patch coverage: 100.00% and no project coverage change.

Comparison is base (1c265b5) 100.00% compared to head (dd295d6) 100.00%.

Additional details and impacted files
@@            Coverage Diff            @@
##            master       #21   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files            3         2    -1     
  Lines           61        56    -5     
=========================================
- Hits            61        56    -5     
Impacted Files Coverage Δ
src/with_ladj.jl 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@@ -107,9 +107,13 @@ function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj::Tuple{Any,Real}) where
return y_with_ladj
end

_get_all_first(x) = map(first, x)
# Use x -> x[2] instead of last, using last causes horrible performance in Zygote here:
_sum_over_second(x) = sum(x -> x[2], x)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woaaah, really? Do you have any idea why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea - something in Zygote? Maybe because the default pullback for last needs to store which index was targeted, and doesn't to that efficiently? IDK ... in any case, using x -> x[2] seems to work fine.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have the benchmark code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, with this test code (see #4)

using ChangesOfVariables, LinearAlgebra, Zygote, BenchmarkTools

function foo(xs)
    ys, ladj = with_logabsdet_jacobian(Base.Fix1(broadcast, log), xs)
    dot(ys, ys) + ladj
end

grad_foo(xs) = Zygote.gradient(foo, xs)[1]

xs = rand(10^3);
grad_foo(xs);

@benchmark grad_foo($xs)

Current package version:

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  18.846 μs …  2.534 ms  ┊ GC (min … max):  0.00% … 92.30%
 Time  (median):     25.390 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   31.439 μs ± 66.282 μs  ┊ GC (mean ± σ):  11.07% ±  5.28%

This PR - a bit slower, but covers the NoTangent case correctly:

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  25.685 μs …  1.801 ms  ┊ GC (min … max):  0.00% … 95.85%
 Time  (median):     37.468 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   47.158 μs ± 86.955 μs  ┊ GC (mean ± σ):  11.88% ±  6.31%

But this PR with _sum_over_second(x) = sum(last, x) instead of _sum_over_second(x) = sum(x -> x[2], x) - Zygote performance is horrible:

BenchmarkTools.Trial: 4059 samples with 1 evaluation.
 Range (min … max):  901.134 μs …   5.065 ms  ┊ GC (min … max): 0.00% … 75.73%
 Time  (median):       1.178 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.229 ms ± 358.982 μs  ┊ GC (mean ± σ):  3.25% ±  8.32%

@devmotion
Copy link
Member

I think I did not understand the discussion in #20 correctly - but doesn't #20 (comment) indicate that the rule is fine in so far as rules are not guaranteed currently anyway to handle NoTangent/ZeroTangent explicitly? And that the whole issue is rather a bug in Zygote which calls a pullback with a ZeroTangent?

@oschulz
Copy link
Collaborator Author

oschulz commented Apr 7, 2023

indicate that the rule is fine in so far as rules are not guaranteed currently anyway to handle NoTangent/ZeroTangent explicitly

Ah, that discussion was mainly about That discussion was about a rule I mentioned takes a NoTangent but returns a ZeroTangent. Custom rules definitely have to be handle NoTangent/ZeroTangent correctly themselves, at least in some cases. In our case definitely, since only the tangent on y is NoTangent in @torfjelde's case, not the tangent on ladj. So the pullback doesn't see NoTangent, is sees a Tuple{NoTangent, <:Real}. I've tried, and it not trivial to handle this right, esp. if one want to preserve array types (keep GPU stuff on the GPU, etc.).

So for now, I'd like to drop our custom rrule and take the moderate performance hit in exchange for correctness. We can always an improved version back later, after all.

@oschulz
Copy link
Collaborator Author

oschulz commented Apr 21, 2023

If you have no objections, I'll merge and release this, @devmotion .

@oschulz oschulz merged commit dd295d6 into master Apr 29, 2023
@oschulz oschulz deleted the no-custom-rrule branch April 29, 2023 12:56
@oschulz
Copy link
Collaborator Author

oschulz commented Apr 29, 2023

@oschulz
Copy link
Collaborator Author

oschulz commented Apr 29, 2023

Closes #20 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants