Skip to content

Commit f4621e7

Browse files
committed
WIP: premapping based cyclic type zeros
1 parent fe63c33 commit f4621e7

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

src/tangent_types/abstract_zero.jl

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,19 @@ end
138138
)
139139
Expr(:kw, fname, fval)
140140
end
141-
return if has_mutable_tangent(primal)
141+
142+
# easy case exit early, can't hold references, can't be a reference.
143+
if isbitstype(primal)
144+
return :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...))))
145+
end
146+
147+
# hard case need to be prepared for cycic references to this, or that are contained within this
148+
quote
149+
counts = $count_references!(primal)
150+
end
151+
152+
## TODO rewrite below
153+
has_mutable_tangent(primal)
142154
any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
143155
# If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent
144156
fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype))
@@ -171,6 +183,36 @@ function zero_tangent(x::Array{P,N}) where {P,N}
171183
return y
172184
end
173185

186+
###############################################
187+
count_references!(x) = count_references(IdDict{Any, Int}(), x)
188+
function count_references!(counts::IdDict{Any, Int}, x)
189+
isbits(x) && return counts # can't be a refernece and can't hold a reference
190+
counts[x] = get(counts, x, 0) + 1 # Increment *before* recursing
191+
if counts[x] == 1 # Only recurse the first time
192+
for ii in fieldcount(typeof(x))
193+
field = getfield(x, ii)
194+
count_references!(counts, field)
195+
end
196+
end
197+
return counts
198+
end
199+
200+
function count_references!(counts::IdDict{Any, Int}, x::Array)
201+
counts[x] = get(counts, x, 0) + 1 # increment before recursing
202+
isbitstype(eltype(x)) && return counts # no need to look inside, it can't hold references
203+
if counts[x] == 1 # only recurse the first time
204+
for ele in x
205+
count_references!(counts, ele)
206+
end
207+
end
208+
return counts
209+
end
210+
211+
count_references!(counts::IdDict{Any, Int}, ::DataType) = counts
212+
213+
###############################################
214+
215+
174216
# Sad heauristic methods we need because of unassigned values
175217
guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
176218
guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T)))

test/tangent_types/abstract_zero.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,4 +275,36 @@ end
275275
@test d.z == [2.0, 3.0]
276276
@test d.z isa SubArray
277277
end
278+
279+
280+
@testset "cyclic references" begin
281+
mutable struct Link
282+
data::Float64
283+
next::Link
284+
Link(data) = new(data)
285+
end
286+
287+
lk = Link(1.5)
288+
lk.next = lk
289+
290+
d = zero_tangent(lk)
291+
@test d.data == 0.0
292+
@test d.next === d
293+
294+
struct CarryingArray
295+
x::Vector
296+
end
297+
ca = CarryingArray(Any[1.5])
298+
push!(ca.x, ca)
299+
@test d_ca = zero_tangent(ca)
300+
@test d_ca[1] == 0.0
301+
@test d_ca[2] === _ca
302+
303+
# Idea: check if typeof(xs) <: eltype(xs), if so need to cache it before computing
304+
xs = Any[1.5]
305+
push!(xs, xs)
306+
@test d_xs = zero_tangent(xs)
307+
@test d_xs[1] == 0.0
308+
@test d_xs[2] == d_xs
309+
end
278310
end

0 commit comments

Comments
 (0)