From dd295d6a92103c72b702bb3b4ac6cb449ccc28ab Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 4 Apr 2023 21:54:25 +0200 Subject: [PATCH] Remove custom rrule for _with_ladj_on_mapped Tricky to implement a correct rrule here that handles tangents which contain NoTangent or ZeroTangent. --- Project.toml | 12 +----------- README.md | 2 +- ext/ChangesOfVariablesChainRulesCoreExt.jl | 20 -------------------- src/ChangesOfVariables.jl | 3 --- src/with_ladj.jl | 8 ++++++-- test/test_with_ladj.jl | 7 ------- 6 files changed, 8 insertions(+), 44 deletions(-) delete mode 100644 ext/ChangesOfVariablesChainRulesCoreExt.jl diff --git a/Project.toml b/Project.toml index 3c9826c..aa39dce 100644 --- a/Project.toml +++ b/Project.toml @@ -3,25 +3,15 @@ uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" version = "0.1.6" [deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - -[extensions] -ChangesOfVariablesChainRulesCoreExt = "ChainRulesCore" - [compat] -ChainRulesCore = "1" julia = "1" [extras] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" [targets] -test = ["ChainRulesCore", "ChainRulesTestUtils", "Documenter", "ForwardDiff"] +test = ["Documenter", "ForwardDiff"] diff --git a/README.md b/README.md index 6eccec7..11c636c 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ changes for functions that perform a change of variables (like coordinate transformations). `ChangesOfVariables` is a very lightweight package and has no dependencies -beyond `Base`, `LinearAlgebra`, `Test` and `ChainRulesCore`. +beyond `Base`, `LinearAlgebra`, `Test`. ## Documentation diff --git a/ext/ChangesOfVariablesChainRulesCoreExt.jl b/ext/ChangesOfVariablesChainRulesCoreExt.jl deleted file mode 100644 index 8c7c5a5..0000000 --- a/ext/ChangesOfVariablesChainRulesCoreExt.jl +++ /dev/null @@ -1,20 +0,0 @@ -module ChangesOfVariablesChainRulesCoreExt - -using ChainRulesCore - -import ChangesOfVariables: _with_ladj_on_mapped - -# Need to use a type for this, type inference fails when using a pullback -# closure over YLT in the rrule, resulting in bad performance: -struct WithLadjOnMappedPullback{YLT} <: Function end -function (::WithLadjOnMappedPullback{YLT})(thunked_ΔΩ) where YLT - ys, ladj = unthunk(thunked_ΔΩ) - return NoTangent(), NoTangent(), map(y -> Tangent{YLT}(y, ladj), ys) -end - -function ChainRulesCore.rrule(::typeof(_with_ladj_on_mapped), map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}} - YLT = eltype(y_with_ladj) - return _with_ladj_on_mapped(map_or_bc, y_with_ladj), WithLadjOnMappedPullback{YLT}() -end - -end diff --git a/src/ChangesOfVariables.jl b/src/ChangesOfVariables.jl index b246a91..b2b034d 100644 --- a/src/ChangesOfVariables.jl +++ b/src/ChangesOfVariables.jl @@ -14,8 +14,5 @@ using Test include("with_ladj.jl") include("test.jl") -if !isdefined(Base, :get_extension) - include("../ext/ChangesOfVariablesChainRulesCoreExt.jl") -end end # module diff --git a/src/with_ladj.jl b/src/with_ladj.jl index 33d40e0..90732b0 100644 --- a/src/with_ladj.jl +++ b/src/with_ladj.jl @@ -107,9 +107,13 @@ function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj::Tuple{Any,Real}) where return y_with_ladj end +_get_all_first(x) = map(first, x) +# Use x -> x[2] instead of last, using last causes horrible performance in Zygote here: +_sum_over_second(x) = sum(x -> x[2], x) + function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}} - y = map_or_bc(first, y_with_ladj) - ladj = sum(last, y_with_ladj) + y = _get_all_first(y_with_ladj) + ladj = _sum_over_second(y_with_ladj) (y, ladj) end diff --git a/test/test_with_ladj.jl b/test/test_with_ladj.jl index 3cd83c0..f27ec86 100644 --- a/test/test_with_ladj.jl +++ b/test/test_with_ladj.jl @@ -7,7 +7,6 @@ using LinearAlgebra using ChangesOfVariables using ChangesOfVariables: test_with_logabsdet_jacobian -using ChainRulesTestUtils include("getjacobian.jl") @@ -66,10 +65,4 @@ include("getjacobian.jl") test_with_logabsdet_jacobian(f, x, getjacobian) end end - - @testset "rrules" begin - for map_or_bc in (map, broadcast) - test_rrule(ChangesOfVariables._with_ladj_on_mapped, map_or_bc, [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]) - end - end end