diff --git a/changelog.txt b/changelog.txt index d11f2857f..ba206ee17 100644 --- a/changelog.txt +++ b/changelog.txt @@ -14,6 +14,7 @@ Version 0.26 Added +++++ +- Added feature to hide the status of the progress bar (#685). - ``test-workflow`` CLI option for testing template environments/submission scripts (#747). - Frontier environment and template (#743). - Added ``-o`` / ``--operation`` flag to report project status information for specific operations (#725). diff --git a/flow/project.py b/flow/project.py index bcb701aed..c16f97a0d 100644 --- a/flow/project.py +++ b/flow/project.py @@ -2602,6 +2602,7 @@ def _fetch_status( err, ignore_errors, status_parallelization="none", + hide_progress=False, names=None, ): """Fetch status for the provided aggregates / jobs. @@ -2617,6 +2618,8 @@ def _fetch_status( status_parallelization : str Parallelization mode for fetching the status. Allowed values are "thread", "process", or "none". (Default value = "none") + hide_progress : bool + Hide the progress bar when printing status output (Default value = False). names : iterable of :class:`str` Only show status for operations that match the provided set of names (interpreted as regular expressions), or all if the argument is @@ -2653,7 +2656,9 @@ def _fetch_status( "Valid choices are 'thread', 'process', or 'none'." ) - parallel_executor = _get_parallel_executor(status_parallelization) + parallel_executor = _get_parallel_executor( + status_parallelization, hide_progress + ) # Update the project's status cache scheduler_info = self._query_scheduler_status( @@ -2750,11 +2755,13 @@ def compute_status(data): self._get_job_labels, ignore_errors=ignore_errors, ) - job_labels = parallel_executor( - compute_labels, - individual_jobs, - desc="Fetching labels", - file=err, + job_labels = list( + parallel_executor( + compute_labels, + individual_jobs, + desc="Fetching labels", + file=err, + ) ) def combine_group_and_operation_status(aggregate_status_results): @@ -2795,7 +2802,6 @@ def combine_group_and_operation_status(aggregate_status_results): "_error": error_message, } ) - return status_results_combined, job_labels, individual_jobs PRINT_STATUS_ALL_VARYING_PARAMETERS = True @@ -2824,6 +2830,7 @@ def print_status( profile=False, eligible_jobs_max_lines=None, output_format="terminal", + hide_progress=False, operation=None, ): """Print the status of the project. @@ -2875,6 +2882,8 @@ def print_status( output_format : str Status output format, supports: 'terminal' (default), 'markdown' or 'html'. + hide_progress : bool + Hide the progress bar from the status output. (Default value = False) operation : iterable of :class:`str` Show status of operations that match the provided set of names (interpreted as regular expressions), or all if the argument is @@ -2923,6 +2932,7 @@ def print_status( err=err, ignore_errors=ignore_errors, status_parallelization=status_parallelization, + hide_progress=hide_progress, names=operation, ) @@ -3003,6 +3013,7 @@ def print_status( err=err, ignore_errors=ignore_errors, status_parallelization=status_parallelization, + hide_progress=hide_progress, names=operation, ) profiling_results = None @@ -3557,7 +3568,7 @@ def run( will not exceed this argument. The default is 1, there is no limit if this argument is None. progress : bool - Show a progress bar during execution. (Default value = False) + Show a progress bar during execution (Default value = False). order : str, callable, or None Specify the order of operations. Possible values are: @@ -5001,6 +5012,11 @@ class MyProject(FlowProject): "to show result for. Defaults to the main module. " "(requires pprofile)", ) + parser_status.add_argument( + "--hide-progress", + action="store_true", + help="Hide the progress bar", + ) parser_status.add_argument( "-o", "--operation", diff --git a/flow/util/misc.py b/flow/util/misc.py index f13ac0927..7d93c9da8 100644 --- a/flow/util/misc.py +++ b/flow/util/misc.py @@ -7,6 +7,7 @@ import os import warnings from collections.abc import MutableMapping +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from contextlib import contextmanager from functools import lru_cache, partial from itertools import cycle, islice @@ -335,7 +336,7 @@ def _run_cloudpickled_func(func, *args): return unpickled_func(*args) -def _get_parallel_executor(parallelization="none"): +def _get_parallel_executor(parallelization="none", hide_progress=False): """Get an executor for the desired parallelization strategy. This executor shows a progress bar while executing a function over an @@ -346,50 +347,60 @@ def _get_parallel_executor(parallelization="none"): (see :meth:`concurrent.futures.Executor.map`). All other ``**kwargs`` are passed to the tqdm progress bar. + Warning + ------- + We ignore key word arguments when ``hide_progress == True``. + Parameters ---------- parallelization : str Parallelization mode. Allowed values are "thread", "process", or "none". (Default value = "none") + hide_progress : bool + Hide the progress bar when printing status output (Default value = False). Returns ------- callable - A callable with signature ``func, iterable, **kwargs``. + A callable with signature ``func, iterable, **kwargs`` which returns an interator. """ - if parallelization == "thread": + if parallelization == "process": + executor = ProcessPoolExecutor().map + if not hide_progress: + executor = partial(process_map, tqdm_class=tqdm) def parallel_executor(func, iterable, **kwargs): - return thread_map(func, iterable, tqdm_class=tqdm, **kwargs) - - elif parallelization == "process": + # The top-level function called on each process cannot be a local function, it must be a + # module-level function. Creating a partial here allows us to use the passed function + # "func" regardless of whether it is a local function. + func = partial(_run_cloudpickled_func, cloudpickle.dumps(func)) + # The tqdm progress bar requires a total. We compute the total in advance because a map + # iterable (which has no total) is passed to process_map. + kwargs.setdefault("total", len(iterable)) + iterable = map(cloudpickle.dumps, iterable) + if hide_progress: + return executor(func, iterable) + return executor(func, iterable, **kwargs) + + elif parallelization == "thread": + executor = ThreadPoolExecutor().map + if not hide_progress: + executor = partial(thread_map, tqdm_class=tqdm) def parallel_executor(func, iterable, **kwargs): - # The tqdm progress bar requires a total. We compute the total in - # advance because a map iterable (which has no total) is passed to - # process_map. - if "total" not in kwargs: - kwargs["total"] = len(iterable) - - return process_map( - # The top-level function called on each process cannot be a - # local function, it must be a module-level function. Creating - # a partial here allows us to use the passed function "func" - # regardless of whether it is a local function. - partial(_run_cloudpickled_func, cloudpickle.dumps(func)), - map(cloudpickle.dumps, iterable), - tqdm_class=tqdm, - **kwargs, - ) + if hide_progress: + return executor(func, iterable) + return executor(func, iterable, **kwargs) else: + executor = map if hide_progress else partial(tmap, tqdm_class=tqdm) def parallel_executor(func, iterable, **kwargs): - if "chunksize" in kwargs: - # Chunk size only applies to thread/process parallel executors - del kwargs["chunksize"] - return list(tmap(func, iterable, tqdm_class=tqdm, **kwargs)) + if hide_progress: + return executor(func, iterable) + kwargs.pop("chunksize", None) + return executor(func, iterable, **kwargs) return parallel_executor diff --git a/tests/test_status.py b/tests/test_status.py index 1669ce907..dbfb2edeb 100644 --- a/tests/test_status.py +++ b/tests/test_status.py @@ -6,9 +6,35 @@ import sys import generate_status_reference_data as gen +import pytest import signac +@pytest.fixture(params=[True, False]) +def hide_progress_bar(request): + return request.param + + +@pytest.fixture(params=["thread", "process", "none"]) +def parallelization(request): + return request.param + + +def test_hide_progress_bar(hide_progress_bar, parallelization): + with signac.TemporaryProject() as p, signac.TemporaryProject() as status_pr: + gen.init(p) + fp = gen._TestProject.get_project(path=p.path) + fp._flow_config["status_parallelization"] = parallelization + status_pr.import_from(origin=gen.ARCHIVE_PATH) + for job in status_pr: + kwargs = job.statepoint() + tmp_err = io.TextIOWrapper(io.BytesIO(), sys.stderr.encoding) + fp.print_status(**kwargs, err=tmp_err, hide_progress=hide_progress_bar) + tmp_err.seek(0) + generated_tqdm = tmp_err.read() + assert ("Fetching status" not in generated_tqdm) == hide_progress_bar + + def test_print_status(): # Must import the data into the project. with signac.TemporaryProject() as p, signac.TemporaryProject() as status_pr: