diff --git a/doc/conf.py b/doc/conf.py index 7b07e352..acd85174 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -233,4 +233,5 @@ [u'UT-Battelle, LLC'], 1) ] -intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} +intersphinx_mapping = {'python': ('https://docs.python.org/3', None), + 'distributed': ('http://distributed.dask.org/en/stable', None)} diff --git a/doc/user_guides/dask.rst b/doc/user_guides/dask.rst index 1e527c9e..7c3c5dda 100644 --- a/doc/user_guides/dask.rst +++ b/doc/user_guides/dask.rst @@ -109,4 +109,49 @@ You batch script should then look like: ips.py --config=ips.conf --platform=platform.conf +Running with worker plugin +-------------------------- +There is the ability to set a +:class:`~distributed.diagnostics.plugin.WorkerPlugin` on the dask +worker using the `dask_worker_plugin` option in +:meth:`~ipsframework.services.ServicesProxy.submit_tasks`. + +Using a WorkerPlugin in combination with shifter allows you to do +things like coping files out of the `Temporary XFS +`_ +file system. An example of that is + +.. code-block:: python + + from distributed.diagnostics.plugin import WorkerPlugin + + class DaskWorkerPlugin(WorkerPlugin): + def __init__(self, tmp_dir, target_dir): + self.tmp_dir = tmp_dir + self.target_dir = target_dir + + def teardown(self, worker): + os.system(f"cp {self.tmp_dir}/* {self.target_dir}") + + class Worker(Component): + def step(self, timestamp=0.0): + cwd = self.services.get_working_dir() + + self.services.create_task_pool('pool') + self.services.add_task('pool', 'task_1', 1, '/tmp/', 'executable') + + worker_plugin = DaskWorkerPlugin('/tmp', cwd) + + ret_val = self.services.submit_tasks('pool', + use_dask=True, use_shifter=True, + dask_worker_plugin=worker_plugin) + + exit_status = self.services.get_finished_tasks('pool') + + +where the batch script has the temporary XFS filesystem mounted as + +.. code-block:: bash + + #SBATCH --volume="/global/cscratch1/sd/$USER/tmpfiles:/tmp:perNodeCache=size=1G" diff --git a/ipsframework/services.py b/ipsframework/services.py index 9d5241a8..ed6022ef 100644 --- a/ipsframework/services.py +++ b/ipsframework/services.py @@ -1848,7 +1848,7 @@ def add_task(self, task_pool_name, task_name, nproc, working_dir, *args, keywords=keywords) def submit_tasks(self, task_pool_name, block=True, use_dask=False, dask_nodes=1, - dask_ppn=None, launch_interval=0.0, use_shifter=False): + dask_ppn=None, launch_interval=0.0, use_shifter=False, dask_worker_plugin=None): """ Launch all unfinished tasks in task pool *task_pool_name*. If *block* is ``True``, return when all tasks have been launched. If *block* is ``False``, return when all @@ -1860,7 +1860,7 @@ def submit_tasks(self, task_pool_name, block=True, use_dask=False, dask_nodes=1, start_time = time.time() self._send_monitor_event('IPS_TASK_POOL_BEGIN', 'task_pool = %s ' % task_pool_name) task_pool: TaskPool = self.task_pools[task_pool_name] - retval = task_pool.submit_tasks(block, use_dask, dask_nodes, dask_ppn, launch_interval, use_shifter) + retval = task_pool.submit_tasks(block, use_dask, dask_nodes, dask_ppn, launch_interval, use_shifter, dask_worker_plugin) elapsed_time = time.time() - start_time self._send_monitor_event('IPS_TASK_POOL_END', 'task_pool = %s elapsed time = %.2f S' % (task_pool_name, elapsed_time), @@ -2066,7 +2066,7 @@ def add_task(self, task_name, nproc, working_dir, binary, *args, **keywords): self.queued_tasks[task_name] = Task(task_name, nproc, working_dir, binary_fullpath, *args, **keywords["keywords"]) - def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter=False): + def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter=False, dask_worker_plugin=None): """Launch tasks in *queued_tasks* using dask. :param block: Unused, this will always return after tasks are submitted @@ -2077,6 +2077,8 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter :type dask_ppn: int :param use_shifter: Option to launch dask scheduler and workers in shifter container :type use_shifter: bool + :param dask_worker_plugin: If provided this will be registered as a worker plugin with the dask client + :type dask_worker_plugin: distributed.diagnostics.plugin.WorkerPlugin """ services: ServicesProxy = self.services self.dask_file_name = os.path.join(os.getcwd(), @@ -2115,6 +2117,9 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter self.dask_client = self.dask.distributed.Client(scheduler_file=self.dask_file_name) + if dask_worker_plugin is not None: + self.dask_client.register_worker_plugin(dask_worker_plugin) + try: self.worker_event_logfile = services.sim_name + '_' + services.get_config_param("PORTAL_RUNID") + '_' + self.name + '_{}.json' except KeyError: @@ -2135,7 +2140,7 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter self.queued_tasks = {} return len(self.futures) - def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None, launch_interval=0.0, use_shifter=False): + def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None, launch_interval=0.0, use_shifter=False, dask_worker_plugin=None): """Launch tasks in *queued_tasks*. Finished tasks are handled before launching new ones. If *block* is ``True``, the number of tasks submitted is returned after all tasks have been launched @@ -2157,7 +2162,8 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None, :type launch_internal: float :param use_shifter: Option to launch dask scheduler and workers in shifter container :type use_shifter: bool - + :param dask_worker_plugin: If provided this will be registered as a worker plugin with the dask client + :type dask_worker_plugin: distributed.diagnostics.plugin.WorkerPlugin """ if use_dask: @@ -2167,7 +2173,7 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None, self.services.error("Requested to run dask within shifter but shifter not available") raise Exception("shifter not found") else: - return self.submit_dask_tasks(block, dask_nodes, dask_ppn, use_shifter) + return self.submit_dask_tasks(block, dask_nodes, dask_ppn, use_shifter, dask_worker_plugin) elif not TaskPool.dask: self.services.warning("Requested use_dask but cannot because import dask failed") elif not self.serial_pool: diff --git a/tests/helloworld/hello_worker_task_pool_dask.py b/tests/helloworld/hello_worker_task_pool_dask.py index 3d6326e9..a64e1024 100644 --- a/tests/helloworld/hello_worker_task_pool_dask.py +++ b/tests/helloworld/hello_worker_task_pool_dask.py @@ -3,6 +3,7 @@ # ------------------------------------------------------------------------------- from time import sleep import copy +from distributed.diagnostics.plugin import WorkerPlugin from ipsframework import Component @@ -12,6 +13,14 @@ def myFun(*args): return 0 +class DaskWorkerPlugin(WorkerPlugin): + def setup(self, worker): + print("Running setup of worker") + + def teardown(self, worker): + print("Running teardown of worker") + + class HelloWorker(Component): def __init__(self, services, config): super().__init__(services, config) @@ -32,7 +41,10 @@ def step(self, timestamp=0.0, **keywords): self.services.add_task('pool', 'func_' + str(i), 1, cwd, myFun, duration) - ret_val = self.services.submit_tasks('pool', use_dask=True, dask_nodes=1, dask_ppn=10) + worker_plugin = DaskWorkerPlugin() + + ret_val = self.services.submit_tasks('pool', use_dask=True, dask_nodes=1, dask_ppn=10, + dask_worker_plugin=worker_plugin) print('ret_val = ', ret_val) exit_status = self.services.get_finished_tasks('pool') print(exit_status) diff --git a/tests/helloworld/test_helloworld.py b/tests/helloworld/test_helloworld.py index 09113b3b..f2e17b4e 100644 --- a/tests/helloworld/test_helloworld.py +++ b/tests/helloworld/test_helloworld.py @@ -245,6 +245,10 @@ def test_helloworld_task_pool_dask(tmpdir, capfd): assert captured_out[4] == 'HelloDriver: finished worker init call' assert captured_out[5] == 'HelloDriver: beginning step call' assert captured_out[6] == 'Hello from HelloWorker' + + assert "Running setup of worker" in captured_out + assert "Running teardown of worker" in captured_out + assert 'ret_val = 9' in captured_out for duration in ("0.2", "0.4", "0.6"):