-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove custom rrule for _with_ladj_on_mapped #21
Conversation
Tricky to implement a correct rrule here that handles tangents which contain NoTangent or ZeroTangent.
Closes #20 |
We can always add a (better) rrule again later, but better a slightly slower AD than one that fails (#20). |
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## master #21 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 3 2 -1
Lines 61 56 -5
=========================================
- Hits 61 56 -5
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report in Codecov by Sentry. |
@@ -107,9 +107,13 @@ function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj::Tuple{Any,Real}) where | |||
return y_with_ladj | |||
end | |||
|
|||
_get_all_first(x) = map(first, x) | |||
# Use x -> x[2] instead of last, using last causes horrible performance in Zygote here: | |||
_sum_over_second(x) = sum(x -> x[2], x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Woaaah, really? Do you have any idea why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea - something in Zygote? Maybe because the default pullback for last needs to store which index was targeted, and doesn't to that efficiently? IDK ... in any case, using x -> x[2]
seems to work fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have the benchmark code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, with this test code (see #4)
using ChangesOfVariables, LinearAlgebra, Zygote, BenchmarkTools
function foo(xs)
ys, ladj = with_logabsdet_jacobian(Base.Fix1(broadcast, log), xs)
dot(ys, ys) + ladj
end
grad_foo(xs) = Zygote.gradient(foo, xs)[1]
xs = rand(10^3);
grad_foo(xs);
@benchmark grad_foo($xs)
Current package version:
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 18.846 μs … 2.534 ms ┊ GC (min … max): 0.00% … 92.30%
Time (median): 25.390 μs ┊ GC (median): 0.00%
Time (mean ± σ): 31.439 μs ± 66.282 μs ┊ GC (mean ± σ): 11.07% ± 5.28%
This PR - a bit slower, but covers the NoTangent
case correctly:
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 25.685 μs … 1.801 ms ┊ GC (min … max): 0.00% … 95.85%
Time (median): 37.468 μs ┊ GC (median): 0.00%
Time (mean ± σ): 47.158 μs ± 86.955 μs ┊ GC (mean ± σ): 11.88% ± 6.31%
But this PR with _sum_over_second(x) = sum(last, x)
instead of _sum_over_second(x) = sum(x -> x[2], x)
- Zygote performance is horrible:
BenchmarkTools.Trial: 4059 samples with 1 evaluation.
Range (min … max): 901.134 μs … 5.065 ms ┊ GC (min … max): 0.00% … 75.73%
Time (median): 1.178 ms ┊ GC (median): 0.00%
Time (mean ± σ): 1.229 ms ± 358.982 μs ┊ GC (mean ± σ): 3.25% ± 8.32%
I think I did not understand the discussion in #20 correctly - but doesn't #20 (comment) indicate that the rule is fine in so far as rules are not guaranteed currently anyway to handle |
Ah, that discussion was mainly about That discussion was about a rule I mentioned takes a So for now, I'd like to drop our custom rrule and take the moderate performance hit in exchange for correctness. We can always an improved version back later, after all. |
If you have no objections, I'll merge and release this, @devmotion . |
Closes #20 . |
Tricky to implement a correct rrule here that handles tangents which contain NoTangent or ZeroTangent.