Skip to content

Commit 2391a06

Browse files
committed
Fix tests
1 parent fd58050 commit 2391a06

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

ext/SliceSamplingTuringExt.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,17 @@ Turing.Inference.isgibbscomponent(::SliceSampling.Slice) = true
2222
Turing.Inference.isgibbscomponent(::SliceSampling.SliceSteppingOut) = true
2323
Turing.Inference.isgibbscomponent(::SliceSampling.SliceDoublingOut) = true
2424

25-
function Turing.Inference.getparams(
26-
::Turing.DynamicPPL.Model, sample::SliceSampling.UnivariateSliceState
27-
)
25+
const SliceSamplingStates = Union{
26+
SliceSampling.UnivariateSliceState,
27+
SliceSampling.GibbsState,
28+
SliceSampling.HitAndRunState,
29+
SliceSampling.LatentSliceState,
30+
SliceSampling.GibbsPolarSliceState,
31+
}
32+
function Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSamplingStates)
2833
return sample.transition.params
2934
end
3035

31-
function Turing.Inference.getparams(
32-
::Turing.DynamicPPL.Model, state::SliceSampling.GibbsState
33-
)
34-
return state.transition.params
35-
end
36-
37-
function Turing.Inference.getparams(
38-
::Turing.DynamicPPL.Model, state::SliceSampling.HitAndRunState
39-
)
40-
return state.transition.params
41-
end
42-
4336
function Turing.Inference.getlogp_external(
4437
::Turing.DynamicPPL.Model, t::SliceSampling.Transition, state
4538
)

test/turing.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
return nothing
99
end
1010

11-
@model logp_check() = x ~ Normal()
11+
@model function logp_check()
12+
a ~ Normal()
13+
return b ~ Normal()
14+
end
1215

1316
n_samples = 1000
1417
model = demo()
@@ -36,7 +39,11 @@
3639
chain_logp_check = sample(
3740
logp_check(), externalsampler(sampler), 100; progress=false
3841
)
39-
@test isapprox(logpdf.(Normal(), chain_logp_check[:x]), chain_logp_check[:logp])
42+
@test isapprox(
43+
logpdf.(Normal(), chain_logp_check[:a]) .+
44+
logpdf.(Normal(), chain_logp_check[:b]),
45+
chain_logp_check[:lp],
46+
)
4047
end
4148

4249
@testset "gibbs($sampler)" for sampler in [
@@ -55,8 +62,15 @@
5562
)
5663

5764
chain_logp_check = sample(
58-
logp_check(), Turing.Gibbs(:x => externalsampler(sampler)), 100; progress=false
65+
logp_check(),
66+
Turing.Gibbs(:a => externalsampler(sampler), :b => externalsampler(sampler)),
67+
100;
68+
progress=false,
69+
)
70+
@test isapprox(
71+
logpdf.(Normal(), chain_logp_check[:a]) .+
72+
logpdf.(Normal(), chain_logp_check[:b]),
73+
chain_logp_check[:lp],
5974
)
60-
@test isapprox(logpdf.(Normal(), chain_logp_check[:x]), chain_logp_check[:logp])
6175
end
6276
end

0 commit comments

Comments
 (0)