@@ -62,7 +62,12 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::Uniform
62
62
sz = Bijectors. output_size (b, size (dist))
63
63
y = rand (rng, Uniform (u. lower, u. upper), sz)
64
64
b_inv = Bijectors. inverse (b)
65
- return b_inv (y)
65
+ x = b_inv (y)
66
+ # https://github.com/TuringLang/Bijectors.jl/issues/398
67
+ if x isa Array{<: Any ,0 }
68
+ x = x[]
69
+ end
70
+ return x
66
71
end
67
72
68
73
"""
@@ -134,12 +139,14 @@ function tilde_assume(
134
139
# `init()` always returns values in original space, i.e. possibly
135
140
# constrained
136
141
x = init (ctx. rng, vn, dist, ctx. strategy)
137
- # There is a function `to_maybe_linked_internal_transform` that does this,
138
- # but unfortunately it uses `istrans(vi, vn)` which fails if vn is not in
139
- # vi, so we have to manually check. By default we will insert an unlinked
140
- # value into the varinfo.
141
- is_transformed = in_varinfo ? istrans (vi, vn) : false
142
- f = if is_transformed
142
+ # Determine whether to insert a transformed value into the VarInfo.
143
+ # If the VarInfo alrady had a value for this variable, we will
144
+ # keep the same linked status as in the original VarInfo. If not, we
145
+ # check the rest of the VarInfo to see if other variables are linked.
146
+ # istrans(vi) returns true if vi is nonempty and all variables in vi
147
+ # are linked.
148
+ insert_transformed_value = in_varinfo ? istrans (vi, vn) : istrans (vi)
149
+ f = if insert_transformed_value
143
150
to_linked_internal_transform (vi, vn, dist)
144
151
else
145
152
to_internal_transform (vi, vn, dist)
@@ -150,7 +157,7 @@ function tilde_assume(
150
157
# always converts x to a vector, i.e., if dist is univariate, f(x) will be
151
158
# a vector of length 1. It would be nice if we could unify these.
152
159
y = f (x)
153
- logjac = logabsdetjac (is_transformed ? Bijectors . bijector (dist) : identity, x)
160
+ logjac = logabsdetjac (insert_transformed_value ? link_transform (dist) : identity, x)
154
161
# Add the new value to the VarInfo. `push!!` errors if the value already
155
162
# exists, hence the need for setindex!!
156
163
if in_varinfo
0 commit comments