Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove custom rrule for _with_ladj_on_mapped #21

Merged
merged 1 commit into from
Apr 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,15 @@ uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.6"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[extensions]
ChangesOfVariablesChainRulesCoreExt = "ChainRulesCore"

[compat]
ChainRulesCore = "1"
julia = "1"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[targets]
test = ["ChainRulesCore", "ChainRulesTestUtils", "Documenter", "ForwardDiff"]
test = ["Documenter", "ForwardDiff"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ changes for functions that perform a change of variables (like coordinate
transformations).

`ChangesOfVariables` is a very lightweight package and has no dependencies
beyond `Base`, `LinearAlgebra`, `Test` and `ChainRulesCore`.
beyond `Base`, `LinearAlgebra`, `Test`.

## Documentation

Expand Down
20 changes: 0 additions & 20 deletions ext/ChangesOfVariablesChainRulesCoreExt.jl

This file was deleted.

3 changes: 0 additions & 3 deletions src/ChangesOfVariables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,5 @@ using Test

include("with_ladj.jl")
include("test.jl")
if !isdefined(Base, :get_extension)
include("../ext/ChangesOfVariablesChainRulesCoreExt.jl")
end

end # module
8 changes: 6 additions & 2 deletions src/with_ladj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,13 @@ function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj::Tuple{Any,Real}) where
return y_with_ladj
end

_get_all_first(x) = map(first, x)
# Use x -> x[2] instead of last, using last causes horrible performance in Zygote here:
_sum_over_second(x) = sum(x -> x[2], x)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woaaah, really? Do you have any idea why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea - something in Zygote? Maybe because the default pullback for last needs to store which index was targeted, and doesn't to that efficiently? IDK ... in any case, using x -> x[2] seems to work fine.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have the benchmark code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, with this test code (see #4)

using ChangesOfVariables, LinearAlgebra, Zygote, BenchmarkTools

function foo(xs)
    ys, ladj = with_logabsdet_jacobian(Base.Fix1(broadcast, log), xs)
    dot(ys, ys) + ladj
end

grad_foo(xs) = Zygote.gradient(foo, xs)[1]

xs = rand(10^3);
grad_foo(xs);

@benchmark grad_foo($xs)

Current package version:

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  18.846 μs …  2.534 ms  ┊ GC (min … max):  0.00% … 92.30%
 Time  (median):     25.390 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   31.439 μs ± 66.282 μs  ┊ GC (mean ± σ):  11.07% ±  5.28%

This PR - a bit slower, but covers the NoTangent case correctly:

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  25.685 μs …  1.801 ms  ┊ GC (min … max):  0.00% … 95.85%
 Time  (median):     37.468 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   47.158 μs ± 86.955 μs  ┊ GC (mean ± σ):  11.88% ±  6.31%

But this PR with _sum_over_second(x) = sum(last, x) instead of _sum_over_second(x) = sum(x -> x[2], x) - Zygote performance is horrible:

BenchmarkTools.Trial: 4059 samples with 1 evaluation.
 Range (min … max):  901.134 μs …   5.065 ms  ┊ GC (min … max): 0.00% … 75.73%
 Time  (median):       1.178 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.229 ms ± 358.982 μs  ┊ GC (mean ± σ):  3.25% ±  8.32%


function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}
y = map_or_bc(first, y_with_ladj)
ladj = sum(last, y_with_ladj)
y = _get_all_first(y_with_ladj)
ladj = _sum_over_second(y_with_ladj)
(y, ladj)
end

Expand Down
7 changes: 0 additions & 7 deletions test/test_with_ladj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using LinearAlgebra

using ChangesOfVariables
using ChangesOfVariables: test_with_logabsdet_jacobian
using ChainRulesTestUtils

include("getjacobian.jl")

Expand Down Expand Up @@ -66,10 +65,4 @@ include("getjacobian.jl")
test_with_logabsdet_jacobian(f, x, getjacobian)
end
end

@testset "rrules" begin
for map_or_bc in (map, broadcast)
test_rrule(ChangesOfVariables._with_ladj_on_mapped, map_or_bc, [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)])
end
end
end