Skip to content

Commit b5b73d3

Browse files
Merge pull request #1151 from willtebbutt/wct/mooncake
Add Mooncake to Alternative AD Frontends
2 parents d34bd12 + 0629224 commit b5b73d3

File tree

5 files changed

+76
-8
lines changed

5 files changed

+76
-8
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
2424
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2525
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2626
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
27+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2728
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
2829
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
2930
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
@@ -59,7 +60,7 @@ Calculus = "0.5.1"
5960
ChainRulesCore = "0.10.7, 1"
6061
ComponentArrays = "0.15.5"
6162
DelayDiffEq = "5.43.2"
62-
DiffEqBase = "6.166.1"
63+
DiffEqBase = "6.175"
6364
DiffEqCallbacks = "4"
6465
DiffEqNoiseProcess = "5.19"
6566
Distributed = "1"
@@ -94,7 +95,7 @@ RecursiveArrayTools = "3.27.2"
9495
Reexport = "1.0"
9596
ReverseDiff = "1.15.1"
9697
SafeTestsets = "0.1.0"
97-
SciMLBase = "2.79"
98+
SciMLBase = "2.94"
9899
SciMLJacobianOperators = "0.1"
99100
SciMLStructures = "1.3"
100101
SparseArrays = "1.10"

docs/src/manual/differential_equation_sensitivities.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ Current AD libraries whose calls are captured by the sensitivity
1212
system are:
1313

1414
- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)
15-
- [Zygote.jl](https://fluxml.ai/Zygote.jl/stable/)
16-
- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl)
15+
- [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl)
1716
- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl)
17+
- [Tracker.jl](https://github.com/FluxML/Tracker.jl)
18+
- [Zygote.jl](https://fluxml.ai/Zygote.jl/stable/)
1819

1920
## Using and Controlling Sensitivity Algorithms within AD
2021

ext/SciMLSensitivityMooncakeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module SciMLSensitivityMooncakeExt
22

33
using SciMLSensitivity, Mooncake
4-
import SciMLSensitivity: get_paramjac_config, mooncake_run_ad, MooncakeVJP, MooncakeLoaded
4+
import SciMLSensitivity: get_paramjac_config, mooncake_run_ad, MooncakeVJP, MooncakeLoaded, DiffEqBase
55

66
function get_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, pf, p, f, y, _t)
77
dy_mem = zero(y)

src/concrete_solve.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ function DiffEqBase._concrete_solve_adjoint(
460460
end
461461

462462
_prob = remake(_prob, u0 = new_u0, p = new_p)
463+
463464

464465
if sensealg isa BacksolveAdjoint
465466
sol = solve(_prob, alg, args...; initializealg = new_initializealg, save_noise = true,
@@ -870,8 +871,9 @@ function DiffEqBase._concrete_solve_adjoint(
870871
end
871872

872873
# use the callback in kwargs, not prob
873-
sol = solve(remake(prob, p = p, u0 = u0, callback = nothing),
874-
alg, args...; saveat = _saveat, kwargs...)
874+
kwargs_prob = NamedTuple(filter(x -> x[1] != :callback, prob.kwargs))
875+
_prob = remake(prob, p = p, u0 = u0, kwargs = kwargs_prob)
876+
sol = solve(_prob, alg, args...; saveat = _saveat, kwargs...)
875877

876878
if originator isa SciMLBase.EnzymeOriginator
877879
@reset sol.prob = prob
@@ -1273,6 +1275,21 @@ function Base.showerror(io::IO, e::EnzymeTrackedRealError)
12731275
println(io, ENZYME_TRACKED_REAL_ERROR_MESSAGE)
12741276
end
12751277

1278+
const MOONCAKE_TRACKED_REAL_ERROR_MESSAGE = """
1279+
`Mooncake` is not compatible with `ReverseDiffAdjoint` nor with `TrackerAdjoint`.
1280+
Either choose a different adjoint method like `GaussAdjoint`,
1281+
or use a different AD system like `ReverseDiff`.
1282+
For more details, on these methods see
1283+
https://docs.sciml.ai/SciMLSensitivity/stable/.
1284+
"""
1285+
1286+
struct MooncakeTrackedRealError <: Exception
1287+
end
1288+
1289+
function Base.showerror(io::IO, e::MooncakeTrackedRealError)
1290+
println(io, MOONCAKE_TRACKED_REAL_ERROR_MESSAGE)
1291+
end
1292+
12761293
function DiffEqBase._concrete_solve_adjoint(
12771294
prob::Union{SciMLBase.AbstractDiscreteProblem,
12781295
SciMLBase.AbstractODEProblem,
@@ -1290,6 +1307,10 @@ function DiffEqBase._concrete_solve_adjoint(
12901307
throw(EnzymeTrackedRealError())
12911308
end
12921309

1310+
if originator isa SciMLBase.MooncakeOriginator
1311+
throw(MooncakeTrackedRealError())
1312+
end
1313+
12931314
if !(p === nothing || p isa SciMLBase.NullParameters)
12941315
if !isscimlstructure(p)
12951316
throw(SciMLStructuresCompatibilityError())
@@ -1514,6 +1535,10 @@ function DiffEqBase._concrete_solve_adjoint(
15141535
throw(EnzymeTrackedRealError())
15151536
end
15161537

1538+
if originator isa SciMLBase.MooncakeOriginator
1539+
throw(MooncakeTrackedRealError())
1540+
end
1541+
15171542
t = eltype(prob.tspan)[]
15181543
u = typeof(u0)[]
15191544

test/alternative_ad_frontend.jl

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, ReverseDiff, Tracker, Enzyme,
2-
FiniteDiff
2+
FiniteDiff, Mooncake
33
using Test
44
Enzyme.API.typeWarning!(false)
55

6+
mooncake_gradient(f, x) = Mooncake.value_and_gradient!!(Mooncake.build_rrule(f, x), f, x)[2][2]
7+
68
odef(du, u, p, t) = du .= u .* p
79
const prob = ODEProblem(odef, [2.0], (0.0, 1.0), [3.0])
810

@@ -17,7 +19,9 @@ u0p = [2.0, 3.0]
1719
du0p = zeros(2)
1820
dup = Zygote.gradient(senseloss0(InterpolatingAdjoint()), u0p)[1]
1921
Enzyme.autodiff(Reverse, senseloss0(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p))
22+
dup_mc = mooncake_gradient(senseloss0(InterpolatingAdjoint()), u0p)
2023
@test du0p dup
24+
@test dup_mc dup
2125

2226
struct senseloss{T}
2327
sense::T
@@ -56,6 +60,14 @@ dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1]
5660
@test only(Enzyme.gradient(Reverse, senseloss(ForwardDiffSensitivity()), u0p)) dup
5761
@test_broken only(Enzyme.gradient(Reverse, senseloss(ForwardSensitivity()), u0p)) dup # broken because ForwardSensitivity not compatible with perturbing u0
5862

63+
@test mooncake_gradient(senseloss(InterpolatingAdjoint()), u0p) dup
64+
@test_throws TypeError mooncake_gradient(senseloss(ReverseDiffAdjoint()), u0p) dup
65+
@test_throws TypeError mooncake_gradient(senseloss(TrackerAdjoint()), u0p) dup
66+
#@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss(ReverseDiffAdjoint()), u0p) ≈ dup
67+
#@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss(TrackerAdjoint()), u0p) ≈ dup
68+
@test mooncake_gradient(senseloss(ForwardDiffSensitivity()), u0p) dup
69+
@test_broken mooncake_gradient(senseloss(ForwardSensitivity()), u0p) dup # broken because ForwardSensitivity not compatible with perturbing u0
70+
5971
struct senseloss2{T}
6072
sense::T
6173
end
@@ -90,6 +102,14 @@ dup = Zygote.gradient(senseloss2(InterpolatingAdjoint()), u0p)[1]
90102
@test_broken only(Enzyme.gradient(Reverse, senseloss2(ForwardDiffSensitivity()), u0p)) dup
91103
@test_broken only(Enzyme.gradient(Reverse, senseloss2(ForwardSensitivity()), u0p)) dup # broken because ForwardSensitivity not compatible with perturbing u0
92104

105+
@test mooncake_gradient(senseloss2(InterpolatingAdjoint()), u0p) dup
106+
@test_throws TypeError mooncake_gradient(senseloss2(ReverseDiffAdjoint()), u0p) dup
107+
@test_throws TypeError mooncake_gradient(senseloss2(TrackerAdjoint()), u0p) dup
108+
#@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss2(ReverseDiffAdjoint()), u0p) ≈ dup
109+
#@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss2(TrackerAdjoint()), u0p) ≈ dup
110+
@test mooncake_gradient(senseloss2(ForwardDiffSensitivity()), u0p) dup
111+
@test_broken mooncake_gradient(senseloss2(ForwardSensitivity()), u0p) dup # broken because ForwardSensitivity not compatible with perturbing u0
112+
93113
struct senseloss3{T}
94114
sense::T
95115
end
@@ -122,6 +142,14 @@ dup = Zygote.gradient(senseloss3(InterpolatingAdjoint()), u0p)[1]
122142
@test_broken only(Enzyme.gradient(Reverse, senseloss3(ForwardDiffSensitivity()), u0p)) dup
123143
@test_broken only(Enzyme.gradient(Reverse, senseloss3(ForwardSensitivity()), u0p)) dup
124144

145+
@test mooncake_gradient(senseloss3(InterpolatingAdjoint()), u0p) dup
146+
@test_throws TypeError mooncake_gradient(senseloss3(ReverseDiffAdjoint()), u0p) dup
147+
@test_throws TypeError mooncake_gradient(senseloss3(TrackerAdjoint()), u0p) dup
148+
#@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss3(ReverseDiffAdjoint()), u0p) ≈ dup
149+
#@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss3(TrackerAdjoint()), u0p) ≈ dup
150+
@test mooncake_gradient(senseloss3(ForwardDiffSensitivity()), u0p) dup
151+
@test_broken mooncake_gradient(senseloss3(ForwardSensitivity()), u0p) dup
152+
125153
struct senseloss4{T}
126154
sense::T
127155
end
@@ -156,6 +184,14 @@ dup = Zygote.gradient(senseloss4(InterpolatingAdjoint()), u0p)[1]
156184
@test only(Enzyme.gradient(Reverse, senseloss4(ForwardDiffSensitivity()), u0p)) dup
157185
@test_broken only(Enzyme.gradient(Reverse, senseloss4(ForwardSensitivity()), u0p)) dup # broken because ForwardSensitivity not compatible with perturbing u0
158186

187+
@test mooncake_gradient(senseloss4(InterpolatingAdjoint()), u0p) dup
188+
@test_throws TypeError mooncake_gradient(senseloss4(ReverseDiffAdjoint()), u0p) dup
189+
@test_throws TypeError mooncake_gradient(senseloss4(TrackerAdjoint()), u0p) dup
190+
#@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss4(ReverseDiffAdjoint()), u0p) ≈ dup
191+
#@test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss4(TrackerAdjoint()), u0p) ≈ dup
192+
@test mooncake_gradient(senseloss4(ForwardDiffSensitivity()), u0p) dup
193+
@test_broken mooncake_gradient(senseloss4(ForwardSensitivity()), u0p) dup
194+
159195
solvealg_test = Tsit5()
160196
sensealg_test = InterpolatingAdjoint()
161197
tspan = (0.0, 1.0)
@@ -186,6 +222,9 @@ res4 = ReverseDiff.gradient(loss2, p0)
186222
@test_broken res2Enzyme.gradient(Reverse, loss, p0) atol=1e-14
187223
@test_broken res4Enzyme.gradient(Reverse, loss2, p0) atol=1e-14
188224

225+
@test res2 mooncake_gradient(loss, p0)
226+
@test res4 mooncake_gradient(loss2, p0)
227+
189228
# Test for recursion https://discourse.julialang.org/t/diffeqsensitivity-jl-issues-with-reversediffadjoint-sensealg/88774
190229
function ode!(derivative, state, parameters, t)
191230
derivative .= parameters
@@ -205,6 +244,7 @@ const initial_state = ones(2)
205244
const solution_times = [1.0, 2.0]
206245
ReverseDiff.gradient(p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))
207246
# Enzyme.gradient(Reverse, p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))
247+
# mooncake_gradient(p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))
208248

209249
# https://github.com/SciML/SciMLSensitivity.jl/issues/943
210250

@@ -249,3 +289,4 @@ grad_rd = ReverseDiff.gradient(loss2, p)
249289
@test grad_fdgrad_fi atol=1e-2
250290
@test grad_fdgrad_zg atol=1e-4
251291
@test grad_fdgrad_rd atol=1e-4
292+
@test_broken mooncake_gradient(loss2, p) grad_rd atol=1e-4

0 commit comments

Comments
 (0)