Skip to content

Commit

Permalink
include queue to avoid mixing of different sampler speeds
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesBuchner committed Jan 24, 2025
1 parent 5e73ea0 commit 14ceb6e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 60 deletions.
37 changes: 22 additions & 15 deletions tests/test_netiterintegrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,71 +395,78 @@ def test_singlepointqueue():
for pdim in 2, 3:
pp = SinglePointQueue(udim, pdim)
assert not pp.has(0)
pp.add(np.arange(udim), np.arange(pdim), 0, 0)
pp.add(np.arange(udim), np.arange(pdim), 0, 32, 0)
try:
pp.has(1)
assert False
except ValueError:
pass
assert pp.has(0)
try:
pp.add(np.arange(udim) + 1, np.arange(pdim) + 1, 1, 0)
pp.add(np.arange(udim) + 1, np.arange(pdim) + 1, 1, 10, 0)
assert False
except ValueError:
pass
u, p, L = pp.pop(0)
u, p, L, q = pp.pop(0)
assert_allclose(u, np.arange(udim))
assert_allclose(p, np.arange(pdim))
assert_allclose(L, 0)
pp.add(np.arange(udim) + 42, np.arange(pdim) + 42, 42, 0)
u, p, L = pp.pop(0)
assert_allclose(q, 32)
pp.add(np.arange(udim) + 42, np.arange(pdim) + 42, 42, 32, 0)
u, p, L, q = pp.pop(0)
assert_allclose(u, np.arange(udim) + 42)
assert_allclose(p, np.arange(pdim) + 42)
assert_allclose(L, 42)
assert_allclose(q, 32)

def test_roundrobinpointqueue():
udim = 2
for pdim in 2, 3:
pp = RoundRobinPointQueue(udim, pdim)
assert not pp.has(0)
pp.add(np.arange(udim), np.arange(pdim), 0, 42)
pp.add(np.arange(udim), np.arange(pdim), 0, 400, 42)
assert not pp.has(0)
assert pp.has(42)
pp.add(np.arange(udim) + 1, np.arange(pdim) + 1, 1, 32)
pp.add(np.arange(udim) + 5, np.arange(pdim) + 5, 5, 52)
pp.add(np.arange(udim) + 2, np.arange(pdim) + 2, 2, 42)
pp.add(np.arange(udim) + 1, np.arange(pdim) + 1, 1, 30, 32)
pp.add(np.arange(udim) + 5, np.arange(pdim) + 5, 5, 50, 52)
pp.add(np.arange(udim) + 2, np.arange(pdim) + 2, 2, 40, 42)
try:
pp.pop(0)
assert False
except IndexError:
pass
u, p, L = pp.pop(42)
u, p, L, q = pp.pop(42)
assert_allclose(u, np.arange(udim))
assert_allclose(p, np.arange(pdim))
assert_allclose(L, 0)
u, p, L = pp.pop(52)
assert_allclose(q, 400)
u, p, L, q = pp.pop(52)
assert_allclose(u, np.arange(udim) + 5)
assert_allclose(p, np.arange(pdim) + 5)
assert_allclose(L, 5)
u, p, L = pp.pop(32)
assert_allclose(q, 50)
u, p, L, q = pp.pop(32)
assert_allclose(u, np.arange(udim) + 1)
assert_allclose(p, np.arange(pdim) + 1)
assert_allclose(L, 1)
u, p, L = pp.pop(42)
assert_allclose(q, 30)
u, p, L, q = pp.pop(42)
assert_allclose(u, np.arange(udim) + 2)
assert_allclose(p, np.arange(pdim) + 2)
assert_allclose(L, 2)
assert_allclose(q, 40)
assert not pp.has(32)
assert not pp.has(42)
assert not pp.has(52)
for i in range(10001):
pp.add(np.arange(udim) + i, np.arange(pdim) + i, i, i % 42)
pp.add(np.arange(udim) + i, np.arange(pdim) + i, i, 60, i % 42)
for i in range(10001):
assert pp.has(i % 42)
u, p, L = pp.pop(i % 42)
u, p, L, q = pp.pop(i % 42)
assert_allclose(u, np.arange(udim) + i)
assert_allclose(p, np.arange(pdim) + i)
assert_allclose(L, i)
assert_allclose(q, 60)
for i in range(42):
assert not pp.has(i)

Expand Down
78 changes: 37 additions & 41 deletions ultranest/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from __future__ import division, print_function

import csv
import itertools
import json
import operator
import os
Expand All @@ -30,9 +29,9 @@
RobustEllipsoidRegion, ScalingLayer, WrappingEllipsoid,
find_nearby)
from .netiter import (BreadthFirstIterator, MultiCounter, PointPile,
SingleCounter, TreeNode, combine_results,
count_tree_between, dump_tree, find_nodes_before,
logz_sequence)
RoundRobinPointQueue, SingleCounter, SinglePointQueue,
TreeNode, combine_results, count_tree_between, dump_tree,
find_nodes_before, logz_sequence)
from .ordertest import UniformOrderAccumulator
from .store import HDF5PointStore, NullPointStore, TextPointStore
from .utils import (create_logger, distributed_work_chunk_size,
Expand Down Expand Up @@ -1200,6 +1199,7 @@ def __init__(self,
self.pointstore = storage_backend
else:
self.pointstore = NullPointStore(3 + self.x_dim + self.num_params)
self.pointqueue = RoundRobinPointQueue(self.x_dim, self.num_params)
self.ncall = self.pointstore.ncalls
self.ncall_region = 0

Expand Down Expand Up @@ -1842,6 +1842,8 @@ def _create_point(self, iteration, Lmin, ndraw, active_u, active_values):
Parameters
-----------
iteration: int
nested sampling iteration, used for picking submitted points in order
Lmin: float
loglikelihood threshold to draw above
ndraw: float
Expand All @@ -1867,8 +1869,7 @@ def _create_point(self, iteration, Lmin, ndraw, active_u, active_values):
nit = 0
while True:
# load current index
ib = self.ib
if ib >= len(self.samples) and self.use_point_stack:
if self.use_point_stack:
# refill cache from point store
# only root accesses the point store
next_point = np.zeros((1, 3 + self.x_dim + self.num_params)) * np.nan
Expand All @@ -1885,17 +1886,18 @@ def _create_point(self, iteration, Lmin, ndraw, active_u, active_values):
self.use_point_stack = self.comm.bcast(self.use_point_stack, root=0)
next_point = self.comm.bcast(next_point, root=0)

# unpack this point (there is only one)
self.likes = next_point[:,1]
self.samples = next_point[:,3:3 + self.x_dim]
self.samplesv = next_point[:,3 + self.x_dim:3 + self.x_dim + self.num_params]
# if we already know it is not useful, advance index to enter the next if
ib = 0 if np.isfinite(self.likes[0]) else 1
# use it if we can:
if next_point[0,1] > Lmin:
return (
next_point[0,3:3 + self.x_dim],
next_point[0,3 + self.x_dim:3 + self.x_dim + self.num_params],
next_point[0,1],
)

use_stepsampler = self.stepsampler is not None
while ib >= len(self.samples):
rank_to_fetch = iteration % self.mpi_size
while not self.pointqueue.has(rank_to_fetch):
# clear and reset cache, then refill by sampling
ib = 0
if use_stepsampler:
u, v, logl, Lmin_sampled, nc = self.stepsampler.__next__(
self.region,
Expand All @@ -1916,50 +1918,44 @@ def _create_point(self, iteration, Lmin, ndraw, active_u, active_values):
u = u.reshape((1, self.x_dim))
v = v.reshape((1, self.num_params))
logl = logl.reshape((1,))
rank_origin = self.mpi_rank + np.zeros(len(u), dtype=int)

if self.use_mpi:
# keep track of rank of received points, store them away in a temporary cache
recv_samples = self.comm.gather(u, root=0)
recv_samplesv = self.comm.gather(v, root=0)
recv_likes = self.comm.gather(logl, root=0)
recv_nc = self.comm.gather(nc, root=0)
recv_rank_origin = self.comm.gather(rank_origin, root=0)
recv_samples = self.comm.bcast(recv_samples, root=0)
recv_samplesv = self.comm.bcast(recv_samplesv, root=0)
recv_likes = self.comm.bcast(recv_likes, root=0)
recv_nc = self.comm.bcast(recv_nc, root=0)
self.samples = np.concatenate(recv_samples, axis=0)
self.samplesv = np.concatenate(recv_samplesv, axis=0)
self.likes = np.concatenate(recv_likes, axis=0)
recv_rank_origin = self.comm.bcast(recv_rank_origin, root=0)
samples = np.concatenate(recv_samples, axis=0)
samplesv = np.concatenate(recv_samplesv, axis=0)
likes = np.concatenate(recv_likes, axis=0)
rank_origins = np.concatenate(recv_rank_origin, axis=0)
self.ncall += sum(recv_nc)
if use_stepsampler:
recv_Lmin_sampled = self.comm.gather(Lmin_sampled, root=0)
self.Lmin_sampled = self.comm.bcast(recv_Lmin_sampled, root=0)
else:
self.Lmin_sampled = itertools.repeat(Lmin)
else:
self.samples = u
self.samplesv = v
self.likes = logl
samples = u
samplesv = v
likes = logl
self.ncall += nc
self.Lmin_sampled = itertools.repeat(Lmin)
rank_origins = rank_origin

# process the next point which has it % point_mpi_rank == 0
if self.log:
for ui, vi, logli, Lmini in zip(self.samples, self.samplesv, self.likes, self.Lmin_sampled):
self.pointstore.add(
_listify([Lmini, logli, quality], ui, vi),
self.ncall)

if self.likes[ib] > Lmin:
u = self.samples[ib, :]
assert np.logical_and(u > 0, u < 1).all(), (u)
p = self.samplesv[ib, :]
logl = self.likes[ib]

self.ib = ib + 1
for ui, vi, logli, rank_origin in zip(samples, samplesv, likes, rank_origins):
self.pointqueue.add(ui, vi, logli, quality, rank_origin)

u, p, logl, quality = self.pointqueue.pop(rank_to_fetch)
if self.log:
self.pointstore.add(
_listify([Lmin, logl, quality], u, p),
self.ncall)

if logl > Lmin:
return u, p, logl
else:
self.ib = ib + 1

def _update_region(
self, active_u, active_node_ids,
Expand Down
21 changes: 17 additions & 4 deletions ultranest/netiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,11 +489,12 @@ def __init__(self, udim, pdim, chunksize=1000):
self.ps = np.zeros((self.chunksize, pdim))
self.Ls = np.zeros(self.chunksize)
self.ranks = np.zeros(self.chunksize, dtype=int)
self.quality = np.zeros(self.chunksize, dtype=int)
self.filled = np.zeros(self.chunksize, dtype=bool)
self.udim = udim
self.pdim = pdim

def add(self, newpointu, newpointp, newpointL, newrank):
def add(self, newpointu, newpointp, newpointL, newquality, newrank):
"""Save point.
Parameters
Expand All @@ -504,6 +505,8 @@ def add(self, newpointu, newpointp, newpointL, newrank):
point (in p-space)
newpointL: float
loglikelihood
newquality: float
quality
newrank: int
rank of point
"""
Expand All @@ -515,6 +518,9 @@ def add(self, newpointu, newpointp, newpointL, newrank):
self.ranks = np.concatenate(
(self.ranks, np.zeros(self.chunksize, dtype=int))
)
self.quality = np.concatenate(
(self.quality, np.zeros(self.chunksize, dtype=int))
)
self.filled = np.concatenate(
(self.filled, np.zeros(self.chunksize, dtype=bool))
)
Expand All @@ -525,6 +531,7 @@ def add(self, newpointu, newpointp, newpointL, newrank):
self.ps[i, :] = newpointp
self.Ls[i] = newpointL
self.ranks[i] = newrank
self.quality[i] = newquality
self.filled[i] = True

def pop(self, rank):
Expand All @@ -546,7 +553,7 @@ def pop(self, rank):
"""
i = submasks(self.filled, self.ranks[self.filled] == rank)[0]
self.filled[i] = False
return self.us[i, :], self.ps[i, :], self.Ls[i]
return self.us[i, :], self.ps[i, :], self.Ls[i], self.quality[i]

def has(self, rank):
"""Check if there is a next point of a given rank.
Expand Down Expand Up @@ -584,8 +591,9 @@ def __init__(self, udim, pdim, chunksize=1):
self.u = None
self.p = None
self.L = None
self.quality = None

def add(self, newpointu, newpointp, newpointL, newrank):
def add(self, newpointu, newpointp, newpointL, newquality, newrank):
"""Save point.
Parameters
Expand All @@ -596,6 +604,8 @@ def add(self, newpointu, newpointp, newpointL, newrank):
point (in p-space)
newpointL: float
loglikelihood
newquality: int
point quality
newrank: int
rank of point
"""
Expand All @@ -605,6 +615,7 @@ def add(self, newpointu, newpointp, newpointL, newrank):
self.u = newpointu
self.p = newpointp
self.L = newpointL
self.quality = newquality
else:
raise ValueError("SinglePointQueue: queue not empty")

Expand All @@ -630,10 +641,12 @@ def pop(self, rank):
u = self.u
p = self.p
L = self.L
quality = self.quality
self.u = None
self.p = None
self.L = None
return u, p, L
self.quality = None
return u, p, L, quality

def has(self, rank):
"""Check if there is a next point of a given rank.
Expand Down

0 comments on commit 14ceb6e

Please sign in to comment.