Skip to content

Commit ece0aaf

Browse files
zoj613brandonwillard
authored andcommitted
pack nuts results some more and update readme example
1 parent 0d90fbb commit ece0aaf

File tree

8 files changed

+209
-100
lines changed

8 files changed

+209
-100
lines changed

README.md

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,12 @@ initial_state = nuts.new_state(y_vv, logprob_fn)
4646

4747
step_size = at.as_tensor(1e-2)
4848
inverse_mass_matrix=at.as_tensor(1.0)
49-
(
50-
next_state,
51-
potential_energy,
52-
potential_energy_grad,
53-
acceptance_prob,
54-
num_doublings,
55-
is_turning,
56-
is_diverging,
57-
), updates = kernel(*initial_state, step_size, inverse_mass_matrix)
58-
59-
next_step_fn = aesara.function([y_vv], next_state, updates=updates)
49+
chain_info, updates = kernel(initial_state, step_size, inverse_mass_matrix)
50+
51+
next_step_fn = aesara.function([y_vv], chain_info.state.position, updates=updates)
6052

6153
print(next_step_fn(0))
62-
# 0.14344008534533775
54+
# 1.1034719409361107
6355
```
6456

6557
## Install

aehmc/hmc.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def step(
7979
step_size: TensorVariable,
8080
inverse_mass_matrix: TensorVariable,
8181
num_integration_steps: int,
82-
) -> Tuple[Tuple[IntegratorState, TensorVariable, bool], Dict]:
82+
) -> Tuple[trajectory.Diagnostics, Dict]:
8383
"""Perform a single step of the HMC algorithm.
8484
8585
Parameters
@@ -120,10 +120,8 @@ def step(
120120
divergence_threshold,
121121
)
122122
updated_state = state._replace(momentum=momentum_generator(srng))
123-
new_state, acceptance_proba, is_divergent, updates = proposal_generator(
124-
srng, updated_state, step_size
125-
)
126-
return (new_state, acceptance_proba, is_divergent), updates
123+
chain_info, updates = proposal_generator(srng, updated_state, step_size)
124+
return chain_info, updates
127125

128126
return step
129127

@@ -158,7 +156,7 @@ def hmc_proposal(
158156

159157
def propose(
160158
srng: RandomStream, state: IntegratorState, step_size: TensorVariable
161-
) -> Tuple[IntegratorState, TensorVariable, bool, Dict]:
159+
) -> Tuple[trajectory.Diagnostics, Dict]:
162160
"""Use the HMC algorithm to propose a new state.
163161
164162
Parameters
@@ -195,7 +193,14 @@ def propose(
195193
p_accept = at.clip(at.exp(delta_energy), 0, 1.0)
196194
do_accept = srng.bernoulli(p_accept)
197195
final_state = IntegratorState(*ifelse(do_accept, new_state, state))
196+
chain_info = trajectory.Diagnostics(
197+
state=final_state,
198+
acceptance_probability=p_accept,
199+
is_diverging=is_transition_divergent,
200+
num_doublings=None,
201+
is_turning=None,
202+
)
198203

199-
return final_state, p_accept, is_transition_divergent, updates
204+
return chain_info, updates
200205

201206
return propose

aehmc/nuts.py

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable
1+
from typing import Callable, Dict, Tuple
22

33
import aesara.tensor as at
44
import numpy as np
@@ -9,7 +9,7 @@
99
from aehmc.integrators import IntegratorState
1010
from aehmc.proposals import ProposalState
1111
from aehmc.termination import iterative_uturn
12-
from aehmc.trajectory import dynamic_integration, multiplicative_expansion
12+
from aehmc.trajectory import Diagnostics, dynamic_integration, multiplicative_expansion
1313

1414
new_state = hmc.new_state
1515

@@ -54,12 +54,10 @@ def potential_fn(x):
5454
return -logprob_fn(x)
5555

5656
def step(
57-
q: TensorVariable,
58-
potential_energy: TensorVariable,
59-
potential_energy_grad: TensorVariable,
57+
state: IntegratorState,
6058
step_size: TensorVariable,
6159
inverse_mass_matrix: TensorVariable,
62-
):
60+
) -> Tuple[Diagnostics, Dict]:
6361
"""Use the NUTS algorithm to propose a new state.
6462
6563
Parameters
@@ -112,50 +110,46 @@ def step(
112110
max_num_expansions,
113111
)
114112

115-
p = momentum_generator(srng)
116-
initial_state = IntegratorState(
117-
position=q,
118-
momentum=p,
119-
potential_energy=potential_energy,
120-
potential_energy_grad=potential_energy_grad,
113+
initial_state = state._replace(momentum=momentum_generator(srng))
114+
initial_termination_state = new_termination_state(
115+
initial_state.position, max_num_expansions
116+
)
117+
initial_energy = initial_state.potential_energy + kinetic_energy_fn(
118+
initial_state.momentum
121119
)
122-
initial_termination_state = new_termination_state(q, max_num_expansions)
123-
initial_energy = potential_energy + kinetic_energy_fn(p)
124120
initial_proposal = ProposalState(
125121
state=initial_state,
126122
energy=initial_energy,
127123
weight=at.as_tensor(0.0, dtype=np.float64),
128124
sum_log_p_accept=at.as_tensor(-np.inf, dtype=np.float64),
129125
)
130-
result, updates = expand(
126+
127+
results, updates = expand(
131128
initial_proposal,
132129
initial_state,
133130
initial_state,
134-
p,
131+
initial_state.momentum,
135132
initial_termination_state,
136133
initial_energy,
137134
step_size,
138135
)
139136

140-
# New MCMC proposal
141-
q_new = result[0][-1]
142-
potential_energy_new = result[2][-1]
143-
potential_energy_grad_new = result[3][-1]
144-
145-
# Diagnostics
146-
is_turning = result[-1][-1]
147-
is_diverging = result[-2][-1]
148-
num_doublings = result[-3][-1]
149-
acceptance_probability = result[-4][-1]
150-
151-
return (
152-
q_new,
153-
potential_energy_new,
154-
potential_energy_grad_new,
155-
acceptance_probability,
156-
num_doublings,
157-
is_turning,
158-
is_diverging,
159-
), updates
137+
# extract the last iteration from multiplicative_expansion chain diagnostics
138+
chain_info = Diagnostics(
139+
state=IntegratorState(
140+
position=results.diagnostics.state.position[-1],
141+
momentum=results.diagnostics.state.momentum[-1],
142+
potential_energy=results.diagnostics.state.potential_energy[-1],
143+
potential_energy_grad=results.diagnostics.state.potential_energy_grad[
144+
-1
145+
],
146+
),
147+
acceptance_probability=results.diagnostics.acceptance_probability[-1],
148+
num_doublings=results.diagnostics.num_doublings[-1],
149+
is_turning=results.diagnostics.is_turning[-1],
150+
is_diverging=results.diagnostics.is_diverging[-1],
151+
)
152+
153+
return chain_info, updates
160154

161155
return step

aehmc/trajectory.py

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Dict, Tuple
1+
from typing import Callable, Dict, NamedTuple, Tuple
22

33
import aesara
44
import aesara.tensor as at
@@ -376,6 +376,23 @@ def add_one_state(
376376
return integrate
377377

378378

379+
class Diagnostics(NamedTuple):
380+
state: IntegratorState
381+
acceptance_probability: TensorVariable
382+
num_doublings: TensorVariable
383+
is_turning: TensorVariable
384+
is_diverging: TensorVariable
385+
386+
387+
class MultiplicativeExpansionResult(NamedTuple):
388+
proposals: ProposalState
389+
right_states: IntegratorState
390+
left_states: IntegratorState
391+
momentum_sums: TensorVariable
392+
termination_states: TerminationState
393+
diagnostics: Diagnostics
394+
395+
379396
def multiplicative_expansion(
380397
srng: RandomStream,
381398
trajectory_integrator: Callable,
@@ -416,7 +433,7 @@ def expand(
416433
termination_state: TerminationState,
417434
initial_energy,
418435
step_size,
419-
):
436+
) -> Tuple[MultiplicativeExpansionResult, Dict]:
420437
"""Expand the current trajectory multiplicatively.
421438
422439
At each step we draw a direction at random, build a subtrajectory starting
@@ -465,7 +482,7 @@ def expand_once(
465482
momentum_sum_ckpts,
466483
idx_min,
467484
idx_max,
468-
):
485+
) -> Tuple[Tuple[TensorVariable, ...], Dict, until]:
469486
left_state = (
470487
q_left,
471488
p_left,
@@ -591,7 +608,33 @@ def expand_once(
591608
)
592609

593610
expansion_steps = at.arange(0, max_num_expansions)
594-
results, updates = aesara.scan(
611+
# results, updates = aesara.scan(
612+
(
613+
proposal_state_position,
614+
proposal_state_momentum,
615+
proposal_state_potential_energy,
616+
proposal_state_potential_energy_grad,
617+
proposal_energy,
618+
proposal_weight,
619+
proposal_sum_log_p_accept,
620+
left_state_position,
621+
left_state_momentum,
622+
left_state_potential_energy,
623+
left_state_potential_energy_grad,
624+
right_state_position,
625+
right_state_momentum,
626+
right_state_potential_energy,
627+
right_state_potential_energy_grad,
628+
momentum_sum,
629+
momentum_checkpoints,
630+
momentum_sum_checkpoints,
631+
min_indices,
632+
max_indices,
633+
acceptance_probability,
634+
num_doublings,
635+
is_diverging,
636+
is_turning,
637+
), updates = aesara.scan(
595638
expand_once,
596639
outputs_info=(
597640
proposal.state.position,
@@ -610,16 +653,63 @@ def expand_once(
610653
right_state.potential_energy,
611654
right_state.potential_energy_grad,
612655
momentum_sum,
613-
*termination_state,
656+
termination_state.momentum_checkpoints,
657+
termination_state.momentum_sum_checkpoints,
658+
termination_state.min_index,
659+
termination_state.max_index,
614660
None,
615661
None,
616662
None,
617663
None,
618664
),
619665
sequences=expansion_steps,
620666
)
621-
622-
return results, updates
667+
# Ensure each item of the returned result sequence is packed into the appropriate namedtuples.
668+
typed_result = MultiplicativeExpansionResult(
669+
proposals=ProposalState(
670+
state=IntegratorState(
671+
position=proposal_state_position,
672+
momentum=proposal_state_momentum,
673+
potential_energy=proposal_state_potential_energy,
674+
potential_energy_grad=proposal_state_potential_energy_grad,
675+
),
676+
energy=proposal_energy,
677+
weight=proposal_weight,
678+
sum_log_p_accept=proposal_sum_log_p_accept,
679+
),
680+
left_states=IntegratorState(
681+
position=left_state_position,
682+
momentum=left_state_momentum,
683+
potential_energy=left_state_potential_energy,
684+
potential_energy_grad=left_state_potential_energy_grad,
685+
),
686+
right_states=IntegratorState(
687+
position=right_state_position,
688+
momentum=right_state_momentum,
689+
potential_energy=right_state_potential_energy,
690+
potential_energy_grad=right_state_potential_energy_grad,
691+
),
692+
momentum_sums=momentum_sum,
693+
termination_states=TerminationState(
694+
momentum_checkpoints=momentum_checkpoints,
695+
momentum_sum_checkpoints=momentum_sum_checkpoints,
696+
min_index=min_indices,
697+
max_index=max_indices,
698+
),
699+
diagnostics=Diagnostics(
700+
state=IntegratorState(
701+
position=proposal_state_position,
702+
momentum=proposal_state_momentum,
703+
potential_energy=proposal_state_potential_energy,
704+
potential_energy_grad=proposal_state_potential_energy_grad,
705+
),
706+
acceptance_probability=acceptance_probability,
707+
num_doublings=num_doublings,
708+
is_turning=is_turning,
709+
is_diverging=is_diverging,
710+
),
711+
)
712+
return typed_result, updates
623713

624714
return expand
625715

aehmc/window_adaptation.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from aehmc.algorithms import DualAveragingState
1111
from aehmc.integrators import IntegratorState
1212
from aehmc.mass_matrix import covariance_adaptation
13+
from aehmc.nuts import Diagnostics
1314
from aehmc.step_size import dual_averaging_adaptation
1415

1516

@@ -42,7 +43,13 @@ def one_step(
4243
step_size, # parameters
4344
inverse_mass_matrix,
4445
):
45-
chain_state = (q, potential_energy, potential_energy_grad)
46+
chain_state = IntegratorState(
47+
position=q,
48+
momentum=None,
49+
potential_energy=potential_energy,
50+
potential_energy_grad=potential_energy_grad,
51+
)
52+
4653
warmup_state = (
4754
DualAveragingState(
4855
step=step,
@@ -56,17 +63,17 @@ def one_step(
5663
parameters = (step_size, inverse_mass_matrix)
5764

5865
# Advance the chain by one step
59-
chain_state, inner_updates = kernel(*chain_state, *parameters)
66+
chain_info, inner_updates = kernel(chain_state, *parameters)
6067

6168
# Update the warmup state and parameters
6269
warmup_state, parameters = update_adapt(
63-
warmup_step, warmup_state, parameters, chain_state
70+
warmup_step, warmup_state, parameters, chain_info
6471
)
6572
da_state = warmup_state[0]
6673
return (
67-
chain_state[0], # q
68-
chain_state[1], # potential_energy
69-
chain_state[2], # potential_energy_grad
74+
chain_info.state.position, # q
75+
chain_info.state.potential_energy, # potential_energy
76+
chain_info.state.potential_energy_grad, # potential_energy_grad
7077
da_state.step,
7178
da_state.iterates, # log_step_size
7279
da_state.iterates_avg, # log_step_size_avg
@@ -182,14 +189,19 @@ def final(
182189
step_size = at.exp(da_state.iterates_avg) # return stepsize_avg at the end
183190
return step_size, inverse_mass_matrix
184191

185-
def update(step: int, warmup_state: Tuple, parameters: Tuple, chain_state: Tuple):
186-
position, _, _, p_accept, *_ = chain_state
187-
192+
def update(
193+
step: int, warmup_state: Tuple, parameters: Tuple, chain_state: Diagnostics
194+
):
188195
stage = schedule_stage[step]
189196
warmup_state, parameters = where_warmup_state(
190197
at.eq(stage, 0),
191-
fast_update(p_accept, warmup_state, parameters),
192-
slow_update(position, p_accept, warmup_state, parameters),
198+
fast_update(chain_state.acceptance_probability, warmup_state, parameters),
199+
slow_update(
200+
chain_state.state.position,
201+
chain_state.acceptance_probability,
202+
warmup_state,
203+
parameters,
204+
),
193205
)
194206

195207
is_middle_window_end = schedule_middle_window[step]

0 commit comments

Comments
 (0)