Skip to content

Commit b58ccf6

Browse files
committed
Refactored nested_sampling() function to use the multiprocessing
standard package. Updated setup.py and travis files (remove pathos install).
1 parent 4ede6fd commit b58ccf6

File tree

3 files changed

+28
-34
lines changed

3 files changed

+28
-34
lines changed

.travis.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ matrix:
1111
install:
1212
- pip install -r requirements.txt
1313
- pip install dynesty>=0.9.5
14-
- pip install pathos>=0.2.4
1514
- pip install -e .
1615

1716
script: pytest tests -v

mc3/ns_driver.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
__all__ = ["nested_sampling"]
55

66
import sys
7+
import multiprocessing as mp
78
import numpy as np
9+
810
from . import stats as ms
911

1012
if sys.version_info.major == 2:
@@ -95,45 +97,35 @@ def nested_sampling(data, uncert, func, params, indparams, pmin, pmax, pstep,
9597
"""
9698
try:
9799
import dynesty
98-
if ncpu > 1:
99-
from pathos.multiprocessing import ProcessingPool
100100
except ImportError as error:
101101
log.error("ModuleNotFoundError: {}".format(error))
102102

103103
nfree = int(np.sum(pstep > 0))
104104
ifree = np.where(pstep > 0)[0]
105105
ishare = np.where(pstep < 0)[0]
106-
# Can't use multiprocessing.Pool since map can't pickle defs:
107-
pool = ProcessingPool(nodes=ncpu) if ncpu > 1 else None
108-
queue_size = ncpu if ncpu > 1 else None
109-
110-
# Setup prior transform:
111-
priors = []
112-
for p0, plo, pup, min, max, step in zip(prior, priorlow, priorup,
113-
pmin, pmax, pstep):
114-
if step <= 0:
115-
continue
116-
if plo == 0.0 and pup == 0.0:
117-
priors.append(ms.ppf_uniform(min, max))
118-
else:
119-
priors.append(ms.ppf_gaussian(p0, plo, pup))
120-
121-
def prior_transform(u):
122-
return [p(v) for p,v in zip(priors,u)]
123-
124-
def loglike(pars):
125-
params[ifree] = pars
126-
for s in ishare:
127-
params[s] = params[-int(pstep[s])-1]
128-
return -0.5*np.sum((data-func(params, *indparams))**2/uncert**2)
106+
107+
# Multiprocessing setup:
108+
if ncpu > 1:
109+
pool = mp.Pool(ncpu)
110+
queue_size = ncpu
111+
else:
112+
pool = None
113+
queue_size = None
129114

130115
# Intercept kwargs that go into DynamicNestedSampler():
131-
skip_logp = False
132116
if 'loglikelihood' in kwargs:
133117
loglike = kwargs.pop('loglikelihood')
118+
else:
119+
loglike = ms.Loglike(data, uncert, func, params, indparams, pstep)
120+
134121
if 'prior_transform' in kwargs:
135122
prior_transform = kwargs.pop('prior_transform')
136123
skip_logp = True
124+
else:
125+
prior_transform = ms.Prior_transform(prior, priorlow, priorup,
126+
pmin, pmax, pstep)
127+
skip_logp = False
128+
137129
if 'ndim' in kwargs:
138130
nfree = kwargs.pop('ndim')
139131
if 'pool' in kwargs:
@@ -148,7 +140,7 @@ def loglike(pars):
148140

149141
weights = np.exp(sampler.results.logwt - sampler.results.logz[-1])
150142
isample = resample_equal(weights)
151-
posterior = sampler.results.samples[isample] #[::thinning]
143+
posterior = sampler.results.samples[isample]
152144
chisq = -2.0*sampler.results.logl[isample]
153145

154146
# Contribution to chi-square from non-uniform priors:
@@ -173,25 +165,29 @@ def loglike(pars):
173165
log.msg("\nNested Sampling Summary:"
174166
"\n------------------------")
175167

176-
# Get some stats:
168+
posterior = posterior[::thinning]
169+
chisq = chisq[::thinning]
170+
log_post = log_post[::thinning]
171+
# Number of evaluated and kept samples:
177172
nsample = sampler.results['niter']
178-
nZsample = len(posterior) # Valid samples (after thinning and burning)
173+
nzsample = len(posterior)
179174

180175
fmt = len(str(nsample))
181176
log.msg("Number of evaluated samples: {:{}d}".
182177
format(nsample, fmt), indent=2)
183178
log.msg("Thinning factor: {:{}d}".
184179
format(thinning, fmt), indent=2)
185180
log.msg("NS sample size (thinned): {:{}d}".
186-
format(nZsample, fmt), indent=2)
181+
format(nzsample, fmt), indent=2)
187182
log.msg("Sampling efficiency: {:.2f}%\n".
188183
format(sampler.results['eff']), indent=2)
189184

190185
# Build the output dict:
191186
output = {
192187
# The posterior:
193188
'posterior':posterior,
194-
'zchain':np.zeros(nsample, int),
189+
'zchain':np.zeros(nzsample, int),
190+
'zmask':np.arange(nzsample),
195191
'chisq':chisq,
196192
'log_post':log_post,
197193
'burnin':0,

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@
4545
'scipy>=0.17.1',
4646
'matplotlib>=2.0',],
4747
tests_require = ['pytest>=3.9',
48-
'dynesty>=0.9.5',
49-
'pathos>=0.2.4'],
48+
'dynesty>=0.9.5'],
5049
include_package_data=True,
5150
license = 'MIT',
5251
description = 'Multi-core Markov-chain Monte Carlo package.',

0 commit comments

Comments
 (0)