Skip to content

Commit 2dd11c8

Browse files
committed
updates
1 parent c0171f0 commit 2dd11c8

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

src/projection.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,8 @@ end
264264

265265
# Tuple
266266
function ProjectTo(xs::Tuple)
267-
elements = map(xs) do x
268-
x isa Union{Number, AbstractArray{<:Number}} ? ProjectTo(x) : identity
269-
end
270-
if all(p -> p isa ProjectTo{<:AbstractZero}, elements)
267+
elements = map(ProjectTo, xs)
268+
if elements isa Tuple{Vararg{ProjectTo{<:AbstractZero}}}
271269
ProjectTo{NoTangent}() # short-circuit if all elements project to zero
272270
else
273271
return ProjectTo{Tuple}(; type=Val(typeof(xs)), elements=elements)
@@ -284,8 +282,16 @@ function (project::ProjectTo{Tuple})(dx::Tuple)
284282
dz = map((f, y) -> f(y), project.elements, dx)
285283
return Tangent{_val(project.type)}(dz...)
286284
end
287-
(project::ProjectTo{Tuple})(dx) = project(NTuple{length(project.elements)}(dx))
288-
(::ProjectTo{Tuple})(dx::AbstractZero) = dx # else ambiguous
285+
function (project::ProjectTo{Tuple})(dx::AbstractArray)
286+
for d in 1:ndims(dx)
287+
if size(dx, d) != get(size(project.elements), d, 1)
288+
throw(_projection_mismatch(axes(project.elements), size(dx)))
289+
end
290+
end
291+
dz = ntuple(i -> project.elements[i](dx[i]), length(project.elements))
292+
return Tangent{_val(project.type)}(dz...)
293+
end
294+
# (::ProjectTo{Tuple})(dx::AbstractZero) = dx # else ambiguous
289295

290296
# Ref
291297
function ProjectTo(x::Ref)

0 commit comments

Comments
 (0)