-
Notifications
You must be signed in to change notification settings - Fork 18
mul/ewise rules for basic arithmetic semiring #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
11bdec4
arithmetic groundwork
42c8670
arithmetic rules for mul and elwise 1st pass
8f595b9
tests and a few fixes
4e52852
Merge branch 'master' into arithmeticchains
c58a4a8
Add mask function, fix eadd(PLUS)
980d7d5
correct mul rrules
4b2e00c
test folder structure
952e7a0
mask and vector transpose v1
b2289bf
Broken constructor rules
b665fa7
arithmetic groundwork
0f4509e
arithmetic rules for mul and elwise 1st pass
b4ec8c5
tests and a few fixes
7991aa5
Add mask function, fix eadd(PLUS)
bb4dc6e
correct mul rrules
fd8433b
test folder structure
965a983
Broken constructor rules
2f810da
Merge branch 'arithmeticchains' of https://github.com/JuliaSparse/Sui…
9369b60
Move out constructor rules for now
9c6f478
compat
c769833
rm constructorrule includes
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import FiniteDifferences | ||
import LinearAlgebra | ||
import ChainRulesCore: frule, rrule | ||
using ChainRulesCore | ||
const RealOrComplex = Union{Real, Complex} | ||
|
||
#Required for ChainRulesTestUtils | ||
function FiniteDifferences.to_vec(M::GBMatrix) | ||
rayegun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
I, J, X = findnz(M) | ||
function backtomat(xvec) | ||
return GBMatrix(I, J, xvec; nrows = size(M, 1), ncols = size(M, 2)) | ||
end | ||
return X, backtomat | ||
end | ||
rayegun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function FiniteDifferences.to_vec(v::GBVector) | ||
i, x = findnz(v) | ||
function backtovec(xvec) | ||
return GBVector(i, xvec; nrows=size(v, 1)) | ||
end | ||
return x, backtovec | ||
end | ||
|
||
function FiniteDifferences.rand_tangent( | ||
rng::AbstractRNG, | ||
x::GBMatrix{T} | ||
) where {T <: Union{AbstractFloat, Complex}} | ||
n = nnz(x) | ||
v = rand(rng, -9:0.01:9, n) | ||
I, J, _ = findnz(x) | ||
return GBMatrix(I, J, v; nrows = size(x, 1), ncols = size(x, 2)) | ||
end | ||
|
||
function FiniteDifferences.rand_tangent( | ||
rng::AbstractRNG, | ||
x::GBVector{T} | ||
) where {T <: Union{AbstractFloat, Complex}} | ||
n = nnz(x) | ||
v = rand(rng, -9:0.01:9, n) | ||
I, _ = findnz(x) | ||
return GBVector(I, v; nrows = size(x, 1)) | ||
end | ||
|
||
FiniteDifferences.rand_tangent(rng::AbstractRNG, x::AbstractOp) = NoTangent() | ||
# LinearAlgebra.norm freaks over the nothings. | ||
LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#emul TIMES | ||
function frule( | ||
(_, ΔA, ΔB, _), | ||
::typeof(emul), | ||
A::GBArray, | ||
B::GBArray, | ||
::typeof(BinaryOps.TIMES) | ||
) | ||
Ω = emul(A, B, BinaryOps.TIMES) | ||
∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES) | ||
return Ω, ∂Ω | ||
end | ||
function frule((_, ΔA, ΔB), ::typeof(emul), A::GBArray, B::GBArray) | ||
return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES) | ||
end | ||
|
||
function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES)) | ||
function timespullback(ΔΩ) | ||
∂A = emul(ΔΩ, B) | ||
∂B = emul(ΔΩ, A) | ||
return NoTangent(), ∂A, ∂B, NoTangent() | ||
end | ||
return emul(A, B, BinaryOps.TIMES), timespullback | ||
end | ||
|
||
function rrule(::typeof(emul), A::GBArray, B::GBArray) | ||
Ω, fullpb = rrule(emul, A, B, BinaryOps.TIMES) | ||
emulpb(ΔΩ) = fullpb(ΔΩ)[1:3] | ||
return Ω, emulpb | ||
end | ||
|
||
############ | ||
# eadd rules | ||
############ | ||
|
||
# PLUS | ||
###### | ||
|
||
function frule( | ||
(_, ΔA, ΔB, _), | ||
::typeof(eadd), | ||
A::GBArray, | ||
B::GBArray, | ||
::typeof(BinaryOps.PLUS) | ||
) | ||
Ω = eadd(A, B, BinaryOps.PLUS) | ||
∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS) | ||
return Ω, ∂Ω | ||
end | ||
function frule((_, ΔA, ΔB), ::typeof(eadd), A::GBArray, B::GBArray) | ||
return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS) | ||
end | ||
|
||
function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS)) | ||
function pluspullback(ΔΩ) | ||
return ( | ||
NoTangent(), | ||
mask(ΔΩ, A; structural = true), | ||
mask(ΔΩ, B; structural = true), | ||
NoTangent() | ||
) | ||
end | ||
return eadd(A, B, BinaryOps.PLUS), pluspullback | ||
end | ||
|
||
# Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule. | ||
rayegun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function rrule(::typeof(eadd), A::GBArray, B::GBArray) | ||
Ω, fullpb = rrule(eadd, A, B, BinaryOps.PLUS) | ||
eaddpb(ΔΩ) = fullpb(ΔΩ)[1:3] | ||
return Ω, eaddpb | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Per Lyndon. Needs adaptation, and/or needs redefinition of map to use functions rather | ||
# than AbstractOp. | ||
#function rrule(map, f, xs) | ||
# # Rather than 3 maps really want 1 multimap | ||
# ys_and_pullbacks = map(x->rrule(f, x), xs) #Take this to ys = map(f, x) | ||
# ys = map(first, ys_and_pullbacks) | ||
# pullbacks = map(last, ys_and_pullbacks) | ||
# function map_pullback(dys) | ||
# _call(f, x) = f(x) | ||
# dfs_and_dxs = map(_call, pullbacks, dys) | ||
# # but in your case you know it will be NoTangent() so can skip | ||
# df = sum(first, dfs_and_dxs) | ||
# dxs = map(last, dfs_and_dxs) | ||
# return NoTangent(), df, dxs | ||
# end | ||
# return ys, map_pullback | ||
#end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Standard arithmetic mul: | ||
function frule( | ||
(_, ΔA, ΔB), | ||
::typeof(mul), | ||
A::GBMatOrTranspose, | ||
B::GBMatOrTranspose | ||
) | ||
frule((nothing, ΔA, ΔB, nothing), mul, A, B, Semirings.PLUS_TIMES) | ||
end | ||
function frule( | ||
(_, ΔA, ΔB, _), | ||
::typeof(mul), | ||
A::GBMatOrTranspose, | ||
B::GBMatOrTranspose, | ||
::typeof(Semirings.PLUS_TIMES) | ||
) | ||
Ω = mul(A, B) | ||
∂Ω = mul(ΔA, B) + mul(A, ΔB) | ||
return Ω, ∂Ω | ||
end | ||
# Tests will not pass for this. For two reasons. | ||
# First is #25, the output inference is not type stable. | ||
# That's it's own issue. | ||
|
||
# Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings. | ||
# This means they give different answers. FiniteDifferences is probably "incorrect", but I have no proof. | ||
|
||
function rrule( | ||
::typeof(mul), | ||
A::GBMatOrTranspose, | ||
B::GBMatOrTranspose, | ||
::typeof(Semirings.PLUS_TIMES) | ||
) | ||
function mulpullback(ΔΩ) | ||
∂A = mul(ΔΩ, B'; mask=A) | ||
∂B = mul(A', ΔΩ; mask=B) | ||
return NoTangent(), ∂A, ∂B, NoTangent() | ||
end | ||
return mul(A, B), mulpullback | ||
end | ||
|
||
|
||
function rrule( | ||
::typeof(mul), | ||
A::GBMatOrTranspose, | ||
B::GBMatOrTranspose | ||
) | ||
Ω, mulpullback = rrule(mul, A, B, Semirings.PLUS_TIMES) | ||
pullback(ΔΩ) = mulpullback(ΔΩ)[1:3] | ||
return Ω, pullback | ||
end |
Empty file.
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,7 +59,6 @@ function LinearAlgebra.mul!( | |
return w | ||
end | ||
|
||
|
||
""" | ||
mul(A::GBArray, B::GBArray; kwargs...)::GBArray | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.