Skip to content

Commit 8588802

Browse files
committed
Fix more tests
1 parent 9d754ff commit 8588802

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,7 @@ function DynamicPPL.predict(
124124
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
125125
predictive_samples = map(iters) do (sample_idx, chain_idx)
126126
# Extract values from the chain
127-
values_dict = DynamicPPL.chain_sample_to_varname_dict(
128-
parameter_only_chain, sample_idx, chain_idx
129-
)
127+
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
130128
# Resample any variables that are not present in `values_dict`
131129
_, varinfo = last(
132130
DynamicPPL.init!!(
@@ -268,9 +266,7 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
268266
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
269267
return map(iters) do (sample_idx, chain_idx)
270268
# Extract values from the chain
271-
values_dict = DynamicPPL.chain_sample_to_varname_dict(
272-
parameter_only_chain, sample_idx, chain_idx
273-
)
269+
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
274270
# Resample any variables that are not present in `values_dict`, and
275271
# return the model's retval (`first`).
276272
first(

src/contexts/init.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,12 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::Uniform
6262
sz = Bijectors.output_size(b, size(dist))
6363
y = rand(rng, Uniform(u.lower, u.upper), sz)
6464
b_inv = Bijectors.inverse(b)
65-
return b_inv(y)
65+
x = b_inv(y)
66+
# https://github.com/TuringLang/Bijectors.jl/issues/398
67+
if x isa Array{<:Any,0}
68+
x = x[]
69+
end
70+
return x
6671
end
6772

6873
"""
@@ -134,12 +139,14 @@ function tilde_assume(
134139
# `init()` always returns values in original space, i.e. possibly
135140
# constrained
136141
x = init(ctx.rng, vn, dist, ctx.strategy)
137-
# There is a function `to_maybe_linked_internal_transform` that does this,
138-
# but unfortunately it uses `istrans(vi, vn)` which fails if vn is not in
139-
# vi, so we have to manually check. By default we will insert an unlinked
140-
# value into the varinfo.
141-
is_transformed = in_varinfo ? istrans(vi, vn) : false
142-
f = if is_transformed
142+
# Determine whether to insert a transformed value into the VarInfo.
143+
# If the VarInfo alrady had a value for this variable, we will
144+
# keep the same linked status as in the original VarInfo. If not, we
145+
# check the rest of the VarInfo to see if other variables are linked.
146+
# istrans(vi) returns true if vi is nonempty and all variables in vi
147+
# are linked.
148+
insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi)
149+
f = if insert_transformed_value
143150
to_linked_internal_transform(vi, vn, dist)
144151
else
145152
to_internal_transform(vi, vn, dist)
@@ -150,7 +157,7 @@ function tilde_assume(
150157
# always converts x to a vector, i.e., if dist is univariate, f(x) will be
151158
# a vector of length 1. It would be nice if we could unify these.
152159
y = f(x)
153-
logjac = logabsdetjac(is_transformed ? Bijectors.bijector(dist) : identity, x)
160+
logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x)
154161
# Add the new value to the VarInfo. `push!!` errors if the value already
155162
# exists, hence the need for setindex!!
156163
if in_varinfo

test/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ end
4343
end
4444
model = gdemo(1.0, 2.0)
4545

46-
vi = DynamicPPL.untyped_varinfo(model, SampleFromUniform())
46+
_, vi = DynamicPPL.init!!(model, VarInfo(), UniformInit())
4747
tvi = DynamicPPL.typed_varinfo(vi)
4848

4949
meta = vi.metadata

0 commit comments

Comments
 (0)