Closed
Description
This is more a question than a bug report.
I defined the following rrule
for Fill
with _map
(Zygote does not allow playing with map :) )
_map(f, args...) = map(f, args...)
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f::Tf, x::F) where {Tf,F<:Fill}
y_el, back = ChainRulesCore.rrule_via_ad(config, f, x.value)
function _map_Fill_rrule(Δ)
Δf, Δx_el = back(Δ.value)
return NoTangent(), Δf, Tangent{F}(value = Δx_el, axes = NoTangent())
end
return Fill(y_el, axes(x)), _map_Fill_rrule
end
The result seems correct but I cannot call test_rrule
on it:
test_rrule(_map, sum, Fill(randn(3, 4), 4))
The error narrows down to the jacobian
function from FiniteDifferences
trying to differentiate through the axes
field of Fill
.
I tried to pass a Tangent
to Fill
via ⊢ Tangent{typeof(x)}(value=randn(3, 4), axes=NoTangent())
but without success...
Could you help me figure out what I need to do?
Metadata
Metadata
Assignees
Labels
No labels