@@ -105,155 +105,174 @@ end
105
105
106
106
@testset " Reverse Apply iterate" begin
107
107
x = [(2.0 , 3.0 ), (7.9 , 11.2 )]
108
- dx = [(0.0 , 0.0 ), (0.0 , 0.0 )]
109
- res = Enzyme. autodiff (Reverse, metasumsq, Active, Const (metaconcat), Duplicated (x, dx))
110
- @test tupapprox (dx, [(4.0 , 6.0 ), (15.8 , 22.4 )])
111
-
112
- dx = [(0.0 , 0.0 ), (0.0 , 0.0 )]
113
- res = Enzyme. autodiff (ReverseWithPrimal, metasumsq, Active, Const (metaconcat), Duplicated (x, dx))
114
- @test res[2 ] ≈ 200.84999999999997
115
- @test tupapprox (dx, [(4.0 , 6.0 ), (15.8 , 22.4 )])
116
-
117
- x = [[2.0 , 3.0 ], [7.9 , 11.2 ]]
118
- dx = [[0.0 , 0.0 ], [0.0 , 0.0 ]]
119
-
120
- res = Enzyme. autodiff (Reverse, metasumsq2, Active, Const (metaconcat), Duplicated (x, dx))
121
- @test dx ≈ [[4.0 , 6.0 ], [15.8 , 22.4 ]]
122
-
123
- dx = [[0.0 , 0.0 ], [0.0 , 0.0 ]]
124
-
125
- res = Enzyme. autodiff (ReverseWithPrimal, metasumsq2, Active, Const (metaconcat), Duplicated (x, dx))
126
-
127
- @test res[2 ] ≈ 200.84999999999997
128
- @test tupapprox (dx, [[4.0 , 6.0 ], [15.8 , 22.4 ]])
129
-
130
-
131
- x = [(2.0 , 3.0 ), (7.9 , 11.2 )]
132
- dx = [(0.0 , 0.0 ), (0.0 , 0.0 )]
133
-
134
108
y = [(13 , 17 ), (25 , 31 )]
135
- res = Enzyme. autodiff (Reverse, metasumsq3, Active, Const (metaconcat2), Duplicated (x, dx), Const (y))
136
- @test tupapprox (dx, [(4.0 , 6.0 ), (15.8 , 22.4 )])
137
-
138
-
139
- x = [(2.0 , 3.0 ), (7.9 , 11.2 )]
140
- dx = [(0.0 , 0.0 ), (0.0 , 0.0 )]
141
- y = [(13 , 17 ), (25 , 31 )]
142
- dy = [(0 , 0 ), (0 , 0 )]
143
- res = Enzyme. autodiff (Reverse, metasumsq3, Active, Const (metaconcat2), Duplicated (x, dx), Duplicated (y, dy))
144
- @test tupapprox (dx, [(4.0 , 6.0 ), (15.8 , 22.4 )])
145
-
146
-
147
-
148
- x = [[2.0 , 3.0 ], [7.9 , 11.2 ]]
149
- dx = [[0.0 , 0.0 ], [0.0 , 0.0 ]]
150
- y = [[13 , 17 ], [25 , 31 ]]
151
- res = Enzyme. autodiff (Reverse, metasumsq4, Active, Const (metaconcat2), Duplicated (x, dx), Const (y))
152
- @test tupapprox (dx, [[4.0 , 6.0 ], [15.8 , 22.4 ]])
153
-
109
+ dy_const = [(0 , 0 ), (0 , 0 )]
110
+ primal = 200.84999999999997
111
+ @testset " tuple $label " for (label, dx_pre, dx_post) in [
112
+ (" dx == 0" , [(0.0 , 0.0 ), (0.0 , 0.0 )], [(4.0 , 6.0 ), (15.8 , 22.4 )]),
113
+ (" dx != 0" , [(1.0 , - 2.0 ), (- 3.0 , 4.0 )], [(5.0 , 4.0 ), (12.8 , 26.4 )]),
114
+ ]
115
+ dx = deepcopy (dx_pre)
116
+ Enzyme. autodiff (Reverse, metasumsq, Active, Const (metaconcat), Duplicated (x, dx))
117
+ @test tupapprox (dx, dx_post)
118
+
119
+ dx = deepcopy (dx_pre)
120
+ res = Enzyme. autodiff (ReverseWithPrimal, metasumsq, Active, Const (metaconcat), Duplicated (x, dx))
121
+ @test res[2 ] ≈ primal
122
+ @test tupapprox (dx, dx_post)
123
+
124
+ dx = deepcopy (dx_pre)
125
+ Enzyme. autodiff (Reverse, metasumsq3, Active, Const (metaconcat2), Duplicated (x, dx), Const (y))
126
+ @test tupapprox (dx, dx_post)
127
+
128
+ dx = deepcopy (dx_pre)
129
+ dy = deepcopy (dy_const)
130
+ Enzyme. autodiff (Reverse, metasumsq3, Active, Const (metaconcat2), Duplicated (x, dx), Duplicated (y, dy))
131
+ @test tupapprox (dx, dx_post)
132
+ @test tupapprox (dy, dy_const)
133
+ end
154
134
155
135
x = [[2.0 , 3.0 ], [7.9 , 11.2 ]]
156
- dx = [[0.0 , 0.0 ], [0.0 , 0.0 ]]
157
136
y = [[13 , 17 ], [25 , 31 ]]
158
- dy = [[0 , 0 ], [0 , 0 ]]
159
- res = Enzyme. autodiff (Reverse, metasumsq4, Active, Const (metaconcat2), Duplicated (x, dx), Duplicated (y, dy))
160
- @test tupapprox (dx, [[4.0 , 6.0 ], [15.8 , 22.4 ]])
137
+ dy_const = [[0 , 0 ], [0 , 0 ]]
138
+ primal = 200.84999999999997
139
+ @testset " list $label " for (label, dx_pre, dx_post) in [
140
+ (" dx == 0" , [[0.0 , 0.0 ], [0.0 , 0.0 ]], [[4.0 , 6.0 ], [15.8 , 22.4 ]]),
141
+ (" dx != 0" , [[1.0 , - 2.0 ], [- 3.0 , 4.0 ]], [[5.0 , 4.0 ], [12.8 , 26.4 ]]),
142
+ ]
143
+ dx = deepcopy (dx_pre)
144
+ Enzyme. autodiff (Reverse, metasumsq2, Active, Const (metaconcat), Duplicated (x, dx))
145
+ @test dx ≈ dx_post
146
+
147
+ dx = deepcopy (dx_pre)
148
+ res = Enzyme. autodiff (ReverseWithPrimal, metasumsq2, Active, Const (metaconcat), Duplicated (x, dx))
149
+ @test res[2 ] ≈ primal
150
+ @test dx ≈ dx_post
151
+
152
+ dx = deepcopy (dx_pre)
153
+ Enzyme. autodiff (Reverse, metasumsq4, Active, Const (metaconcat2), Duplicated (x, dx), Const (y))
154
+ @test dx ≈ dx_post
155
+
156
+ dx = deepcopy (dx_pre)
157
+ dy = deepcopy (dy_const)
158
+ Enzyme. autodiff (Reverse, metasumsq4, Active, Const (metaconcat2), Duplicated (x, dx), Duplicated (y, dy))
159
+ @test dx ≈ dx_post
160
+ @test dy ≈ dy_const
161
+ end
161
162
end
162
163
163
164
@testset " BatchReverse Apply iterate" begin
164
165
x = [(2.0 , 3.0 ), (7.9 , 11.2 )]
165
- dx = [(0.0 , 0.0 ), (0.0 , 0.0 )]
166
- dx2 = [(0.0 , 0.0 ), (0.0 , 0.0 )]
167
- out = Ref (0.0 )
168
- dout = Ref (1.0 )
169
- dout2 = Ref (3.0 )
170
- Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (metasumsq), Const (metaconcat), BatchDuplicated (x, (dx, dx2)))
171
- @test tupapprox (dx, [(4.0 , 6.0 ), (15.8 , 22.4 )])
172
- @test tupapprox (dx2, [(3 * 4.0 , 3 * 6.0 ), (3 * 15.8 , 3 * 22.4 )])
173
-
174
- dx = [(0.0 , 0.0 ), (0.0 , 0.0 )]
175
- dx2 = [(0.0 , 0.0 ), (0.0 , 0.0 )]
176
- out = Ref (0.0 )
177
- dout = Ref (1.0 )
178
- dout2 = Ref (3.0 )
179
- Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicated (out, (dout, dout2)), Const (metasumsq), Const (metaconcat), BatchDuplicated (x, (dx, dx2)))
180
- @test out[] ≈ 200.84999999999997
181
- @test tupapprox (dx, [(4.0 , 6.0 ), (15.8 , 22.4 )])
182
- @test tupapprox (dx2, [(3 * 4.0 , 3 * 6.0 ), (3 * 15.8 , 3 * 22.4 )])
183
-
184
- x = [[2.0 , 3.0 ], [7.9 , 11.2 ]]
185
- dx = [[0.0 , 0.0 ], [0.0 , 0.0 ]]
186
- dx2 = [[0.0 , 0.0 ], [0.0 , 0.0 ]]
187
- out = Ref (0.0 )
188
- dout = Ref (1.0 )
189
- dout2 = Ref (3.0 )
190
-
191
- Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (metasumsq2), Const (metaconcat), BatchDuplicated (x, (dx, dx2)))
192
- @test dx ≈ [[4.0 , 6.0 ], [15.8 , 22.4 ]]
193
- @test dx2 ≈ [[3 * 4.0 , 3 * 6.0 ], [3 * 15.8 , 3 * 22.4 ]]
194
-
195
- dx = [[0.0 , 0.0 ], [0.0 , 0.0 ]]
196
- dx2 = [[0.0 , 0.0 ], [0.0 , 0.0 ]]
197
- out = Ref (0.0 )
198
- dout = Ref (1.0 )
199
- dout2 = Ref (3.0 )
200
- Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicated (out, (dout, dout2)), Const (metasumsq2), Const (metaconcat), BatchDuplicated (x, (dx, dx2)))
201
-
202
- @test out[] ≈ 200.84999999999997
203
- @test tupapprox (dx, [[4.0 , 6.0 ], [15.8 , 22.4 ]])
204
- @test tupapprox (dx2, [[3 * 4.0 , 3 * 6.0 ], [3 * 15.8 , 3 * 22.4 ]])
205
-
206
-
207
- x = [(2.0 , 3.0 ), (7.9 , 11.2 )]
208
- dx = [(0.0 , 0.0 ), (0.0 , 0.0 )]
209
- dx2 = [(0.0 , 0.0 ), (0.0 , 0.0 )]
210
-
211
166
y = [(13 , 17 ), (25 , 31 )]
212
- out = Ref (0.0 )
213
- dout = Ref (1.0 )
214
- dout2 = Ref (3.0 )
215
- Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (metasumsq3), Const (metaconcat2), BatchDuplicated (x, (dx, dx2)), Const (y))
216
- @test tupapprox (dx, [(4.0 , 6.0 ), (15.8 , 22.4 )])
217
- @test tupapprox (dx2, [(3 * 4.0 , 3 * 6.0 ), (3 * 15.8 , 3 * 22.4 )])
218
-
219
-
220
- x = [(2.0 , 3.0 ), (7.9 , 11.2 )]
221
- dx = [(0.0 , 0.0 ), (0.0 , 0.0 )]
222
- dx2 = [(0.0 , 0.0 ), (0.0 , 0.0 )]
223
- y = [(13 , 17 ), (25 , 31 )]
224
- dy = [(0 , 0 ), (0 , 0 )]
225
- dy2 = [(0 , 0 ), (0 , 0 )]
226
- out = Ref (0.0 )
227
- dout = Ref (1.0 )
228
- dout2 = Ref (3.0 )
229
- Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (metasumsq3),Const (metaconcat2), BatchDuplicated (x, (dx, dx2)), BatchDuplicated (y, (dy, dy2)))
230
- @test tupapprox (dx, [(4.0 , 6.0 ), (15.8 , 22.4 )])
231
- @test tupapprox (dx2, [(3 * 4.0 , 3 * 6.0 ), (3 * 15.8 , 3 * 22.4 )])
232
-
233
-
234
- x = [[2.0 , 3.0 ], [7.9 , 11.2 ]]
235
- dx = [[0.0 , 0.0 ], [0.0 , 0.0 ]]
236
- dx2 = [[0.0 , 0.0 ], [0.0 , 0.0 ]]
237
- y = [[13 , 17 ], [25 , 31 ]]
238
- out = Ref (0.0 )
239
- dout = Ref (1.0 )
240
- dout2 = Ref (3.0 )
241
- Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicated (out, (dout, dout2)), Const (metasumsq4), Const (metaconcat2), BatchDuplicated (x, (dx, dx2)), Const (y))
242
- @test tupapprox (dx, [[4.0 , 6.0 ], [15.8 , 22.4 ]])
243
- @test tupapprox (dx2, [[3 * 4.0 , 3 * 6.0 ], [3 * 15.8 , 3 * 22.4 ]])
167
+ dy_const = [(0 , 0 ), (0 , 0 )]
168
+ primal = 200.84999999999997
169
+ out_pre, dout_pre, dout2_pre = 0.0 , 1.0 , 3.0
170
+ @testset " tuple $label " for (label, dx_pre, dx_post, dx2_post) in [
171
+ (
172
+ " dx == 0" ,
173
+ [(0.0 , 0.0 ), (0.0 , 0.0 )],
174
+ [(4.0 , 6.0 ), (15.8 , 22.4 )],
175
+ [(3 * 4.0 , 3 * 6.0 ), (3 * 15.8 , 3 * 22.4 )],
176
+ ),
177
+ (
178
+ " dx != 0" ,
179
+ [(1.0 , - 2.0 ), (- 3.0 , 4.0 )],
180
+ [(5.0 , 4.0 ), (12.8 , 26.4 )],
181
+ [(1.0 + 3 * 4.0 , - 2.0 + 3 * 6.0 ), (- 3.0 + 3 * 15.8 , 4.0 + 3 * 22.4 )],
182
+ ),
183
+ ]
184
+ dx, dx2 = deepcopy .((dx_pre, dx_pre))
185
+ out, dout, dout2 = Ref .((out_pre, dout_pre, dout2_pre))
186
+ Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (metasumsq), Const (metaconcat), BatchDuplicated (x, (dx, dx2)))
187
+ @test dout[] ≈ 0
188
+ @test dout2[] ≈ 0
189
+ @test tupapprox (dx, dx_post)
190
+ @test tupapprox (dx2, dx2_post)
191
+
192
+ dx, dx2 = deepcopy .((dx_pre, dx_pre))
193
+ out, dout, dout2 = Ref .((out_pre, dout_pre, dout2_pre))
194
+ Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicated (out, (dout, dout2)), Const (metasumsq), Const (metaconcat), BatchDuplicated (x, (dx, dx2)))
195
+ @test out[] ≈ primal
196
+ @test dout[] ≈ 0
197
+ @test dout2[] ≈ 0
198
+ @test tupapprox (dx, dx_post)
199
+ @test tupapprox (dx2, dx2_post)
200
+
201
+ dx, dx2 = deepcopy .((dx_pre, dx_pre))
202
+ out, dout, dout2 = Ref .((out_pre, dout_pre, dout2_pre))
203
+ Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (metasumsq3), Const (metaconcat2), BatchDuplicated (x, (dx, dx2)), Const (y))
204
+ @test dout[] ≈ 0
205
+ @test dout2[] ≈ 0
206
+ @test tupapprox (dx, dx_post)
207
+ @test tupapprox (dx2, dx2_post)
208
+
209
+ dx, dx2 = deepcopy .((dx_pre, dx_pre))
210
+ dy, dy2 = deepcopy .((dy_const, dy_const))
211
+ out, dout, dout2 = Ref .((out_pre, dout_pre, dout2_pre))
212
+ Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (metasumsq3), Const (metaconcat2), BatchDuplicated (x, (dx, dx2)), BatchDuplicated (y, (dy, dy2)))
213
+ @test dout[] ≈ 0
214
+ @test dout2[] ≈ 0
215
+ @test tupapprox (dx, dx_post)
216
+ @test tupapprox (dx2, dx2_post)
217
+ @test tupapprox (dy, dy_const)
218
+ @test tupapprox (dy2, dy_const)
219
+ end
244
220
245
221
x = [[2.0 , 3.0 ], [7.9 , 11.2 ]]
246
- dx = [[0.0 , 0.0 ], [0.0 , 0.0 ]]
247
- dx2 = [[0.0 , 0.0 ], [0.0 , 0.0 ]]
248
222
y = [[13 , 17 ], [25 , 31 ]]
249
- dy = [[0 , 0 ], [0 , 0 ]]
250
- dy2 = [[0 , 0 ], [0 , 0 ]]
251
- out = Ref (0.0 )
252
- dout = Ref (1.0 )
253
- dout2 = Ref (3.0 )
254
- Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicated (out, (dout, dout2)), Const (metasumsq4), Const (metaconcat2), BatchDuplicated (x, (dx, dx2)), BatchDuplicated (y, (dy, dy2)))
255
- @test tupapprox (dx, [[4.0 , 6.0 ], [15.8 , 22.4 ]])
256
- @test tupapprox (dx2, [[3 * 4.0 , 3 * 6.0 ], [3 * 15.8 , 3 * 22.4 ]])
223
+ dy_const = [[0 , 0 ], [0 , 0 ]]
224
+ primal = 200.84999999999997
225
+ out_pre, dout_pre, dout2_pre = 0.0 , 1.0 , 3.0
226
+ @testset " tuple $label " for (label, dx_pre, dx_post, dx2_post) in [
227
+ (
228
+ " dx == 0" ,
229
+ [[0.0 , 0.0 ], [0.0 , 0.0 ]],
230
+ [[4.0 , 6.0 ], [15.8 , 22.4 ]],
231
+ [[3 * 4.0 , 3 * 6.0 ], [3 * 15.8 , 3 * 22.4 ]],
232
+ ),
233
+ (
234
+ " dx != 0" ,
235
+ [[1.0 , - 2.0 ], [- 3.0 , 4.0 ]],
236
+ [[5.0 , 4.0 ], [12.8 , 26.4 ]],
237
+ [[1.0 + 3 * 4.0 , - 2.0 + 3 * 6.0 ], [- 3.0 + 3 * 15.8 , 4.0 + 3 * 22.4 ]],
238
+ ),
239
+ ]
240
+ dx, dx2 = deepcopy .((dx_pre, dx_pre))
241
+ out, dout, dout2 = Ref .((out_pre, dout_pre, dout2_pre))
242
+ Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (metasumsq2), Const (metaconcat), BatchDuplicated (x, (dx, dx2)))
243
+ @test dout[] ≈ 0
244
+ @test dout2[] ≈ 0
245
+ @test dx ≈ dx_post
246
+ @test dx2 ≈ dx2_post
247
+
248
+ dx, dx2 = deepcopy .((dx_pre, dx_pre))
249
+ out, dout, dout2 = Ref .((out_pre, dout_pre, dout2_pre))
250
+ Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicated (out, (dout, dout2)), Const (metasumsq2), Const (metaconcat), BatchDuplicated (x, (dx, dx2)))
251
+ @test out[] ≈ primal
252
+ @test dout[] ≈ 0
253
+ @test dout2[] ≈ 0
254
+ @test dx ≈ dx_post
255
+ @test dx2 ≈ dx2_post
256
+
257
+ dx, dx2 = deepcopy .((dx_pre, dx_pre))
258
+ out, dout, dout2 = Ref .((out_pre, dout_pre, dout2_pre))
259
+ Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (metasumsq4), Const (metaconcat2), BatchDuplicated (x, (dx, dx2)), Const (y))
260
+ @test dout[] ≈ 0
261
+ @test dout2[] ≈ 0
262
+ @test dx ≈ dx_post
263
+ @test dx2 ≈ dx2_post
264
+
265
+ dx, dx2 = deepcopy .((dx_pre, dx_pre))
266
+ dy, dy2 = deepcopy .((dy_const, dy_const))
267
+ out, dout, dout2 = Ref .((out_pre, dout_pre, dout2_pre))
268
+ Enzyme. autodiff (Reverse, make_byref, Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (metasumsq4), Const (metaconcat2), BatchDuplicated (x, (dx, dx2)), BatchDuplicated (y, (dy, dy2)))
269
+ @test dout[] ≈ 0
270
+ @test dout2[] ≈ 0
271
+ @test dx ≈ dx_post
272
+ @test dx2 ≈ dx2_post
273
+ @test dy ≈ dy_const
274
+ @test dy2 ≈ dy_const
275
+ end
257
276
end
258
277
259
278
@testset " Forward Apply iterate" begin
502
521
@test ddata[1 ][1 ] ≈ 6.0
503
522
end
504
523
505
- include (" mixedapplyiter.jl" )
524
+ include (" mixedapplyiter.jl" )
0 commit comments