Skip to content

Commit 49c3ddc

Browse files
authored
Formatting (#33)
* foramatting * TOML
1 parent f08b89a commit 49c3ddc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+761
-604
lines changed

.JuliaFormatter.toml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
indent = 4
2+
margin = 92
3+
always_for_in = true
4+
whitespace_typedefs = false
5+
whitespace_ops_in_indices = false
6+
remove_extra_newlines = true
7+
import_to_using = false
8+
pipe_to_function_call = false
9+
short_to_long_function_def = true
10+
always_use_return = false
11+
whitespace_in_kwargs = true
12+
annotate_untyped_fields_with_any = false
13+
format_docstrings = false
14+
align_struct_field = true
15+
align_conditional = true
16+
align_assignment = true
17+
align_pair_arrow = true
18+
conditional_to_if = true
19+
normalize_line_endings = "unix"
20+
align_matrix = false

benchmarks/bouncy.jl

Lines changed: 91 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,35 @@ Random.seed!(1)
2121
function readlrdata()
2222
fname = joinpath("lr.data")
2323
z = readdlm(fname)
24-
A = z[:,1:end-1]
25-
A = [ones(size(A,1)) A]
26-
y = z[:,end] .- 1
24+
A = z[:, 1:end-1]
25+
A = [ones(size(A, 1)) A]
26+
y = z[:, end] .- 1
2727
return A, y
2828
end
2929
A, y = readlrdata();
3030
At = collect(A');
3131

3232
model_lr = @model (At, y, σ) begin
33-
d,n = size(At)
34-
θ ~ Normal=σ)^d
33+
d, n = size(At)
34+
θ ~ Normal = σ)^d
3535
for j in 1:n
36-
logitp = dot(view(At,:,j), θ)
36+
logitp = dot(view(At, :, j), θ)
3737
y[j] ~ Bernoulli(logitp = logitp)
3838
end
3939
end
4040
σ = 100.0
4141

42-
function make_grads(model_lr, At, y, σ)
43-
post = model_lr(At, y, σ) | (;y)
42+
function make_grads(model_lr, At, y, σ)
43+
post = model_lr(At, y, σ) | (; y)
4444
as_post = as(post)
4545
obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ))
4646
(θ) = -obj(θ)
4747
@inline function dneglogp(t, x, v) # two directional derivatives
48-
f(t) = obj(x + t*v)
48+
f(t) = obj(x + t * v)
4949
u = ForwardDiff.derivative(f, Dual{:hSrkahPmmC}(0.0, 1.0))
5050
u.value, u.partials[]
5151
end
52-
52+
5353
gconfig = ForwardDiff.GradientConfig(obj, rand(25), ForwardDiff.Chunk{25}())
5454
function ∇neglogp!(y, t, x)
5555
ForwardDiff.gradient!(y, obj, x, gconfig)
@@ -58,60 +58,88 @@ function make_grads(model_lr, At, y, σ)
5858
post, ℓ, dneglogp, ∇neglogp!
5959
end
6060

61-
post, ℓ, dneglogp, ∇neglogp! = make_grads(model_lr, At, y, σ)
61+
post, ℓ, dneglogp, ∇neglogp! = make_grads(model_lr, At, y, σ)
6262
# Try things out
6363
dneglogp(2.4, randn(25), randn(25));
6464
∇neglogp!(randn(25), 2.1, randn(25));
6565

66-
6766
d = 25 # number of parameters
6867
t0 = 0.0;
6968
x0 = zeros(d); # starting point sampler
7069
# estimated posterior mean (n=100000, 797s)
71-
μ̂ = [3.406, -0.5918, 0.0352, -0.3874, 0.004481, -0.2346, -0.1495, -0.2184, 0.01219, 0.1731, -0.00976, -0.3224, 0.2168, 0.08002, -0.2829, -1.581, 0.6666, -0.9984, 1.081, 1.405, 0.327, -0.1357, -0.6446, -0.06583, -0.04994]
70+
μ̂ = [
71+
3.406,
72+
-0.5918,
73+
0.0352,
74+
-0.3874,
75+
0.004481,
76+
-0.2346,
77+
-0.1495,
78+
-0.2184,
79+
0.01219,
80+
0.1731,
81+
-0.00976,
82+
-0.3224,
83+
0.2168,
84+
0.08002,
85+
-0.2829,
86+
-1.581,
87+
0.6666,
88+
-0.9984,
89+
1.081,
90+
1.405,
91+
0.327,
92+
-0.1357,
93+
-0.6446,
94+
-0.06583,
95+
-0.04994,
96+
]
7297
n = 2000
7398
c = 4.0 # initial guess for the bound
7499

75-
init_scale=1;
76-
@time pf_result = pathfinder(ℓ; dim=d, init_scale);
100+
init_scale = 1;
101+
@time pf_result = pathfinder(ℓ; dim = d, init_scale);
77102
M = PDMats.PDiagMat(diag(pf_result.fit_distribution.Σ));
78103
M = pf_result.fit_distribution.Σ;
79104
x0 = pf_result.fit_distribution.μ;
80105
v0 = PDMats.unwhiten(M, randn(length(x0)));
81106

82-
83-
84-
85-
86107
MAP = pf_result.optim_solution; # MAP, could be useful for control variates
87108

88109
# define BouncyParticle sampler (has two relevant parameters)
89-
Z = BouncyParticle(missing, # graphical structure
110+
Z = BouncyParticle(
111+
missing, # graphical structure
90112
MAP, # MAP estimate, unused
91113
2.0, # momentum refreshment rate and sample saving rate
92114
0.95, # momentum correlation / only gradually change momentum in refreshment/momentum update
93115
M, # metric (PDMat compatible object for momentum covariance)
94-
missing # legacy
95-
) ;
96-
97-
sampler = ZZB.NotFactSampler(Z, (dneglogp, ∇neglogp!), ZZB.LocalBound(c), t0 => (x0, v0), ZZB.Rng(ZZB.Seed()), (),
98-
(; adapt=true, # adapt bound c
99-
subsample=true, # keep only samples at refreshment times
100-
));
101-
116+
missing, # legacy
117+
);
118+
119+
sampler = ZZB.NotFactSampler(
120+
Z,
121+
(dneglogp, ∇neglogp!),
122+
ZZB.LocalBound(c),
123+
t0 => (x0, v0),
124+
ZZB.Rng(ZZB.Seed()),
125+
(),
126+
(;
127+
adapt = true, # adapt bound c
128+
subsample = true, # keep only samples at refreshment times
129+
),
130+
);
102131

103132
using TupleVectors: chainvec
104133
using Tilde.MeasureTheory: transform
105134

106-
107-
function collect_sampler(t, sampler, n; progress=true, progress_stops=20)
135+
function collect_sampler(t, sampler, n; progress = true, progress_stops = 20)
108136
if progress
109137
prg = Progress(progress_stops, 1)
110138
else
111139
prg = missing
112140
end
113141
stops = ismissing(prg) ? 0 : max(prg.n - 1, 0) # allow one stop for cleanup
114-
nstop = n/stops
142+
nstop = n / stops
115143

116144
x1 = transform(t, sampler.u0[2][1])
117145
tv = chainvec(x1, n)
@@ -124,42 +152,54 @@ function collect_sampler(t, sampler, n; progress=true, progress_stops=20)
124152
tv[j] = transform(t, val[2])
125153
ϕ = iterate(sampler, state)
126154
if j > nstop
127-
nstop += n/stops
128-
next!(prg)
129-
end
155+
nstop += n / stops
156+
next!(prg)
157+
end
130158
end
131159
ismissing(prg) || ProgressMeter.finish!(prg)
132-
tv, (;uT=state[1], acc=state[3][1], total=state[3][2], bound=state[4].c)
160+
tv, (; uT = state[1], acc = state[3][1], total = state[3][2], bound = state[4].c)
133161
end
134-
collect_sampler(as(post), sampler, 10; progress=false);
162+
collect_sampler(as(post), sampler, 10; progress = false);
135163

136164
elapsed_time = @elapsed @time begin
137-
global bps_samples, info
138-
bps_samples, info = collect_sampler(as(post), sampler, n; progress=false)
165+
global bps_samples, info
166+
bps_samples, info = collect_sampler(as(post), sampler, n; progress = false)
139167
end
140168

141169
using MCMCChains
142170
bps_chain = MCMCChains.Chains(bps_samples.θ);
143-
bps_chain = setinfo(bps_chain, (;start_time=0.0, stop_time = elapsed_time));
171+
bps_chain = setinfo(bps_chain, (; start_time = 0.0, stop_time = elapsed_time));
144172

145-
μ̂1 = round.(mean(bps_chain).nt[:mean], sigdigits=4)
173+
μ̂1 = round.(mean(bps_chain).nt[:mean], sigdigits = 4)
146174
println("μ̂ (BPS) = ", μ̂1)
147175

148176
using SampleChainsDynamicHMC
149177
init_params = pf_result.draws[:, 1];
150178
inv_metric = (pf_result.fit_distribution.Σ);
151-
Tilde.sample(post, dynamichmc(
152-
;init=(; q=init_params, κ=GaussianKineticEnergy(inv_metric)),
153-
warmup_stages=default_warmup_stages(; middle_steps=0, doubling_stages=0),
154-
), 1,1);
155-
hmc_time = @elapsed @time (hmc_samples = Tilde.sample(post, dynamichmc(
156-
;init=(; q=init_params, κ=GaussianKineticEnergy(inv_metric)),
157-
warmup_stages=default_warmup_stages(; middle_steps=0, doubling_stages=0),
158-
), 2000,1));
179+
Tilde.sample(
180+
post,
181+
dynamichmc(;
182+
init = (; q = init_params, κ = GaussianKineticEnergy(inv_metric)),
183+
warmup_stages = default_warmup_stages(; middle_steps = 0, doubling_stages = 0),
184+
),
185+
1,
186+
1,
187+
);
188+
hmc_time = @elapsed @time (
189+
hmc_samples = Tilde.sample(
190+
post,
191+
dynamichmc(;
192+
init = (; q = init_params, κ = GaussianKineticEnergy(inv_metric)),
193+
warmup_stages = default_warmup_stages(; middle_steps = 0, doubling_stages = 0),
194+
),
195+
2000,
196+
1,
197+
)
198+
);
159199
hmc_chain = MCMCChains.Chains(hmc_samples.θ);
160-
μ̂2 = round.(mean(hmc_chain).nt[:mean], sigdigits=4);
200+
μ̂2 = round.(mean(hmc_chain).nt[:mean], sigdigits = 4);
161201
println("μ̂ (HMC) = ", μ̂2)
162-
hmc_chain = MCMCChains.setinfo(hmc_chain, (;start_time=0.0, stop_time = hmc_time));
202+
hmc_chain = MCMCChains.setinfo(hmc_chain, (; start_time = 0.0, stop_time = hmc_time));
163203

164204
ess_bps = MCMCChains.ess_rhat(bps_chain).nt.ess_per_sec;
165205
ess_hmc = MCMCChains.ess_rhat(hmc_chain).nt.ess_per_sec;
@@ -173,4 +213,4 @@ ylabel!(plt, "DynamicHMC");
173213
plt_bounds = collect(extrema(ess_hmc));
174214
lineplot!(plt, plt_bounds, plt_bounds);
175215
plt
176-
@info "For each coordinate, a point (x,y) shows the effective sample size per second for BPS (x) and HMC (y) . In blue is the diagonal x=y"
216+
@info "For each coordinate, a point (x,y) shows the effective sample size per second for BPS (x) and HMC (y) . In blue is the diagonal x=y"

0 commit comments

Comments
 (0)