Description
I am writing some reverse-mode custom rules for manipulating sparse matrices with Enzyme. However, I did not get the correct result compared to the finite-difference result. The following is the MWE:
using Enzyme
import Enzyme.EnzymeCore
using Random
using Test
using SparseArrays
function mul_internal!(S, A, B)
S[:] = A*B
end
function mul_custom!(S, A, B)
S[:] = A*B
end
function EnzymeRules.augmented_primal(
config,
func::EnzymeRules.Const{typeof(mul_custom!)},
::Type{RT},
S::EnzymeCore.Annotation{<:AbstractArray{T,N}},
A::EnzymeCore.Annotation{<:AbstractArray{T,N}},
B::EnzymeCore.Const{<:AbstractArray{T,N}},
) where {RT,T,N}
println("In custom augmented primal rule.")
if typeof(S) <: EnzymeCore.Duplicated || typeof(S) <: EnzymeCore.BatchDuplicated
func.val(S.val, A.val, B.val)
end
primal = if EnzymeRules.needs_primal(config)
S.val
else
nothing
end
shadow = if EnzymeRules.needs_shadow(config)
S.dval
else
nothing
end
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end
function EnzymeRules.reverse(
config,
func::EnzymeRules.Const{typeof(mul_custom!)},
::Type{RT},
cache,
S::EnzymeCore.Annotation{<:AbstractArray{T,N}},
A::EnzymeCore.Annotation{<:AbstractArray{T,N}},
B::EnzymeCore.Const{<:AbstractArray{T,N}},
) where {RT,T,N}
println("In custom reverse rule.")
dys = S.dval
dxs = A.dval
if EnzymeRules.width(config) == 1
dys = (dys,)
dxs = (dxs,)
end
for (dy, dx) in zip(dys, dxs)
if !(typeof(S) <: EnzymeCore.Const) && dy !== S.val
if !(typeof(A) <: EnzymeCore.Const) && dx !== A.val
dx .+= dy * (B.val)'
end
dy .= 0
end
end
return (nothing, nothing, nothing)
end
@testset "Test dense matrix S = A*B: Computing S[2,2] pullback from tuning A[2,1]" begin
Random.seed!(1234) # Make the result reproducible
tuning_ind_i = 2
tuning_ind_j = 1
resulting_ind_i = 2
resulting_ind_j = 2
# A, B, and S are dense matrices
A_internal = rand(3,3)
dA_internal = make_zero(A_internal)
B_internal = rand(3,3)
S_internal = zeros(3,3)
dS_internal = make_zero(S_internal)
dS_internal[resulting_ind_i,resulting_ind_j] = 1 # Set the pullback at S[2,2]
A_custom = deepcopy(A_internal)
dA_custom = deepcopy(dA_internal)
B_custom = deepcopy(B_internal)
S_custom = deepcopy(S_internal)
dS_custom = deepcopy(dS_internal)
ϵ = 1e-5
r_matrix = zeros(3,3)
r_matrix[tuning_ind_i,tuning_ind_j] = ϵ/2
finite_difference = ((A_internal + r_matrix)*B_internal - (A_internal - r_matrix)*B_internal)/ϵ
# Case 1: internal Enzyme rule for mul in dense matrix (it works)
autodiff(
Reverse,
mul_internal!,
Const,
Duplicated(S_internal, dS_internal),
Duplicated(A_internal, dA_internal),
Const(B_internal),
)
dA_internal[1,1] # 0.0
dA_internal[2,1] # 0.08344008943212289
dA_internal[3,1] # 0.0
dA_internal[1,2] # 0.0
dA_internal[2,2] # 0.525795663891226
dA_internal[3,2] # 0.0
dA_internal[1,3] # 0.0
dA_internal[2,3] # 0.8406409194782338
dA_internal[3,3] # 0.0
# 0.08344008943212289 ≈ 0.0834400894378362
@test dA_internal[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2
# Case 2: custom rule for mul in dense matrix (it works)
autodiff(
Reverse,
mul_custom!,
Const,
Duplicated(S_custom, dS_custom),
Duplicated(A_custom, dA_custom),
Const(B_custom),
)
dA_internal[1,1] # 0.0
dA_internal[2,1] # 0.08344008943212289
dA_internal[3,1] # 0.0
dA_internal[1,2] # 0.0
dA_internal[2,2] # 0.525795663891226
dA_internal[3,2] # 0.0
dA_internal[1,3] # 0.0
dA_internal[2,3] # 0.8406409194782338
dA_internal[3,3] # 0.0
# 0.08344008943212289 ≈ 0.0834400894378362
@test dA_custom[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2
end
@testset "Test sparse matrix S = A*B: Computing S[2,2] pullback from tuning A[2,1]" begin
Random.seed!(1234) # Make the result reproducible
tuning_ind_i = 2
tuning_ind_j = 1
resulting_ind_i = 2
resulting_ind_j = 2
# A, B, and S are sparse matrices
A_internal =sprand(3,3,1.0)
dA_internal = make_zero(A_internal)
B_internal =sprand(3,3,1.0)
S_internal = sparse(ones(Float64,3,3))
dS_internal = make_zero(S_internal)
dS_internal[resulting_ind_i,resulting_ind_j] = 1 # Set the pullback at S[2,2]
A_custom = deepcopy(A_internal)
dA_custom = deepcopy(dA_internal)
B_custom = deepcopy(B_internal)
S_custom = deepcopy(S_internal)
dS_custom = deepcopy(dS_internal)
ϵ = 1e-5
r_matrix = zeros(3,3)
r_matrix[tuning_ind_i,tuning_ind_j] = ϵ/2
finite_difference = ((A_internal + r_matrix)*B_internal - (A_internal - r_matrix)*B_internal)/ϵ
# Case 3: internal Enzyme rule for mul in sparse matrix (it works)
autodiff(
Reverse,
mul_internal!,
Const,
Duplicated(S_internal, dS_internal),
Duplicated(A_internal, dA_internal),
Const(B_internal),
)
dA_internal[1,1] # 0.0
dA_internal[2,1] # 0.08344008943212289
dA_internal[3,1] # 0.0
dA_internal[1,2] # 0.0
dA_internal[2,2] # 0.525795663891226
dA_internal[3,2] # 0.0
dA_internal[1,3] # 0.0
dA_internal[2,3] # 0.8406409194782338
dA_internal[3,3] # 0.0
# 0.08344008943212289 ≈ 0.0834400894378362
@test dA_internal[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2
# Case 4: custom rule for mul in sparse matrix (it does not work)
autodiff(
Reverse,
mul_custom!,
Const,
Duplicated(S_custom, dS_custom),
Duplicated(A_custom, dA_custom),
Const(B_custom),
)
dA_custom[1,1] # 0.08344008943212289 (which should be dA[2,1])
dA_custom[2,1] # 0.525795663891226 (which should be dA[2,2])
dA_custom[3,1] # 0.8406409194782338 (which should be dA[2,3])
@test_throws BoundsError dA_custom[1,2] # ERROR: BoundsError
@test_throws BoundsError dA_custom[2,2] # ERROR: BoundsError
@test_throws BoundsError dA_custom[3,2] # ERROR: BoundsError
@test_throws BoundsError dA_custom[1,3] # ERROR: BoundsError
@test_throws BoundsError dA_custom[2,3] # ERROR: BoundsError
@test_throws BoundsError dA_custom[3,3] # ERROR: BoundsError
# 0.525795663891226 ≈ 0.0834400894378362
@test_broken dA_custom[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2
end
In this example, I tried to get the autodiff from a matrix multiplication: S = A*B, where B is a constant matrix.
I test the four cases:
Case 1. Dense matrix multiplication with internal enzyme rule
Case 2. Dense matrix multiplication with my custom enzyme rule
Case 3. Sparse matrix multiplication with internal enzyme rule
Case 4. Sparse matrix multiplication with my custom enzyme rule
The case 4 does not give the right result. The custom rule in sparse matrix multiplication does not work because the pullbacks passed to it are incorrect, such as passing dA[2,1] wrongly to dA[1,1].
Now, I am confused why case 3 (internal Enzyme rule) can work, but case 4 (custom Enzyme rule) doesn't. Is there a subtlety in the implementation of the custom rule for the sparse matrix I am missing? Thanks for any suggestion or insight in advance!