264
264
265
265
# Tuple
266
266
function ProjectTo (xs:: Tuple )
267
- elements = map (xs) do x
268
- x isa Union{Number, AbstractArray{<: Number }} ? ProjectTo (x) : identity
269
- end
270
- if all (p -> p isa ProjectTo{<: AbstractZero }, elements)
267
+ elements = map (ProjectTo, xs)
268
+ if elements isa Tuple{Vararg{ProjectTo{<: AbstractZero }}}
271
269
ProjectTo {NoTangent} () # short-circuit if all elements project to zero
272
270
else
273
271
return ProjectTo {Tuple} (; type= Val (typeof (xs)), elements= elements)
@@ -284,8 +282,16 @@ function (project::ProjectTo{Tuple})(dx::Tuple)
284
282
dz = map ((f, y) -> f (y), project. elements, dx)
285
283
return Tangent {_val(project.type)} (dz... )
286
284
end
287
- (project:: ProjectTo{Tuple} )(dx) = project (NTuple {length(project.elements)} (dx))
288
- (:: ProjectTo{Tuple} )(dx:: AbstractZero ) = dx # else ambiguous
285
+ function (project:: ProjectTo{Tuple} )(dx:: AbstractArray )
286
+ for d in 1 : ndims (dx)
287
+ if size (dx, d) != get (size (project. elements), d, 1 )
288
+ throw (_projection_mismatch (axes (project. elements), size (dx)))
289
+ end
290
+ end
291
+ dz = ntuple (i -> project. elements[i](dx[i]), length (project. elements))
292
+ return Tangent {_val(project.type)} (dz... )
293
+ end
294
+ # (::ProjectTo{Tuple})(dx::AbstractZero) = dx # else ambiguous
289
295
290
296
# Ref
291
297
function ProjectTo (x:: Ref )
0 commit comments