Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom rule in SparseMatrixCSC #2013

Closed
hochunlin opened this issue Oct 24, 2024 · 7 comments · Fixed by #2109
Closed

Custom rule in SparseMatrixCSC #2013

hochunlin opened this issue Oct 24, 2024 · 7 comments · Fixed by #2109

Comments

@hochunlin
Copy link

I am writing some reverse-mode custom rules for manipulating sparse matrices with Enzyme. However, I did not get the correct result compared to the finite-difference result. The following is the MWE:

using Enzyme
import Enzyme.EnzymeCore
using Random
using Test
using SparseArrays

function mul_internal!(S, A, B)
    S[:] = A*B
end

function mul_custom!(S, A, B)
    S[:] = A*B
end

function EnzymeRules.augmented_primal(
    config,
    func::EnzymeRules.Const{typeof(mul_custom!)},
    ::Type{RT},
    S::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    A::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    B::EnzymeCore.Const{<:AbstractArray{T,N}},
    ) where {RT,T,N}
    println("In custom augmented primal rule.")

    if typeof(S) <: EnzymeCore.Duplicated || typeof(S) <: EnzymeCore.BatchDuplicated
        func.val(S.val, A.val, B.val)
    end

    primal = if EnzymeRules.needs_primal(config)
        S.val
    else
        nothing
    end
    shadow = if EnzymeRules.needs_shadow(config)
        S.dval
    else
        nothing
    end
    return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(
    config,
    func::EnzymeRules.Const{typeof(mul_custom!)},
    ::Type{RT},
    cache,
    S::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    A::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    B::EnzymeCore.Const{<:AbstractArray{T,N}},
    ) where {RT,T,N}
    println("In custom reverse rule.")

    dys = S.dval
    dxs = A.dval
    
    if EnzymeRules.width(config) == 1
        dys = (dys,)
        dxs = (dxs,)
    end

    for (dy, dx) in zip(dys, dxs)
        if !(typeof(S) <: EnzymeCore.Const) && dy !== S.val
            if !(typeof(A) <: EnzymeCore.Const) && dx !== A.val
                dx .+=  dy * (B.val)'
            end
            dy .= 0
        end
    end
    return (nothing, nothing, nothing)
end


@testset "Test dense matrix S = A*B: Computing S[2,2] pullback from tuning A[2,1]" begin
    Random.seed!(1234) # Make the result reproducible

    tuning_ind_i = 2
    tuning_ind_j = 1
    resulting_ind_i = 2
    resulting_ind_j = 2

    # A, B, and S are dense matrices
    A_internal = rand(3,3)
    dA_internal = make_zero(A_internal)
    B_internal = rand(3,3)
    S_internal = zeros(3,3)
    dS_internal = make_zero(S_internal)
    
    dS_internal[resulting_ind_i,resulting_ind_j] = 1 # Set the pullback at S[2,2]

    A_custom = deepcopy(A_internal)
    dA_custom = deepcopy(dA_internal)
    B_custom = deepcopy(B_internal)
    S_custom = deepcopy(S_internal)
    dS_custom = deepcopy(dS_internal)

    ϵ = 1e-5
    r_matrix = zeros(3,3)
    r_matrix[tuning_ind_i,tuning_ind_j] = ϵ/2

    finite_difference = ((A_internal + r_matrix)*B_internal - (A_internal - r_matrix)*B_internal)/ϵ

    # Case 1: internal Enzyme rule for mul in dense matrix (it works)
    autodiff(
        Reverse, 
        mul_internal!, 
        Const,
        Duplicated(S_internal, dS_internal),
        Duplicated(A_internal, dA_internal), 
        Const(B_internal), 
    ) 
    dA_internal[1,1] # 0.0
    dA_internal[2,1] # 0.08344008943212289
    dA_internal[3,1] # 0.0
    dA_internal[1,2] # 0.0
    dA_internal[2,2] # 0.525795663891226
    dA_internal[3,2] # 0.0
    dA_internal[1,3] # 0.0
    dA_internal[2,3] # 0.8406409194782338
    dA_internal[3,3] # 0.0

    # 0.08344008943212289 ≈ 0.0834400894378362
    @test dA_internal[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2

    # Case 2: custom rule for mul in dense matrix (it works)
    autodiff(
        Reverse, 
        mul_custom!, 
        Const,
        Duplicated(S_custom, dS_custom),
        Duplicated(A_custom, dA_custom), 
        Const(B_custom), 
    ) 
    dA_internal[1,1] # 0.0
    dA_internal[2,1] # 0.08344008943212289
    dA_internal[3,1] # 0.0
    dA_internal[1,2] # 0.0
    dA_internal[2,2] # 0.525795663891226
    dA_internal[3,2] # 0.0
    dA_internal[1,3] # 0.0
    dA_internal[2,3] # 0.8406409194782338
    dA_internal[3,3] # 0.0
     
     # 0.08344008943212289  ≈ 0.0834400894378362
     @test dA_custom[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2
end

@testset "Test sparse matrix S = A*B: Computing S[2,2] pullback from tuning A[2,1]" begin
    Random.seed!(1234) # Make the result reproducible

    tuning_ind_i = 2
    tuning_ind_j = 1
    resulting_ind_i = 2
    resulting_ind_j = 2

    # A, B, and S are sparse matrices
    A_internal =sprand(3,3,1.0)
    dA_internal = make_zero(A_internal)
    B_internal =sprand(3,3,1.0)
    S_internal = sparse(ones(Float64,3,3))
    dS_internal = make_zero(S_internal)

    dS_internal[resulting_ind_i,resulting_ind_j] = 1 # Set the pullback at S[2,2]

    A_custom = deepcopy(A_internal)
    dA_custom = deepcopy(dA_internal)
    B_custom = deepcopy(B_internal)
    S_custom = deepcopy(S_internal)
    dS_custom = deepcopy(dS_internal)

    ϵ = 1e-5
    r_matrix = zeros(3,3)
    r_matrix[tuning_ind_i,tuning_ind_j] = ϵ/2

    finite_difference = ((A_internal + r_matrix)*B_internal - (A_internal - r_matrix)*B_internal)/ϵ

    # Case 3: internal Enzyme rule for mul in sparse matrix (it works)
    autodiff(
        Reverse, 
        mul_internal!, 
        Const,
        Duplicated(S_internal, dS_internal),
        Duplicated(A_internal, dA_internal), 
        Const(B_internal), 
    ) 
    dA_internal[1,1] # 0.0
    dA_internal[2,1] # 0.08344008943212289
    dA_internal[3,1] # 0.0
    dA_internal[1,2] # 0.0
    dA_internal[2,2] # 0.525795663891226
    dA_internal[3,2] # 0.0
    dA_internal[1,3] # 0.0
    dA_internal[2,3] # 0.8406409194782338
    dA_internal[3,3] # 0.0

    # 0.08344008943212289 ≈ 0.0834400894378362
    @test dA_internal[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2

    # Case 4: custom rule for mul in sparse matrix (it does not work)
    autodiff(
        Reverse, 
        mul_custom!, 
        Const,
        Duplicated(S_custom, dS_custom),
        Duplicated(A_custom, dA_custom), 
        Const(B_custom), 
    ) 
     dA_custom[1,1] # 0.08344008943212289 (which should be dA[2,1])
     dA_custom[2,1] # 0.525795663891226 (which should be dA[2,2])
     dA_custom[3,1] # 0.8406409194782338 (which should be dA[2,3])
     @test_throws BoundsError dA_custom[1,2] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[2,2] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[3,2] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[1,3] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[2,3] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[3,3] # ERROR: BoundsError
     
     # 0.525795663891226 ≈ 0.0834400894378362
     @test_broken dA_custom[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2
end

In this example, I tried to get the autodiff from a matrix multiplication: S = A*B, where B is a constant matrix.

I test the four cases:

Case 1. Dense matrix multiplication with internal enzyme rule
Case 2. Dense matrix multiplication with my custom enzyme rule
Case 3. Sparse matrix multiplication with internal enzyme rule
Case 4. Sparse matrix multiplication with my custom enzyme rule

The case 4 does not give the right result. The custom rule in sparse matrix multiplication does not work because the pullbacks passed to it are incorrect, such as passing dA[2,1] wrongly to dA[1,1].

Now, I am confused why case 3 (internal Enzyme rule) can work, but case 4 (custom Enzyme rule) doesn't. Is there a subtlety in the implementation of the custom rule for the sparse matrix I am missing? Thanks for any suggestion or insight in advance!

@hochunlin
Copy link
Author

Hi, I tested more on the sparse matrix. I bumped into more issues.

In the following MWE, I tried to get the autodiff from an inversion matrix: y = x⁻¹, where the element x[3,2] is changed by the parameter r (i.e. x[3,2] = x[3,2] + r):

using Enzyme
import EnzymeCore
using SparseArrays
using Test
using Random

function inv_plus_r_without_custom_rule_for_plus!(y::AbstractArray, x::AbstractArray, r::AbstractArray)
    x[3,2] = x[3,2] + r[1]
    inv!(y,x)
end

function inv_plus_r!(y::AbstractArray, x::AbstractArray, r::AbstractArray)
    x[3,2] = x[3,2] + r[1]
    inv!(y,x)
end

function inv!(y::AbstractArray, x::AbstractArray)
    y[:] = inv(Matrix(x))
end


function EnzymeRules.augmented_primal(config, func::EnzymeRules.Const{typeof(inv_plus_r!)}, ::Type{RT},
    y::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    x::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    r::EnzymeCore.Annotation{<:AbstractArray{T,1}}
) where {RT,T,N}
    println("In custom augmented primal rule in inv_plus_r!.")

    if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated
        func.val(y.val, x.val, r.val)
    end

    primal = if EnzymeRules.needs_primal(config)
        y.val
    else
        nothing
    end
    shadow = if EnzymeRules.needs_shadow(config)
        y.dval
    else
        nothing
    end

    return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(config, func::EnzymeRules.Const{typeof(inv_plus_r!)}, ::Type{RT}, cache,
    y::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    x::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    r::EnzymeCore.Annotation{<:AbstractArray{T,1}}
) where {RT,T,N}
    println("In custom reverse rule in inv_plus_r!.")

    dys = y.dval
    dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval
    if EnzymeRules.width(config) == 1
        dys = (dys,)
        dxs = (dxs,)
    end
    for (dy, dx) in zip(dys, dxs)
        if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val
            if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
                dx .+=  - inv(Matrix(x.val))' * dy * inv(Matrix(x.val))'
            end
            dy .= 0
        end
    end
    r.dval[1] = dxs[1][3,2]
    return (nothing, nothing,nothing)
end

function EnzymeRules.augmented_primal(config, func::EnzymeRules.Const{typeof(inv!)}, ::Type{RT},
    y::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    x::EnzymeCore.Annotation{<:AbstractArray{T,N}}
) where {RT,T,N}
    println("In custom augmented primal rule in inv!.")

    if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated
        func.val(y.val, x.val)
    end

    primal = if EnzymeRules.needs_primal(config)
        y.val
    else
        nothing
    end
    shadow = if EnzymeRules.needs_shadow(config)
        y.dval
    else
        nothing
    end
    return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(config, func::EnzymeRules.Const{typeof(inv!)}, ::Type{RT}, cache,
    y::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    x::EnzymeCore.Annotation{<:AbstractArray{T,N}}
) where {RT,T,N}
    println("In custom reverse rule in inv!.")

    dys = y.dval
    dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval

    if EnzymeRules.width(config) == 1
        dys = (dys,)
        dxs = (dxs,)
    end

    for (dy, dx) in zip(dys, dxs)
        if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val
            if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
                dx .+=  - inv(Matrix(x.val))' * dy * inv(Matrix(x.val))'
            end
            dy .= 0
        end
    end
    return (nothing, nothing)
end

@testset "Test: Computing y[2,2] (y = x⁻¹) pullback at x[3,2] = x[3,2] + r from x = sprand(1000,1000,probability = 0.9) with partial custom rule and whole custom rule" begin
    Random.seed!(1234) # Make the result reproducible

    x_partial_rule = sprand(Float64,100,100,0.7)
    dx_partial_rule = make_zero(x_partial_rule)
    y_partial_rule = make_zero(sparse(ones(Float64, 100,100)))
    dy_partial_rule = make_zero(y_partial_rule);
    dy_partial_rule[2, 2] = 1; # set pullback seed to [1, 1] index
    r_partial_rule = [0.0]
    dr_partial_rule = [0.0]
    

    Random.seed!(1234) # Make the result reproducible

    x_whole_rule = sprand(Float64,100,100,0.7)
    dx_whole_rule = make_zero(x_whole_rule)
    y_whole_rule = make_zero(sparse(ones(Float64, 100,100)))
    dy_whole_rule = make_zero(y_whole_rule);
    dy_whole_rule[2, 2] = 1; # set pullback seed to [1, 1] index
    r_whole_rule = [0.0]
    dr_whole_rule = [0.0]

    delta = spzeros(size(x_whole_rule));
    fd_delta = 1e-5;
    delta[3, 2] = fd_delta;
    delta;
    grad_fd = (inv(Matrix(x_whole_rule .+ delta / 2)) - inv(Matrix(x_whole_rule .- delta / 2)) ) / fd_delta

    Enzyme.autodiff(
        Reverse,
        inv_plus_r_without_custom_rule_for_plus!,
        Const,
        Duplicated(y_partial_rule, dy_partial_rule),
        Duplicated(x_partial_rule, dx_partial_rule),
        Duplicated(r_partial_rule, dr_partial_rule),
    )

    Enzyme.autodiff(
        Reverse,
        inv_plus_r!,
        Const,
        Duplicated(y_whole_rule, dy_whole_rule),
        Duplicated(x_whole_rule, dx_whole_rule),
        Duplicated(r_whole_rule, dr_whole_rule),
    )

    @test dx_partial_rule[3,2] ≈ grad_fd[2, 2] rtol = 1e-3
    @test_broken dr_partial_rule[1] ≈ grad_fd[2, 2] rtol = 1e-3
    @test_broken dr_partial_rule[1] ≈ dx_partial_rule[3,2] rtol = 1e-3

    @test dx_whole_rule[3,2] ≈ grad_fd[2, 2] rtol = 1e-3
    @test dr_whole_rule[1] ≈ grad_fd[2, 2] rtol = 1e-3
    @test dr_whole_rule[1] ≈ dx_whole_rule[3,2] rtol = 1e-3
end

Here I compute the pullback by writing

  • the custom rule for the only inversion part
  • the custom rule for the inversion and addition part

I expect the pullback dx[3,2] = dr in this simple case. However, if I just wrote the custom rule for the only inversion part, then dx[3,2] != dr.

After testing, this issue also only happened in the sparse matrices, any thought and insight would be appreciated. Thanks!

@ChrisRackauckas
Copy link
Contributor

CC DARPA-ASKEM/sciml-service#182 SciML/SciMLSensitivity.jl#1139 it seems a lot of things hit this.

@wsmoses is a custom rule the right direction here? mul! is directly Julia defined so it's a bit odd that it doesn't work right out of the box.

@wsmoses
Copy link
Member

wsmoses commented Nov 20, 2024

cc @ptiede who was looking at the sparse rules recently

@ptiede
Copy link
Contributor

ptiede commented Nov 20, 2024

OK so the SciML issue is just a duplicate of #1970

The reason for the custom rule, @ChrisRackauckas, was that Enzyme's native gradient was slow for the sparse matmul #1682, so I implemented a custom rule.

The problem with the rule I defined is that it explicitly does not handle the case where the sparse matrix is duplicated. This is entirely because I lacked the cycles to implement this properly since it requires something like ChainRules.ProjectTo. The naive gradient accumulation

dx .+=  dy * (B.val)'

does not preserve the correct sparsity pattern for dx. This is stated in the internal rule in Enzyme in a comment

I had initially hoped that since the internal rule was written for when the sparse matrix activity is Const, Enzyme would only try to use a rule for mul! when the activities matched. Unfortunately, I was wrong. Once you define a rule for a function signature with a specific set of activities, it expects it to be defined for all activities.

This is why @hochunlin rules are incorrect; they need to preserve the matrix structure. For sparse matrices this is subtle. Specifically, the returned gradient looks weird because Enzyme is now pointing to different memory since dx .+= automatically changes the underlying sparse matrix representation.

I'll add a Duplicated SparseMatrixCSC rule, but it won't be optimal since I don't have the cycles for that right now.

@ChrisRackauckas
Copy link
Contributor

I'll add a Duplicated SparseMatrixCSC rule, but it won't be optimal since I don't have the cycles for that right now.

That's fine. I think the key for us is that this is likely the last thing blocking getting a full end-to-end application in PDEs working (I think), so a fallback to just differentiate the code directly is actually what I was hoping someone could come up with, at least to see if the full applications are working.

does not preserve the correct sparsity pattern for dx. This is stated in the internal rule in Enzyme in a comment

Yeah if you couldn't tell I'm not a fan of this behavior.

@ptiede
Copy link
Contributor

ptiede commented Nov 20, 2024

Ok PR #2109 should fix the SciML issue @ChrisRackauckas. @hochunlin, does the above discussion clarify why the custom rules you wrote aren't correct? The main thing is that broadcasting does not preserve the structural zeros. On the brightside, PR #2109 makes it so that your matmul rule likely isn't needed anymore. For inv,

dx .+=  - inv(Matrix(x.val))' * dy * inv(Matrix(x.val))'

will need to change.

@hochunlin
Copy link
Author

@ptiede Yes. It is clear that broadcasting does not preserve the structural zeros, and it causes the issue. Then how should I change the custom rule inv to make it right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants