Skip to content

test_rrule trying to find Tangent on axes #268

Closed
@theogf

Description

@theogf

This is more a question than a bug report.

I defined the following rrule for Fill with _map (Zygote does not allow playing with map :) )

_map(f, args...) = map(f, args...)
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f::Tf, x::F) where {Tf,F<:Fill}
    y_el, back = ChainRulesCore.rrule_via_ad(config, f, x.value)
    function _map_Fill_rrule(Δ)
        Δf, Δx_el = back.value)
        return NoTangent(), Δf, Tangent{F}(value = Δx_el, axes = NoTangent())
    end
    return Fill(y_el, axes(x)), _map_Fill_rrule
end

The result seems correct but I cannot call test_rrule on it:

test_rrule(_map, sum, Fill(randn(3, 4), 4))

The error narrows down to the jacobian function from FiniteDifferences trying to differentiate through the axes field of Fill.

I tried to pass a Tangent to Fill via ⊢ Tangent{typeof(x)}(value=randn(3, 4), axes=NoTangent()) but without success...

Could you help me figure out what I need to do?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions