Skip to content

Commit

Permalink
Add ability to register Dask WorkerPlugin
Browse files Browse the repository at this point in the history
  • Loading branch information
rosswhitfield committed Mar 10, 2022
1 parent 8fb6efa commit 3d7b3e1
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 8 deletions.
3 changes: 2 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,5 @@
[u'UT-Battelle, LLC'], 1)
]

intersphinx_mapping = {'python': ('https://docs.python.org/3', None)}
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
'distributed': ('http://distributed.dask.org/en/stable', None)}
45 changes: 45 additions & 0 deletions doc/user_guides/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,49 @@ You batch script should then look like:
ips.py --config=ips.conf --platform=platform.conf
Running with worker plugin
--------------------------

There is the ability to set a
:class:`~distributed.diagnostics.plugin.WorkerPlugin` on the dask
worker using the `dask_worker_plugin` option in
:meth:`~ipsframework.services.ServicesProxy.submit_tasks`.

Using a WorkerPlugin in combination with shifter allows you to do
things like coping files out of the `Temporary XFS
<https://docs.nersc.gov/development/shifter/how-to-use/#temporary-xfs-files-for-optimizing-io>`_
file system. An example of that is

.. code-block:: python
from distributed.diagnostics.plugin import WorkerPlugin
class DaskWorkerPlugin(WorkerPlugin):
def __init__(self, tmp_dir, target_dir):
self.tmp_dir = tmp_dir
self.target_dir = target_dir
def teardown(self, worker):
os.system(f"cp {self.tmp_dir}/* {self.target_dir}")
class Worker(Component):
def step(self, timestamp=0.0):
cwd = self.services.get_working_dir()
self.services.create_task_pool('pool')
self.services.add_task('pool', 'task_1', 1, '/tmp/', 'executable')
worker_plugin = DaskWorkerPlugin('/tmp', cwd)
ret_val = self.services.submit_tasks('pool',
use_dask=True, use_shifter=True,
dask_worker_plugin=worker_plugin)
exit_status = self.services.get_finished_tasks('pool')
where the batch script has the temporary XFS filesystem mounted as

.. code-block:: bash
#SBATCH --volume="/global/cscratch1/sd/$USER/tmpfiles:/tmp:perNodeCache=size=1G"
18 changes: 12 additions & 6 deletions ipsframework/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -1848,7 +1848,7 @@ def add_task(self, task_pool_name, task_name, nproc, working_dir,
*args, keywords=keywords)

def submit_tasks(self, task_pool_name, block=True, use_dask=False, dask_nodes=1,
dask_ppn=None, launch_interval=0.0, use_shifter=False):
dask_ppn=None, launch_interval=0.0, use_shifter=False, dask_worker_plugin=None):
"""
Launch all unfinished tasks in task pool *task_pool_name*. If *block* is ``True``,
return when all tasks have been launched. If *block* is ``False``, return when all
Expand All @@ -1860,7 +1860,7 @@ def submit_tasks(self, task_pool_name, block=True, use_dask=False, dask_nodes=1,
start_time = time.time()
self._send_monitor_event('IPS_TASK_POOL_BEGIN', 'task_pool = %s ' % task_pool_name)
task_pool: TaskPool = self.task_pools[task_pool_name]
retval = task_pool.submit_tasks(block, use_dask, dask_nodes, dask_ppn, launch_interval, use_shifter)
retval = task_pool.submit_tasks(block, use_dask, dask_nodes, dask_ppn, launch_interval, use_shifter, dask_worker_plugin)
elapsed_time = time.time() - start_time
self._send_monitor_event('IPS_TASK_POOL_END', 'task_pool = %s elapsed time = %.2f S' %
(task_pool_name, elapsed_time),
Expand Down Expand Up @@ -2066,7 +2066,7 @@ def add_task(self, task_name, nproc, working_dir, binary, *args, **keywords):
self.queued_tasks[task_name] = Task(task_name, nproc, working_dir, binary_fullpath, *args,
**keywords["keywords"])

def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter=False):
def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter=False, dask_worker_plugin=None):
"""Launch tasks in *queued_tasks* using dask.
:param block: Unused, this will always return after tasks are submitted
Expand All @@ -2077,6 +2077,8 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
:type dask_ppn: int
:param use_shifter: Option to launch dask scheduler and workers in shifter container
:type use_shifter: bool
:param dask_worker_plugin: If provided this will be registered as a worker plugin with the dask client
:type dask_worker_plugin: distributed.diagnostics.plugin.WorkerPlugin
"""
services: ServicesProxy = self.services
self.dask_file_name = os.path.join(os.getcwd(),
Expand Down Expand Up @@ -2115,6 +2117,9 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter

self.dask_client = self.dask.distributed.Client(scheduler_file=self.dask_file_name)

if dask_worker_plugin is not None:
self.dask_client.register_worker_plugin(dask_worker_plugin)

try:
self.worker_event_logfile = services.sim_name + '_' + services.get_config_param("PORTAL_RUNID") + '_' + self.name + '_{}.json'
except KeyError:
Expand All @@ -2135,7 +2140,7 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
self.queued_tasks = {}
return len(self.futures)

def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None, launch_interval=0.0, use_shifter=False):
def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None, launch_interval=0.0, use_shifter=False, dask_worker_plugin=None):
"""Launch tasks in *queued_tasks*. Finished tasks are handled before
launching new ones. If *block* is ``True``, the number of
tasks submitted is returned after all tasks have been launched
Expand All @@ -2157,7 +2162,8 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None,
:type launch_internal: float
:param use_shifter: Option to launch dask scheduler and workers in shifter container
:type use_shifter: bool
:param dask_worker_plugin: If provided this will be registered as a worker plugin with the dask client
:type dask_worker_plugin: distributed.diagnostics.plugin.WorkerPlugin
"""

if use_dask:
Expand All @@ -2167,7 +2173,7 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None,
self.services.error("Requested to run dask within shifter but shifter not available")
raise Exception("shifter not found")
else:
return self.submit_dask_tasks(block, dask_nodes, dask_ppn, use_shifter)
return self.submit_dask_tasks(block, dask_nodes, dask_ppn, use_shifter, dask_worker_plugin)
elif not TaskPool.dask:
self.services.warning("Requested use_dask but cannot because import dask failed")
elif not self.serial_pool:
Expand Down
14 changes: 13 additions & 1 deletion tests/helloworld/hello_worker_task_pool_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# -------------------------------------------------------------------------------
from time import sleep
import copy
from distributed.diagnostics.plugin import WorkerPlugin
from ipsframework import Component


Expand All @@ -12,6 +13,14 @@ def myFun(*args):
return 0


class DaskWorkerPlugin(WorkerPlugin):
def setup(self, worker):
print("Running setup of worker")

def teardown(self, worker):
print("Running teardown of worker")


class HelloWorker(Component):
def __init__(self, services, config):
super().__init__(services, config)
Expand All @@ -32,7 +41,10 @@ def step(self, timestamp=0.0, **keywords):
self.services.add_task('pool', 'func_' + str(i), 1,
cwd, myFun, duration)

ret_val = self.services.submit_tasks('pool', use_dask=True, dask_nodes=1, dask_ppn=10)
worker_plugin = DaskWorkerPlugin()

ret_val = self.services.submit_tasks('pool', use_dask=True, dask_nodes=1, dask_ppn=10,
dask_worker_plugin=worker_plugin)
print('ret_val = ', ret_val)
exit_status = self.services.get_finished_tasks('pool')
print(exit_status)
Expand Down
4 changes: 4 additions & 0 deletions tests/helloworld/test_helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ def test_helloworld_task_pool_dask(tmpdir, capfd):
assert captured_out[4] == 'HelloDriver: finished worker init call'
assert captured_out[5] == 'HelloDriver: beginning step call'
assert captured_out[6] == 'Hello from HelloWorker'

assert "Running setup of worker" in captured_out
assert "Running teardown of worker" in captured_out

assert 'ret_val = 9' in captured_out

for duration in ("0.2", "0.4", "0.6"):
Expand Down

0 comments on commit 3d7b3e1

Please sign in to comment.