Skip to content

Introduce a isa_wrapped_array function #460

Open
@avik-pal

Description

@avik-pal

Most details are in EnzymeAD/Reactant.jl#369 (comment). I will copy over the important parts.

We introduce a function isa_wrapped_array that downstream packages can use to mark that their array type wraps another array. Using a union type from Adapt doesn't solve this problem, because that fundamentally doesn't extend to new array types.

With this function, we can override functions inside our custom interpreter. Consider this simple example of extending LinearAlgebra.diag

Base.Experimental.@overlay REACTANT_METHOD_TABLE function LinearAlgebra.diag(
    x::AbstractArray{T,2}, k::Integer=0
) where {T}
    if isa_wrapped_array(x) && ancestor(x) isa TracedRArray
        y = materialize_traced_array(x) # convert it to a known type
        return diag(y, k)
    else
        # invoke diag(x) on NativeInterpreter
    end
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions