Remove custom rrule for _with_ladj_on_mapped#21
Conversation
Tricky to implement a correct rrule here that handles tangents which contain NoTangent or ZeroTangent.
|
Closes #20 |
|
We can always add a (better) rrule again later, but better a slightly slower AD than one that fails (#20). |
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## master #21 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 3 2 -1
Lines 61 56 -5
=========================================
- Hits 61 56 -5
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report in Codecov by Sentry. |
|
|
||
| _get_all_first(x) = map(first, x) | ||
| # Use x -> x[2] instead of last, using last causes horrible performance in Zygote here: | ||
| _sum_over_second(x) = sum(x -> x[2], x) |
There was a problem hiding this comment.
Woaaah, really? Do you have any idea why?
There was a problem hiding this comment.
No idea - something in Zygote? Maybe because the default pullback for last needs to store which index was targeted, and doesn't to that efficiently? IDK ... in any case, using x -> x[2] seems to work fine.
There was a problem hiding this comment.
Sure, with this test code (see #4)
using ChangesOfVariables, LinearAlgebra, Zygote, BenchmarkTools
function foo(xs)
ys, ladj = with_logabsdet_jacobian(Base.Fix1(broadcast, log), xs)
dot(ys, ys) + ladj
end
grad_foo(xs) = Zygote.gradient(foo, xs)[1]
xs = rand(10^3);
grad_foo(xs);
@benchmark grad_foo($xs)Current package version:
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 18.846 μs … 2.534 ms ┊ GC (min … max): 0.00% … 92.30%
Time (median): 25.390 μs ┊ GC (median): 0.00%
Time (mean ± σ): 31.439 μs ± 66.282 μs ┊ GC (mean ± σ): 11.07% ± 5.28%
This PR - a bit slower, but covers the NoTangent case correctly:
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 25.685 μs … 1.801 ms ┊ GC (min … max): 0.00% … 95.85%
Time (median): 37.468 μs ┊ GC (median): 0.00%
Time (mean ± σ): 47.158 μs ± 86.955 μs ┊ GC (mean ± σ): 11.88% ± 6.31%
But this PR with _sum_over_second(x) = sum(last, x) instead of _sum_over_second(x) = sum(x -> x[2], x) - Zygote performance is horrible:
BenchmarkTools.Trial: 4059 samples with 1 evaluation.
Range (min … max): 901.134 μs … 5.065 ms ┊ GC (min … max): 0.00% … 75.73%
Time (median): 1.178 ms ┊ GC (median): 0.00%
Time (mean ± σ): 1.229 ms ± 358.982 μs ┊ GC (mean ± σ): 3.25% ± 8.32%
|
I think I did not understand the discussion in #20 correctly - but doesn't #20 (comment) indicate that the rule is fine in so far as rules are not guaranteed currently anyway to handle |
Ah, that discussion was mainly about That discussion was about a rule I mentioned takes a So for now, I'd like to drop our custom rrule and take the moderate performance hit in exchange for correctness. We can always an improved version back later, after all. |
|
If you have no objections, I'll merge and release this, @devmotion . |
|
Closes #20 . |
Tricky to implement a correct rrule here that handles tangents which contain NoTangent or ZeroTangent.