1
1
using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, ReverseDiff, Tracker, Enzyme,
2
- FiniteDiff
2
+ FiniteDiff, Mooncake
3
3
using Test
4
4
Enzyme. API. typeWarning! (false )
5
5
6
+ mooncake_gradient (f, x) = Mooncake. value_and_gradient!! (Mooncake. build_rrule (f, x), f, x)[2 ][2 ]
7
+
6
8
odef (du, u, p, t) = du .= u .* p
7
9
const prob = ODEProblem (odef, [2.0 ], (0.0 , 1.0 ), [3.0 ])
8
10
@@ -17,7 +19,9 @@ u0p = [2.0, 3.0]
17
19
du0p = zeros (2 )
18
20
dup = Zygote. gradient (senseloss0 (InterpolatingAdjoint ()), u0p)[1 ]
19
21
Enzyme. autodiff (Reverse, senseloss0 (InterpolatingAdjoint ()), Active, Duplicated (u0p, du0p))
22
+ dup_mc = mooncake_gradient (senseloss0 (InterpolatingAdjoint ()), u0p)
20
23
@test du0p ≈ dup
24
+ @test dup_mc ≈ dup
21
25
22
26
struct senseloss{T}
23
27
sense:: T
@@ -56,6 +60,14 @@ dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1]
56
60
@test only (Enzyme. gradient (Reverse, senseloss (ForwardDiffSensitivity ()), u0p)) ≈ dup
57
61
@test_broken only (Enzyme. gradient (Reverse, senseloss (ForwardSensitivity ()), u0p)) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0
58
62
63
+ @test mooncake_gradient (senseloss (InterpolatingAdjoint ()), u0p) ≈ dup
64
+ @test_throws TypeError mooncake_gradient (senseloss (ReverseDiffAdjoint ()), u0p) ≈ dup
65
+ @test_throws TypeError mooncake_gradient (senseloss (TrackerAdjoint ()), u0p) ≈ dup
66
+ # @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss(ReverseDiffAdjoint()), u0p) ≈ dup
67
+ # @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss(TrackerAdjoint()), u0p) ≈ dup
68
+ @test mooncake_gradient (senseloss (ForwardDiffSensitivity ()), u0p) ≈ dup
69
+ @test_broken mooncake_gradient (senseloss (ForwardSensitivity ()), u0p) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0
70
+
59
71
struct senseloss2{T}
60
72
sense:: T
61
73
end
@@ -90,6 +102,14 @@ dup = Zygote.gradient(senseloss2(InterpolatingAdjoint()), u0p)[1]
90
102
@test_broken only (Enzyme. gradient (Reverse, senseloss2 (ForwardDiffSensitivity ()), u0p)) ≈ dup
91
103
@test_broken only (Enzyme. gradient (Reverse, senseloss2 (ForwardSensitivity ()), u0p)) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0
92
104
105
+ @test mooncake_gradient (senseloss2 (InterpolatingAdjoint ()), u0p) ≈ dup
106
+ @test_throws TypeError mooncake_gradient (senseloss2 (ReverseDiffAdjoint ()), u0p) ≈ dup
107
+ @test_throws TypeError mooncake_gradient (senseloss2 (TrackerAdjoint ()), u0p) ≈ dup
108
+ # @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss2(ReverseDiffAdjoint()), u0p) ≈ dup
109
+ # @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss2(TrackerAdjoint()), u0p) ≈ dup
110
+ @test mooncake_gradient (senseloss2 (ForwardDiffSensitivity ()), u0p) ≈ dup
111
+ @test_broken mooncake_gradient (senseloss2 (ForwardSensitivity ()), u0p) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0
112
+
93
113
struct senseloss3{T}
94
114
sense:: T
95
115
end
@@ -122,6 +142,14 @@ dup = Zygote.gradient(senseloss3(InterpolatingAdjoint()), u0p)[1]
122
142
@test_broken only (Enzyme. gradient (Reverse, senseloss3 (ForwardDiffSensitivity ()), u0p)) ≈ dup
123
143
@test_broken only (Enzyme. gradient (Reverse, senseloss3 (ForwardSensitivity ()), u0p)) ≈ dup
124
144
145
+ @test mooncake_gradient (senseloss3 (InterpolatingAdjoint ()), u0p) ≈ dup
146
+ @test_throws TypeError mooncake_gradient (senseloss3 (ReverseDiffAdjoint ()), u0p) ≈ dup
147
+ @test_throws TypeError mooncake_gradient (senseloss3 (TrackerAdjoint ()), u0p) ≈ dup
148
+ # @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss3(ReverseDiffAdjoint()), u0p) ≈ dup
149
+ # @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss3(TrackerAdjoint()), u0p) ≈ dup
150
+ @test mooncake_gradient (senseloss3 (ForwardDiffSensitivity ()), u0p) ≈ dup
151
+ @test_broken mooncake_gradient (senseloss3 (ForwardSensitivity ()), u0p) ≈ dup
152
+
125
153
struct senseloss4{T}
126
154
sense:: T
127
155
end
@@ -156,6 +184,14 @@ dup = Zygote.gradient(senseloss4(InterpolatingAdjoint()), u0p)[1]
156
184
@test only (Enzyme. gradient (Reverse, senseloss4 (ForwardDiffSensitivity ()), u0p)) ≈ dup
157
185
@test_broken only (Enzyme. gradient (Reverse, senseloss4 (ForwardSensitivity ()), u0p)) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0
158
186
187
+ @test mooncake_gradient (senseloss4 (InterpolatingAdjoint ()), u0p) ≈ dup
188
+ @test_throws TypeError mooncake_gradient (senseloss4 (ReverseDiffAdjoint ()), u0p) ≈ dup
189
+ @test_throws TypeError mooncake_gradient (senseloss4 (TrackerAdjoint ()), u0p) ≈ dup
190
+ # @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss4(ReverseDiffAdjoint()), u0p) ≈ dup
191
+ # @test_throws SciMLSensitivity.MooncakeTrackedRealError mooncake_gradient(senseloss4(TrackerAdjoint()), u0p) ≈ dup
192
+ @test mooncake_gradient (senseloss4 (ForwardDiffSensitivity ()), u0p) ≈ dup
193
+ @test_broken mooncake_gradient (senseloss4 (ForwardSensitivity ()), u0p) ≈ dup
194
+
159
195
solvealg_test = Tsit5 ()
160
196
sensealg_test = InterpolatingAdjoint ()
161
197
tspan = (0.0 , 1.0 )
@@ -186,6 +222,9 @@ res4 = ReverseDiff.gradient(loss2, p0)
186
222
@test_broken res2≈ Enzyme. gradient (Reverse, loss, p0) atol= 1e-14
187
223
@test_broken res4≈ Enzyme. gradient (Reverse, loss2, p0) atol= 1e-14
188
224
225
+ @test res2 ≈ mooncake_gradient (loss, p0)
226
+ @test res4 ≈ mooncake_gradient (loss2, p0)
227
+
189
228
# Test for recursion https://discourse.julialang.org/t/diffeqsensitivity-jl-issues-with-reversediffadjoint-sensealg/88774
190
229
function ode! (derivative, state, parameters, t)
191
230
derivative .= parameters
@@ -205,6 +244,7 @@ const initial_state = ones(2)
205
244
const solution_times = [1.0 , 2.0 ]
206
245
ReverseDiff. gradient (p -> sum (sum (solve_euler (initial_state, solution_times, p))), zeros (2 ))
207
246
# Enzyme.gradient(Reverse, p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))
247
+ # mooncake_gradient(p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))
208
248
209
249
# https://github.com/SciML/SciMLSensitivity.jl/issues/943
210
250
@@ -249,3 +289,4 @@ grad_rd = ReverseDiff.gradient(loss2, p)
249
289
@test grad_fd≈ grad_fi atol= 1e-2
250
290
@test grad_fd≈ grad_zg atol= 1e-4
251
291
@test grad_fd≈ grad_rd atol= 1e-4
292
+ @test_broken mooncake_gradient (loss2, p) ≈ grad_rd atol= 1e-4
0 commit comments