Skip to content

Commit 14a3735

Browse files
committed
pipe through --cores correctly
1 parent 87d07d0 commit 14a3735

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

docs/source/release-history.rst

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Release History
55
Unreleased (2021-09-10)
66
=======================
77
- Fix sign on cosmological K correction
8+
- Pipe through the ``--cores`` argument correctly
89

910
v1.1.0 (2020-10-02)
1011
===================

superphot/fit.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ITERATIONS = 10000
1818
TUNING = 25000
1919
WALKERS = 25
20+
CORES = 1
2021
PARAMNAMES = ['Amplitude', 'Plateau Slope (d$^{-1}$)', 'Plateau Duration (d)',
2122
'Reference Epoch (d)', 'Rise Time (d)', 'Fall Time (d)']
2223

@@ -303,7 +304,7 @@ def setup_model2(obs, parameters, x_priors, y_priors):
303304

304305

305306
def sample_or_load_trace(model, trace_file, force=False, iterations=ITERATIONS, walkers=WALKERS, tuning=TUNING,
306-
cores=1):
307+
cores=CORES):
307308
"""
308309
Run a Metropolis Hastings MCMC for the given model with a certain number iterations, burn in (tuning), and walkers.
309310
@@ -324,7 +325,7 @@ def sample_or_load_trace(model, trace_file, force=False, iterations=ITERATIONS,
324325
tuning : int, optional
325326
The number of iterations used for tuning.
326327
cores : int, optional
327-
The number of walkers to run in parallel. Default: 1.
328+
The number of walkers to run in parallel.
328329
329330
Returns
330331
-------
@@ -585,7 +586,7 @@ def select_event_data(t, phase_min=PHASE_MIN, phase_max=PHASE_MAX, nsigma=None):
585586

586587

587588
def two_iteration_mcmc(light_curve, outfile, filters=None, force=False, force_second=False, do_diagnostics=True,
588-
iterations=ITERATIONS, walkers=WALKERS, tuning=TUNING):
589+
iterations=ITERATIONS, walkers=WALKERS, tuning=TUNING, cores=CORES):
589590
"""
590591
Fit the model to the observed light curve. Then combine the posteriors for each filter and use that as the new prior
591592
for a second iteration of fitting.
@@ -611,6 +612,8 @@ def two_iteration_mcmc(light_curve, outfile, filters=None, force=False, force_se
611612
The number of cores and walkers used.
612613
tuning : int, optional
613614
The number of iterations used for tuning.
615+
cores : int, optional
616+
The number of walkers to run in parallel.
614617
615618
Returns
616619
-------
@@ -633,7 +636,7 @@ def two_iteration_mcmc(light_curve, outfile, filters=None, force=False, force_se
633636
obs = t[t['FLT'] == fltr]
634637
model1, parameters1 = setup_model1(obs, t['FLUXCAL'].max())
635638
outfile1 = outfile.format('_1' + fltr)
636-
trace1 = sample_or_load_trace(model1, outfile1, force, iterations, walkers, tuning)
639+
trace1 = sample_or_load_trace(model1, outfile1, force, iterations, walkers, tuning, cores)
637640
traces1[fltr] = trace1
638641
if do_diagnostics:
639642
diagnostics(obs, trace1, parameters1, outfile1)
@@ -648,7 +651,7 @@ def two_iteration_mcmc(light_curve, outfile, filters=None, force=False, force_se
648651
obs = t[t['FLT'] == fltr]
649652
model2, parameters2 = setup_model2(obs, parameters1, x_priors, y_priors)
650653
outfile2 = outfile.format('_2' + fltr)
651-
trace2 = sample_or_load_trace(model2, outfile2, force or force_second, iterations, walkers, tuning)
654+
trace2 = sample_or_load_trace(model2, outfile2, force or force_second, iterations, walkers, tuning, cores)
652655
traces2[fltr] = trace2
653656
if do_diagnostics:
654657
diagnostics(obs, trace2, parameters2, outfile2)
@@ -663,7 +666,7 @@ def _main():
663666
parser.add_argument('--iterations', type=int, default=ITERATIONS, help='Number of steps after burn-in')
664667
parser.add_argument('--tuning', type=int, default=TUNING, help='Number of burn-in steps')
665668
parser.add_argument('--walkers', type=int, default=WALKERS, help='Number of walkers')
666-
parser.add_argument('--cores', type=int, default=1, help='Number of walkers to run in parallel')
669+
parser.add_argument('--cores', type=int, default=CORES, help='Number of walkers to run in parallel')
667670
parser.add_argument('--output-dir', type=str, default='.', help='Path in which to save the PyMC3 trace data')
668671
parser.add_argument('--zmin', type=float, help='Do not fit the transient if redshift <= zmin in the header')
669672
parser.add_argument('-f', '--force', action='store_true', help='redo the fit even if the trace is already saved')

0 commit comments

Comments
 (0)