Skip to content

ChainRules pullback doesn't support ZeroTangent. #20

@torfjelde

Description

@torfjelde

It seemsthe rrule for mapped/broadcasted f doesn't support ZeroTangent:

julia> Zygote.gradient(x -> Bijectors.with_logabsdet_jacobian(stacked, x)[2], randn(3))
ERROR: MethodError: no method matching length(::ChainRulesCore.ZeroTangent)
Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
  length(::Union{DataStructures.OrderedRobinDict, DataStructures.RobinDict}) at ~/.julia/packages/DataStructures/59MD0/src/ordered_robin_dict.jl:86
  length(::Union{DataStructures.SortedDict, DataStructures.SortedMultiDict, DataStructures.SortedSet}) at ~/.julia/packages/DataStructures/59MD0/src/container_loops.jl:322
  ...
Stacktrace:
  [1] length(g::Base.Generator{ChainRulesCore.ZeroTangent, ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.var"#1#2"{Tuple{Float64, Float64}, Float64}})
    @ Base ./generator.jl:50
  [2] _similar_shape(itr::Base.Generator{ChainRulesCore.ZeroTangent, ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.var"#1#2"{Tuple{Float64, Float64}, Float64}}, #unused#::Base.HasLength)
    @ Base ./array.jl:663
  [3] collect(itr::Base.Generator{ChainRulesCore.ZeroTangent, ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.var"#1#2"{Tuple{Float64, Float64}, Float64}})
    @ Base ./array.jl:786
  [4] map(f::Function, A::ChainRulesCore.ZeroTangent)
    @ Base ./abstractarray.jl:2961
  [5] (::ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.WithLadjOnMappedPullback{Tuple{Float64, Float64}})(thunked_ΔΩ::ChainRulesCore.Tangent{Any, Tuple{ChainRulesCore.ZeroTangent, Float64}})
    @ ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt ~/.julia/packages/ChangesOfVariables/qC6bf/ext/ChangesOfVariablesChainRulesCoreExt.jl:12
  [6] ZBack
    @ ~/.julia/packages/Zygote/TSj5C/src/compiler/chainrules.jl:211 [inlined]
  [7] Pullback
    @ ~/.julia/packages/ChangesOfVariables/qC6bf/src/with_ladj.jl:121 [inlined]
  [8] (::Zygote.Pullback{Tuple{typeof(with_logabsdet_jacobian), Base.Fix1{typeof(broadcast), typeof(exp)}, Vector{Float64}}, Tuple{Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:x, Zygote.Context{false}, Base.Fix1{typeof(broadcast), typeof(exp)}, typeof(exp)}}, Zygote.Pullback{Tuple{Type{Base.Fix1}, typeof(with_logabsdet_jacobian), typeof(exp)}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#420"}, Zygote.Pullback{Tuple{typeof(Base._stable_typeof), typeof(exp)}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.var"#2176#back#309"{Zygote.Jnew{Base.Fix1{typeof(with_logabsdet_jacobian), typeof(exp)}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(convert), Type{typeof(with_logabsdet_jacobian)}, typeof(with_logabsdet_jacobian)}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{typeof(exp)}, typeof(exp)}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#420"}}}, Zygote.var"#4118#back#1381"{Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), Base.Fix1{typeof(with_logabsdet_jacobian), typeof(exp)}, Vector{Float64}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4086#back#1368"{Zygote.var"#∇broadcasted#1364"{Tuple{Vector{Float64}}, Vector{Tuple{Tuple{Float64, Float64}, Zygote.var"#2379#back#440"{Zygote.Pullback{Tuple{Zygote.var"#fallback_Fix1#439"{typeof(exp), typeof(with_logabsdet_jacobian)}, Float64}, Tuple{Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:x, Zygote.Context{false}, Zygote.var"#fallback_Fix1#439"{typeof(exp), typeof(with_logabsdet_jacobian)}, typeof(exp)}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:f, Zygote.Context{false}, Zygote.var"#fallback_Fix1#439"{typeof(exp), typeof(with_logabsdet_jacobian)}, typeof(with_logabsdet_jacobian)}}, Zygote.Pullback{Tuple{typeof(with_logabsdet_jacobian), typeof(exp), Float64}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.ZBack{ChainRules.var"#exp_pullback#1319"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}}}}}}, Val{2}}}}}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#1982#back#200"{typeof(identity)}}}}, Zygote.ZBack{ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.WithLadjOnMappedPullback{Tuple{Float64, Float64}}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:f, Zygote.Context{false}, Base.Fix1{typeof(broadcast), typeof(exp)}, typeof(broadcast)}}}})(Δ::Tuple{Nothing, Float64})
    @ Zygote ./compiler/interface2.jl:0
...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions