Skip to content

Commit 70a2940

Browse files
authored
Fix EnzymeAD#2182: avoid double accumulation in *MixedDuplicated (EnzymeAD#2262)
* Add nonzero dval tests to expose jit bug * Fix jit bug
1 parent 8a0bff4 commit 70a2940

File tree

3 files changed

+271
-203
lines changed

3 files changed

+271
-203
lines changed

src/rules/jitrules.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,7 +1488,7 @@ end
14881488
end
14891489
elseif args[i] <: MixedDuplicated
14901490
:(args[$i].dval[])
1491-
else
1491+
else # args[i] <: BatchMixedDuplicated
14921492
:(args[$i].dval[$w][])
14931493
end
14941494

@@ -1500,9 +1500,11 @@ end
15001500
T = Core.Typeof(vecld)
15011501
@assert !(vecld isa Base.RefValue)
15021502
vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), $expr)
1503-
else
1503+
elseif $(args[i] <: Active)
15041504
val = @inbounds vec[idx_in_vec]
15051505
add_into_vec!(Base.inferencebarrier(val), $expr, vec, idx_in_vec)
1506+
else # args[i] <: MixedDuplicated || args[i] <: BatchMixedDuplicated
1507+
@inbounds vec[idx_in_vec] = $expr
15061508
end
15071509
end
15081510
else

test/applyiter.jl

Lines changed: 157 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -105,155 +105,174 @@ end
105105

106106
@testset "Reverse Apply iterate" begin
107107
x = [(2.0, 3.0), (7.9, 11.2)]
108-
dx = [(0.0, 0.0), (0.0, 0.0)]
109-
res = Enzyme.autodiff(Reverse, metasumsq, Active, Const(metaconcat), Duplicated(x, dx))
110-
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])
111-
112-
dx = [(0.0, 0.0), (0.0, 0.0)]
113-
res = Enzyme.autodiff(ReverseWithPrimal, metasumsq, Active, Const(metaconcat), Duplicated(x, dx))
114-
@test res[2] 200.84999999999997
115-
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])
116-
117-
x = [[2.0, 3.0], [7.9, 11.2]]
118-
dx = [[0.0, 0.0], [0.0, 0.0]]
119-
120-
res = Enzyme.autodiff(Reverse, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx))
121-
@test dx [[4.0, 6.0], [15.8, 22.4]]
122-
123-
dx = [[0.0, 0.0], [0.0, 0.0]]
124-
125-
res = Enzyme.autodiff(ReverseWithPrimal, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx))
126-
127-
@test res[2] 200.84999999999997
128-
@test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]])
129-
130-
131-
x = [(2.0, 3.0), (7.9, 11.2)]
132-
dx = [(0.0, 0.0), (0.0, 0.0)]
133-
134108
y = [(13, 17), (25, 31)]
135-
res = Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Const(y))
136-
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])
137-
138-
139-
x = [(2.0, 3.0), (7.9, 11.2)]
140-
dx = [(0.0, 0.0), (0.0, 0.0)]
141-
y = [(13, 17), (25, 31)]
142-
dy = [(0, 0), (0, 0)]
143-
res = Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy))
144-
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])
145-
146-
147-
148-
x = [[2.0, 3.0], [7.9, 11.2]]
149-
dx = [[0.0, 0.0], [0.0, 0.0]]
150-
y = [[13, 17], [25, 31]]
151-
res = Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Const(y))
152-
@test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]])
153-
109+
dy_const = [(0, 0), (0, 0)]
110+
primal = 200.84999999999997
111+
@testset "tuple $label" for (label, dx_pre, dx_post) in [
112+
("dx == 0", [(0.0, 0.0), (0.0, 0.0)], [(4.0, 6.0), (15.8, 22.4)]),
113+
("dx != 0", [(1.0, -2.0), (-3.0, 4.0)], [(5.0, 4.0), (12.8, 26.4)]),
114+
]
115+
dx = deepcopy(dx_pre)
116+
Enzyme.autodiff(Reverse, metasumsq, Active, Const(metaconcat), Duplicated(x, dx))
117+
@test tupapprox(dx, dx_post)
118+
119+
dx = deepcopy(dx_pre)
120+
res = Enzyme.autodiff(ReverseWithPrimal, metasumsq, Active, Const(metaconcat), Duplicated(x, dx))
121+
@test res[2] primal
122+
@test tupapprox(dx, dx_post)
123+
124+
dx = deepcopy(dx_pre)
125+
Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Const(y))
126+
@test tupapprox(dx, dx_post)
127+
128+
dx = deepcopy(dx_pre)
129+
dy = deepcopy(dy_const)
130+
Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy))
131+
@test tupapprox(dx, dx_post)
132+
@test tupapprox(dy, dy_const)
133+
end
154134

155135
x = [[2.0, 3.0], [7.9, 11.2]]
156-
dx = [[0.0, 0.0], [0.0, 0.0]]
157136
y = [[13, 17], [25, 31]]
158-
dy = [[0, 0], [0, 0]]
159-
res = Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy))
160-
@test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]])
137+
dy_const = [[0, 0], [0, 0]]
138+
primal = 200.84999999999997
139+
@testset "list $label" for (label, dx_pre, dx_post) in [
140+
("dx == 0", [[0.0, 0.0], [0.0, 0.0]], [[4.0, 6.0], [15.8, 22.4]]),
141+
("dx != 0", [[1.0, -2.0], [-3.0, 4.0]], [[5.0, 4.0], [12.8, 26.4]]),
142+
]
143+
dx = deepcopy(dx_pre)
144+
Enzyme.autodiff(Reverse, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx))
145+
@test dx dx_post
146+
147+
dx = deepcopy(dx_pre)
148+
res = Enzyme.autodiff(ReverseWithPrimal, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx))
149+
@test res[2] primal
150+
@test dx dx_post
151+
152+
dx = deepcopy(dx_pre)
153+
Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Const(y))
154+
@test dx dx_post
155+
156+
dx = deepcopy(dx_pre)
157+
dy = deepcopy(dy_const)
158+
Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy))
159+
@test dx dx_post
160+
@test dy dy_const
161+
end
161162
end
162163

163164
@testset "BatchReverse Apply iterate" begin
164165
x = [(2.0, 3.0), (7.9, 11.2)]
165-
dx = [(0.0, 0.0), (0.0, 0.0)]
166-
dx2 = [(0.0, 0.0), (0.0, 0.0)]
167-
out = Ref(0.0)
168-
dout = Ref(1.0)
169-
dout2 = Ref(3.0)
170-
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
171-
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])
172-
@test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)])
173-
174-
dx = [(0.0, 0.0), (0.0, 0.0)]
175-
dx2 = [(0.0, 0.0), (0.0, 0.0)]
176-
out = Ref(0.0)
177-
dout = Ref(1.0)
178-
dout2 = Ref(3.0)
179-
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
180-
@test out[] 200.84999999999997
181-
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])
182-
@test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)])
183-
184-
x = [[2.0, 3.0], [7.9, 11.2]]
185-
dx = [[0.0, 0.0], [0.0, 0.0]]
186-
dx2 = [[0.0, 0.0], [0.0, 0.0]]
187-
out = Ref(0.0)
188-
dout = Ref(1.0)
189-
dout2 = Ref(3.0)
190-
191-
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
192-
@test dx [[4.0, 6.0], [15.8, 22.4]]
193-
@test dx2 [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]
194-
195-
dx = [[0.0, 0.0], [0.0, 0.0]]
196-
dx2 = [[0.0, 0.0], [0.0, 0.0]]
197-
out = Ref(0.0)
198-
dout = Ref(1.0)
199-
dout2 = Ref(3.0)
200-
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
201-
202-
@test out[] 200.84999999999997
203-
@test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]])
204-
@test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]])
205-
206-
207-
x = [(2.0, 3.0), (7.9, 11.2)]
208-
dx = [(0.0, 0.0), (0.0, 0.0)]
209-
dx2 = [(0.0, 0.0), (0.0, 0.0)]
210-
211166
y = [(13, 17), (25, 31)]
212-
out = Ref(0.0)
213-
dout = Ref(1.0)
214-
dout2 = Ref(3.0)
215-
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y))
216-
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])
217-
@test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)])
218-
219-
220-
x = [(2.0, 3.0), (7.9, 11.2)]
221-
dx = [(0.0, 0.0), (0.0, 0.0)]
222-
dx2 = [(0.0, 0.0), (0.0, 0.0)]
223-
y = [(13, 17), (25, 31)]
224-
dy = [(0, 0), (0, 0)]
225-
dy2 = [(0, 0), (0, 0)]
226-
out = Ref(0.0)
227-
dout = Ref(1.0)
228-
dout2 = Ref(3.0)
229-
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3),Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2)))
230-
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])
231-
@test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)])
232-
233-
234-
x = [[2.0, 3.0], [7.9, 11.2]]
235-
dx = [[0.0, 0.0], [0.0, 0.0]]
236-
dx2 = [[0.0, 0.0], [0.0, 0.0]]
237-
y = [[13, 17], [25, 31]]
238-
out = Ref(0.0)
239-
dout = Ref(1.0)
240-
dout2 = Ref(3.0)
241-
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y))
242-
@test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]])
243-
@test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]])
167+
dy_const = [(0, 0), (0, 0)]
168+
primal = 200.84999999999997
169+
out_pre, dout_pre, dout2_pre = 0.0, 1.0, 3.0
170+
@testset "tuple $label" for (label, dx_pre, dx_post, dx2_post) in [
171+
(
172+
"dx == 0",
173+
[(0.0, 0.0), (0.0, 0.0)],
174+
[(4.0, 6.0), (15.8, 22.4)],
175+
[(3 * 4.0, 3 * 6.0), (3 * 15.8, 3 * 22.4)],
176+
),
177+
(
178+
"dx != 0",
179+
[(1.0, -2.0), (-3.0, 4.0)],
180+
[(5.0, 4.0), (12.8, 26.4)],
181+
[(1.0 + 3 * 4.0, -2.0 + 3 * 6.0), (-3.0 + 3 * 15.8, 4.0 + 3 * 22.4)],
182+
),
183+
]
184+
dx, dx2 = deepcopy.((dx_pre, dx_pre))
185+
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
186+
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
187+
@test dout[] 0
188+
@test dout2[] 0
189+
@test tupapprox(dx, dx_post)
190+
@test tupapprox(dx2, dx2_post)
191+
192+
dx, dx2 = deepcopy.((dx_pre, dx_pre))
193+
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
194+
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
195+
@test out[] primal
196+
@test dout[] 0
197+
@test dout2[] 0
198+
@test tupapprox(dx, dx_post)
199+
@test tupapprox(dx2, dx2_post)
200+
201+
dx, dx2 = deepcopy.((dx_pre, dx_pre))
202+
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
203+
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y))
204+
@test dout[] 0
205+
@test dout2[] 0
206+
@test tupapprox(dx, dx_post)
207+
@test tupapprox(dx2, dx2_post)
208+
209+
dx, dx2 = deepcopy.((dx_pre, dx_pre))
210+
dy, dy2 = deepcopy.((dy_const, dy_const))
211+
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
212+
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2)))
213+
@test dout[] 0
214+
@test dout2[] 0
215+
@test tupapprox(dx, dx_post)
216+
@test tupapprox(dx2, dx2_post)
217+
@test tupapprox(dy, dy_const)
218+
@test tupapprox(dy2, dy_const)
219+
end
244220

245221
x = [[2.0, 3.0], [7.9, 11.2]]
246-
dx = [[0.0, 0.0], [0.0, 0.0]]
247-
dx2 = [[0.0, 0.0], [0.0, 0.0]]
248222
y = [[13, 17], [25, 31]]
249-
dy = [[0, 0], [0, 0]]
250-
dy2 = [[0, 0], [0, 0]]
251-
out = Ref(0.0)
252-
dout = Ref(1.0)
253-
dout2 = Ref(3.0)
254-
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2)))
255-
@test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]])
256-
@test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]])
223+
dy_const = [[0, 0], [0, 0]]
224+
primal = 200.84999999999997
225+
out_pre, dout_pre, dout2_pre = 0.0, 1.0, 3.0
226+
@testset "tuple $label" for (label, dx_pre, dx_post, dx2_post) in [
227+
(
228+
"dx == 0",
229+
[[0.0, 0.0], [0.0, 0.0]],
230+
[[4.0, 6.0], [15.8, 22.4]],
231+
[[3 * 4.0, 3 * 6.0], [3 * 15.8, 3 * 22.4]],
232+
),
233+
(
234+
"dx != 0",
235+
[[1.0, -2.0], [-3.0, 4.0]],
236+
[[5.0, 4.0], [12.8, 26.4]],
237+
[[1.0 + 3 * 4.0, -2.0 + 3 * 6.0], [-3.0 + 3 * 15.8, 4.0 + 3 * 22.4]],
238+
),
239+
]
240+
dx, dx2 = deepcopy.((dx_pre, dx_pre))
241+
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
242+
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
243+
@test dout[] 0
244+
@test dout2[] 0
245+
@test dx dx_post
246+
@test dx2 dx2_post
247+
248+
dx, dx2 = deepcopy.((dx_pre, dx_pre))
249+
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
250+
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
251+
@test out[] primal
252+
@test dout[] 0
253+
@test dout2[] 0
254+
@test dx dx_post
255+
@test dx2 dx2_post
256+
257+
dx, dx2 = deepcopy.((dx_pre, dx_pre))
258+
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
259+
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y))
260+
@test dout[] 0
261+
@test dout2[] 0
262+
@test dx dx_post
263+
@test dx2 dx2_post
264+
265+
dx, dx2 = deepcopy.((dx_pre, dx_pre))
266+
dy, dy2 = deepcopy.((dy_const, dy_const))
267+
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
268+
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2)))
269+
@test dout[] 0
270+
@test dout2[] 0
271+
@test dx dx_post
272+
@test dx2 dx2_post
273+
@test dy dy_const
274+
@test dy2 dy_const
275+
end
257276
end
258277

259278
@testset "Forward Apply iterate" begin
@@ -502,4 +521,4 @@ end
502521
@test ddata[1][1] 6.0
503522
end
504523

505-
include("mixedapplyiter.jl")
524+
include("mixedapplyiter.jl")

0 commit comments

Comments
 (0)