Skip to content

Custom rule in SparseMatrixCSC #2013

Closed
@hochunlin

Description

@hochunlin

I am writing some reverse-mode custom rules for manipulating sparse matrices with Enzyme. However, I did not get the correct result compared to the finite-difference result. The following is the MWE:

using Enzyme
import Enzyme.EnzymeCore
using Random
using Test
using SparseArrays

function mul_internal!(S, A, B)
    S[:] = A*B
end

function mul_custom!(S, A, B)
    S[:] = A*B
end

function EnzymeRules.augmented_primal(
    config,
    func::EnzymeRules.Const{typeof(mul_custom!)},
    ::Type{RT},
    S::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    A::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    B::EnzymeCore.Const{<:AbstractArray{T,N}},
    ) where {RT,T,N}
    println("In custom augmented primal rule.")

    if typeof(S) <: EnzymeCore.Duplicated || typeof(S) <: EnzymeCore.BatchDuplicated
        func.val(S.val, A.val, B.val)
    end

    primal = if EnzymeRules.needs_primal(config)
        S.val
    else
        nothing
    end
    shadow = if EnzymeRules.needs_shadow(config)
        S.dval
    else
        nothing
    end
    return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(
    config,
    func::EnzymeRules.Const{typeof(mul_custom!)},
    ::Type{RT},
    cache,
    S::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    A::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    B::EnzymeCore.Const{<:AbstractArray{T,N}},
    ) where {RT,T,N}
    println("In custom reverse rule.")

    dys = S.dval
    dxs = A.dval
    
    if EnzymeRules.width(config) == 1
        dys = (dys,)
        dxs = (dxs,)
    end

    for (dy, dx) in zip(dys, dxs)
        if !(typeof(S) <: EnzymeCore.Const) && dy !== S.val
            if !(typeof(A) <: EnzymeCore.Const) && dx !== A.val
                dx .+=  dy * (B.val)'
            end
            dy .= 0
        end
    end
    return (nothing, nothing, nothing)
end


@testset "Test dense matrix S = A*B: Computing S[2,2] pullback from tuning A[2,1]" begin
    Random.seed!(1234) # Make the result reproducible

    tuning_ind_i = 2
    tuning_ind_j = 1
    resulting_ind_i = 2
    resulting_ind_j = 2

    # A, B, and S are dense matrices
    A_internal = rand(3,3)
    dA_internal = make_zero(A_internal)
    B_internal = rand(3,3)
    S_internal = zeros(3,3)
    dS_internal = make_zero(S_internal)
    
    dS_internal[resulting_ind_i,resulting_ind_j] = 1 # Set the pullback at S[2,2]

    A_custom = deepcopy(A_internal)
    dA_custom = deepcopy(dA_internal)
    B_custom = deepcopy(B_internal)
    S_custom = deepcopy(S_internal)
    dS_custom = deepcopy(dS_internal)

    ϵ = 1e-5
    r_matrix = zeros(3,3)
    r_matrix[tuning_ind_i,tuning_ind_j] = ϵ/2

    finite_difference = ((A_internal + r_matrix)*B_internal - (A_internal - r_matrix)*B_internal)/ϵ

    # Case 1: internal Enzyme rule for mul in dense matrix (it works)
    autodiff(
        Reverse, 
        mul_internal!, 
        Const,
        Duplicated(S_internal, dS_internal),
        Duplicated(A_internal, dA_internal), 
        Const(B_internal), 
    ) 
    dA_internal[1,1] # 0.0
    dA_internal[2,1] # 0.08344008943212289
    dA_internal[3,1] # 0.0
    dA_internal[1,2] # 0.0
    dA_internal[2,2] # 0.525795663891226
    dA_internal[3,2] # 0.0
    dA_internal[1,3] # 0.0
    dA_internal[2,3] # 0.8406409194782338
    dA_internal[3,3] # 0.0

    # 0.08344008943212289 ≈ 0.0834400894378362
    @test dA_internal[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2

    # Case 2: custom rule for mul in dense matrix (it works)
    autodiff(
        Reverse, 
        mul_custom!, 
        Const,
        Duplicated(S_custom, dS_custom),
        Duplicated(A_custom, dA_custom), 
        Const(B_custom), 
    ) 
    dA_internal[1,1] # 0.0
    dA_internal[2,1] # 0.08344008943212289
    dA_internal[3,1] # 0.0
    dA_internal[1,2] # 0.0
    dA_internal[2,2] # 0.525795663891226
    dA_internal[3,2] # 0.0
    dA_internal[1,3] # 0.0
    dA_internal[2,3] # 0.8406409194782338
    dA_internal[3,3] # 0.0
     
     # 0.08344008943212289  ≈ 0.0834400894378362
     @test dA_custom[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2
end

@testset "Test sparse matrix S = A*B: Computing S[2,2] pullback from tuning A[2,1]" begin
    Random.seed!(1234) # Make the result reproducible

    tuning_ind_i = 2
    tuning_ind_j = 1
    resulting_ind_i = 2
    resulting_ind_j = 2

    # A, B, and S are sparse matrices
    A_internal =sprand(3,3,1.0)
    dA_internal = make_zero(A_internal)
    B_internal =sprand(3,3,1.0)
    S_internal = sparse(ones(Float64,3,3))
    dS_internal = make_zero(S_internal)

    dS_internal[resulting_ind_i,resulting_ind_j] = 1 # Set the pullback at S[2,2]

    A_custom = deepcopy(A_internal)
    dA_custom = deepcopy(dA_internal)
    B_custom = deepcopy(B_internal)
    S_custom = deepcopy(S_internal)
    dS_custom = deepcopy(dS_internal)

    ϵ = 1e-5
    r_matrix = zeros(3,3)
    r_matrix[tuning_ind_i,tuning_ind_j] = ϵ/2

    finite_difference = ((A_internal + r_matrix)*B_internal - (A_internal - r_matrix)*B_internal)/ϵ

    # Case 3: internal Enzyme rule for mul in sparse matrix (it works)
    autodiff(
        Reverse, 
        mul_internal!, 
        Const,
        Duplicated(S_internal, dS_internal),
        Duplicated(A_internal, dA_internal), 
        Const(B_internal), 
    ) 
    dA_internal[1,1] # 0.0
    dA_internal[2,1] # 0.08344008943212289
    dA_internal[3,1] # 0.0
    dA_internal[1,2] # 0.0
    dA_internal[2,2] # 0.525795663891226
    dA_internal[3,2] # 0.0
    dA_internal[1,3] # 0.0
    dA_internal[2,3] # 0.8406409194782338
    dA_internal[3,3] # 0.0

    # 0.08344008943212289 ≈ 0.0834400894378362
    @test dA_internal[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2

    # Case 4: custom rule for mul in sparse matrix (it does not work)
    autodiff(
        Reverse, 
        mul_custom!, 
        Const,
        Duplicated(S_custom, dS_custom),
        Duplicated(A_custom, dA_custom), 
        Const(B_custom), 
    ) 
     dA_custom[1,1] # 0.08344008943212289 (which should be dA[2,1])
     dA_custom[2,1] # 0.525795663891226 (which should be dA[2,2])
     dA_custom[3,1] # 0.8406409194782338 (which should be dA[2,3])
     @test_throws BoundsError dA_custom[1,2] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[2,2] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[3,2] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[1,3] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[2,3] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[3,3] # ERROR: BoundsError
     
     # 0.525795663891226 ≈ 0.0834400894378362
     @test_broken dA_custom[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2
end

In this example, I tried to get the autodiff from a matrix multiplication: S = A*B, where B is a constant matrix.

I test the four cases:

Case 1. Dense matrix multiplication with internal enzyme rule
Case 2. Dense matrix multiplication with my custom enzyme rule
Case 3. Sparse matrix multiplication with internal enzyme rule
Case 4. Sparse matrix multiplication with my custom enzyme rule

The case 4 does not give the right result. The custom rule in sparse matrix multiplication does not work because the pullbacks passed to it are incorrect, such as passing dA[2,1] wrongly to dA[1,1].

Now, I am confused why case 3 (internal Enzyme rule) can work, but case 4 (custom Enzyme rule) doesn't. Is there a subtlety in the implementation of the custom rule for the sparse matrix I am missing? Thanks for any suggestion or insight in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions