Skip to content

Commit

Permalink
Add rules for det and logdet of Cholesky
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed May 10, 2022
1 parent c5dbe03 commit 54995d6
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.29.0"
version = "1.30.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
4 changes: 2 additions & 2 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ end
##### `det`
#####

function frule((_, Δx), ::typeof(det), x::AbstractMatrix)
function frule((_, Δx), ::typeof(det), x::StridedMatrix{<:Number})
Ω = det(x)
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
return Ω, Ω * tr(x \ Δx)
end
frule((_, Δx), ::typeof(det), x::Number) = (det(x), Δx)

function rrule(::typeof(det), x::Union{Number, AbstractMatrix})
function rrule(::typeof(det), x::Union{Number, StridedMatrix{<:Number}})
Ω = det(x)
function det_pullback(ΔΩ)
∂x = x isa Number ? ΔΩ : inv(x)' * dot(Ω, ΔΩ)
Expand Down
21 changes: 21 additions & 0 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,3 +551,24 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky}
end
return getproperty(F, x), getproperty_cholesky_pullback
end

# `det` and `logdet` for `Cholesky`
function rrule(::typeof(det), C::Cholesky)
y = det(C)
s = conj!((2 * y) ./ _diag_view(C.factors))
function det_Cholesky_pullback(ȳ)
ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s))
return NoTangent(), ΔC
end
return y, det_Cholesky_pullback
end

function rrule(::typeof(logdet), C::Cholesky)
y = logdet(C)
s = conj!((2 * one(eltype(C))) ./ _diag_view(C.factors))
function logdet_Cholesky_pullback(ȳ)
ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s))
return NoTangent(), ΔC
end
return y, logdet_Cholesky_pullback
end
19 changes: 19 additions & 0 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -432,5 +432,24 @@ end
ΔX_symmetric = chol_back_sym(Δ)[2]
@test sym_back(ΔX_symmetric)[2] dX_pullback(Δ)[2]
end

@testset "det and logdet (uplo=$p)" for p in ['U', 'L']
@testset "$op" for op in (det, logdet)
@testset "$T" for T in (Float64, ComplexF64)
n = 5
# rand (not randn) so det will be postive, so logdet will be defined
A = 3 * rand(T, (n, n))
X = Cholesky((p === 'U' ? UpperTriangular : LowerTriangular)(A * A' + I))
X̄_acc = Tangent{typeof(X)}(; factors=Diagonal(randn(T, n))) # sensitivity is always a diagonal
test_rrule(op, X X̄_acc)

# return type
_, op_pullback = rrule(op, X)
= op_pullback(2.7)[2]
@testisa Tangent{<:Cholesky}
@test.factors isa Diagonal
end
end
end
end
end

0 comments on commit 54995d6

Please sign in to comment.