@@ -21,35 +21,35 @@ Random.seed!(1)
21
21
function readlrdata ()
22
22
fname = joinpath (" lr.data" )
23
23
z = readdlm (fname)
24
- A = z[:,1 : end - 1 ]
25
- A = [ones (size (A,1 )) A]
26
- y = z[:,end ] .- 1
24
+ A = z[:, 1 : end - 1 ]
25
+ A = [ones (size (A, 1 )) A]
26
+ y = z[:, end ] .- 1
27
27
return A, y
28
28
end
29
29
A, y = readlrdata ();
30
30
At = collect (A' );
31
31
32
32
model_lr = @model (At, y, σ) begin
33
- d,n = size (At)
34
- θ ~ Normal (σ= σ)^ d
33
+ d, n = size (At)
34
+ θ ~ Normal (σ = σ)^ d
35
35
for j in 1 : n
36
- logitp = dot (view (At,:, j), θ)
36
+ logitp = dot (view (At, :, j), θ)
37
37
y[j] ~ Bernoulli (logitp = logitp)
38
38
end
39
39
end
40
40
σ = 100.0
41
41
42
- function make_grads (model_lr, At, y, σ)
43
- post = model_lr (At, y, σ) | (;y)
42
+ function make_grads (model_lr, At, y, σ)
43
+ post = model_lr (At, y, σ) | (; y)
44
44
as_post = as (post)
45
45
obj (θ) = - Tilde. unsafe_logdensityof (post, transform (as_post, θ))
46
46
ℓ (θ) = - obj (θ)
47
47
@inline function dneglogp (t, x, v) # two directional derivatives
48
- f (t) = obj (x + t* v)
48
+ f (t) = obj (x + t * v)
49
49
u = ForwardDiff. derivative (f, Dual {:hSrkahPmmC} (0.0 , 1.0 ))
50
50
u. value, u. partials[]
51
51
end
52
-
52
+
53
53
gconfig = ForwardDiff. GradientConfig (obj, rand (25 ), ForwardDiff. Chunk {25} ())
54
54
function ∇neglogp! (y, t, x)
55
55
ForwardDiff. gradient! (y, obj, x, gconfig)
@@ -58,60 +58,88 @@ function make_grads(model_lr, At, y, σ)
58
58
post, ℓ, dneglogp, ∇neglogp!
59
59
end
60
60
61
- post, ℓ, dneglogp, ∇neglogp! = make_grads (model_lr, At, y, σ)
61
+ post, ℓ, dneglogp, ∇neglogp! = make_grads (model_lr, At, y, σ)
62
62
# Try things out
63
63
dneglogp (2.4 , randn (25 ), randn (25 ));
64
64
∇neglogp! (randn (25 ), 2.1 , randn (25 ));
65
65
66
-
67
66
d = 25 # number of parameters
68
67
t0 = 0.0 ;
69
68
x0 = zeros (d); # starting point sampler
70
69
# estimated posterior mean (n=100000, 797s)
71
- μ̂ = [3.406 , - 0.5918 , 0.0352 , - 0.3874 , 0.004481 , - 0.2346 , - 0.1495 , - 0.2184 , 0.01219 , 0.1731 , - 0.00976 , - 0.3224 , 0.2168 , 0.08002 , - 0.2829 , - 1.581 , 0.6666 , - 0.9984 , 1.081 , 1.405 , 0.327 , - 0.1357 , - 0.6446 , - 0.06583 , - 0.04994 ]
70
+ μ̂ = [
71
+ 3.406 ,
72
+ - 0.5918 ,
73
+ 0.0352 ,
74
+ - 0.3874 ,
75
+ 0.004481 ,
76
+ - 0.2346 ,
77
+ - 0.1495 ,
78
+ - 0.2184 ,
79
+ 0.01219 ,
80
+ 0.1731 ,
81
+ - 0.00976 ,
82
+ - 0.3224 ,
83
+ 0.2168 ,
84
+ 0.08002 ,
85
+ - 0.2829 ,
86
+ - 1.581 ,
87
+ 0.6666 ,
88
+ - 0.9984 ,
89
+ 1.081 ,
90
+ 1.405 ,
91
+ 0.327 ,
92
+ - 0.1357 ,
93
+ - 0.6446 ,
94
+ - 0.06583 ,
95
+ - 0.04994 ,
96
+ ]
72
97
n = 2000
73
98
c = 4.0 # initial guess for the bound
74
99
75
- init_scale= 1 ;
76
- @time pf_result = pathfinder (ℓ; dim= d, init_scale);
100
+ init_scale = 1 ;
101
+ @time pf_result = pathfinder (ℓ; dim = d, init_scale);
77
102
M = PDMats. PDiagMat (diag (pf_result. fit_distribution. Σ));
78
103
M = pf_result. fit_distribution. Σ;
79
104
x0 = pf_result. fit_distribution. μ;
80
105
v0 = PDMats. unwhiten (M, randn (length (x0)));
81
106
82
-
83
-
84
-
85
-
86
107
MAP = pf_result. optim_solution; # MAP, could be useful for control variates
87
108
88
109
# define BouncyParticle sampler (has two relevant parameters)
89
- Z = BouncyParticle (missing , # graphical structure
110
+ Z = BouncyParticle (
111
+ missing , # graphical structure
90
112
MAP, # MAP estimate, unused
91
113
2.0 , # momentum refreshment rate and sample saving rate
92
114
0.95 , # momentum correlation / only gradually change momentum in refreshment/momentum update
93
115
M, # metric (PDMat compatible object for momentum covariance)
94
- missing # legacy
95
- ) ;
96
-
97
- sampler = ZZB. NotFactSampler (Z, (dneglogp, ∇neglogp!), ZZB. LocalBound (c), t0 => (x0, v0), ZZB. Rng (ZZB. Seed ()), (),
98
- (; adapt= true , # adapt bound c
99
- subsample= true , # keep only samples at refreshment times
100
- ));
101
-
116
+ missing , # legacy
117
+ );
118
+
119
+ sampler = ZZB. NotFactSampler (
120
+ Z,
121
+ (dneglogp, ∇neglogp!),
122
+ ZZB. LocalBound (c),
123
+ t0 => (x0, v0),
124
+ ZZB. Rng (ZZB. Seed ()),
125
+ (),
126
+ (;
127
+ adapt = true , # adapt bound c
128
+ subsample = true , # keep only samples at refreshment times
129
+ ),
130
+ );
102
131
103
132
using TupleVectors: chainvec
104
133
using Tilde. MeasureTheory: transform
105
134
106
-
107
- function collect_sampler (t, sampler, n; progress= true , progress_stops= 20 )
135
+ function collect_sampler (t, sampler, n; progress = true , progress_stops = 20 )
108
136
if progress
109
137
prg = Progress (progress_stops, 1 )
110
138
else
111
139
prg = missing
112
140
end
113
141
stops = ismissing (prg) ? 0 : max (prg. n - 1 , 0 ) # allow one stop for cleanup
114
- nstop = n/ stops
142
+ nstop = n / stops
115
143
116
144
x1 = transform (t, sampler. u0[2 ][1 ])
117
145
tv = chainvec (x1, n)
@@ -124,42 +152,54 @@ function collect_sampler(t, sampler, n; progress=true, progress_stops=20)
124
152
tv[j] = transform (t, val[2 ])
125
153
ϕ = iterate (sampler, state)
126
154
if j > nstop
127
- nstop += n/ stops
128
- next! (prg)
129
- end
155
+ nstop += n / stops
156
+ next! (prg)
157
+ end
130
158
end
131
159
ismissing (prg) || ProgressMeter. finish! (prg)
132
- tv, (;uT = state[1 ], acc= state[3 ][1 ], total= state[3 ][2 ], bound= state[4 ]. c)
160
+ tv, (; uT = state[1 ], acc = state[3 ][1 ], total = state[3 ][2 ], bound = state[4 ]. c)
133
161
end
134
- collect_sampler (as (post), sampler, 10 ; progress= false );
162
+ collect_sampler (as (post), sampler, 10 ; progress = false );
135
163
136
164
elapsed_time = @elapsed @time begin
137
- global bps_samples, info
138
- bps_samples, info = collect_sampler (as (post), sampler, n; progress= false )
165
+ global bps_samples, info
166
+ bps_samples, info = collect_sampler (as (post), sampler, n; progress = false )
139
167
end
140
168
141
169
using MCMCChains
142
170
bps_chain = MCMCChains. Chains (bps_samples. θ);
143
- bps_chain = setinfo (bps_chain, (;start_time= 0.0 , stop_time = elapsed_time));
171
+ bps_chain = setinfo (bps_chain, (; start_time = 0.0 , stop_time = elapsed_time));
144
172
145
- μ̂1 = round .(mean (bps_chain). nt[:mean ], sigdigits= 4 )
173
+ μ̂1 = round .(mean (bps_chain). nt[:mean ], sigdigits = 4 )
146
174
println (" μ̂ (BPS) = " , μ̂1)
147
175
148
176
using SampleChainsDynamicHMC
149
177
init_params = pf_result. draws[:, 1 ];
150
178
inv_metric = (pf_result. fit_distribution. Σ);
151
- Tilde. sample (post, dynamichmc (
152
- ;init= (; q= init_params, κ= GaussianKineticEnergy (inv_metric)),
153
- warmup_stages= default_warmup_stages (; middle_steps= 0 , doubling_stages= 0 ),
154
- ), 1 ,1 );
155
- hmc_time = @elapsed @time (hmc_samples = Tilde. sample (post, dynamichmc (
156
- ;init= (; q= init_params, κ= GaussianKineticEnergy (inv_metric)),
157
- warmup_stages= default_warmup_stages (; middle_steps= 0 , doubling_stages= 0 ),
158
- ), 2000 ,1 ));
179
+ Tilde. sample (
180
+ post,
181
+ dynamichmc (;
182
+ init = (; q = init_params, κ = GaussianKineticEnergy (inv_metric)),
183
+ warmup_stages = default_warmup_stages (; middle_steps = 0 , doubling_stages = 0 ),
184
+ ),
185
+ 1 ,
186
+ 1 ,
187
+ );
188
+ hmc_time = @elapsed @time (
189
+ hmc_samples = Tilde. sample (
190
+ post,
191
+ dynamichmc (;
192
+ init = (; q = init_params, κ = GaussianKineticEnergy (inv_metric)),
193
+ warmup_stages = default_warmup_stages (; middle_steps = 0 , doubling_stages = 0 ),
194
+ ),
195
+ 2000 ,
196
+ 1 ,
197
+ )
198
+ );
159
199
hmc_chain = MCMCChains. Chains (hmc_samples. θ);
160
- μ̂2 = round .(mean (hmc_chain). nt[:mean ], sigdigits= 4 );
200
+ μ̂2 = round .(mean (hmc_chain). nt[:mean ], sigdigits = 4 );
161
201
println (" μ̂ (HMC) = " , μ̂2)
162
- hmc_chain = MCMCChains. setinfo (hmc_chain, (;start_time= 0.0 , stop_time = hmc_time));
202
+ hmc_chain = MCMCChains. setinfo (hmc_chain, (; start_time = 0.0 , stop_time = hmc_time));
163
203
164
204
ess_bps = MCMCChains. ess_rhat (bps_chain). nt. ess_per_sec;
165
205
ess_hmc = MCMCChains. ess_rhat (hmc_chain). nt. ess_per_sec;
@@ -173,4 +213,4 @@ ylabel!(plt, "DynamicHMC");
173
213
plt_bounds = collect (extrema (ess_hmc));
174
214
lineplot! (plt, plt_bounds, plt_bounds);
175
215
plt
176
- @info " For each coordinate, a point (x,y) shows the effective sample size per second for BPS (x) and HMC (y) . In blue is the diagonal x=y"
216
+ @info " For each coordinate, a point (x,y) shows the effective sample size per second for BPS (x) and HMC (y) . In blue is the diagonal x=y"
0 commit comments