Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "connect to job" functionality, use that for CondorSpawner #200

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
141 changes: 141 additions & 0 deletions batchspawner/batchspawner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from jupyterhub.spawner import Spawner
from traitlets import Integer, Unicode, Float, Dict, default

from jupyterhub.utils import random_port
from jupyterhub.spawner import set_user_setuid


Expand Down Expand Up @@ -186,6 +187,20 @@ def _req_keepvars_default(self):
"specification.",
).tag(config=True)

connect_to_job_cmd = Unicode('',
help="Command to connect to running batch job and forward the port "
"of the running notebook to the Hub. If empty, direct connectivity is assumed. "
"Uses self.job_id as {job_id}, self.port as {port} and self.ip as {host}."
"If {rport} is used in this string, it is set to self.port, "
"and a new random self.port is chosen locally and used as {port}."
"This is useful e.g. for SSH port forwarding."
).tag(config=True)

rport = Integer(0,
help="Remote port of notebook, will be set if it differs from the self.port."
"This is set by connect_to_job() if needed."
)

# Raw output of job submission command unless overridden
job_id = Unicode()

Expand Down Expand Up @@ -215,6 +230,26 @@ def cmd_formatted_for_batch(self):
"""The command which is substituted inside of the batch script"""
return " ".join([self.batchspawner_singleuser_cmd] + self.cmd + self.get_args())

async def connect_to_job(self):
"""This command ensures the port of the singleuser server is reachable from the
Batchspawner machine. Only called if connect_to_job_cmd is set.
If the template string connect_to_job_cmd contains {rport},
a new random self.port is chosen locally (useful e.g. for SSH port forwarding).
"""
subvars = self.get_req_subvars()
subvars['host'] = self.ip
subvars['job_id'] = self.job_id
if '{rport}' in self.connect_to_job_cmd:
self.rport = self.port
self.port = random_port()
subvars['rport'] = self.rport
subvars['port'] = self.port
else:
subvars['port'] = self.port
cmd = ' '.join((format_template(self.exec_prefix, **subvars),
format_template(self.connect_to_job_cmd, **subvars)))
await self.run_background_command(cmd)

async def run_command(self, cmd, input=None, env=None):
proc = await asyncio.create_subprocess_shell(
cmd,
Expand Down Expand Up @@ -268,6 +303,46 @@ async def run_command(self, cmd, input=None, env=None):
out = out.decode().strip()
return out

# List of running background processes, e.g. used by connect_to_job.
background_processes = []

async def _async_wait_process(self, sleep_time):
"""Asynchronously sleeping process for delayed checks"""
await asyncio.sleep(sleep_time)

async def run_background_command(self, cmd, startup_check_delay=1, input=None, env=None):
"""Runs the given background command, adds it to background_processes,
and checks if the command is still running after startup_check_delay."""
background_process = asyncio.ensure_future(self.run_command(cmd, input, env))
success_check_delay = asyncio.ensure_future(self._async_wait_process(startup_check_delay))

# Start up both the success check process and the actual process.
done, pending = await asyncio.wait([background_process, success_check_delay], return_when=asyncio.FIRST_COMPLETED)

# If the success check process is the one which exited first, all is good, else fail.
if success_check_delay in done:
background_task = list(pending)[0]
self.background_processes.append(background_task)
return background_task
else:
self.log.error("Background command exited early: %s" % cmd)
gather_pending = asyncio.gather(*pending)
gather_pending.cancel()
try:
self.log.debug("Cancelling pending success check task...")
await gather_pending
except asyncio.CancelledError:
self.log.debug("Cancel was successful.")
pass

# Retrieve exception from "done" process.
try:
gather_done = asyncio.gather(*done)
await gather_done
except:
self.log.debug("Retrieving exception from failed background task...")
raise RuntimeError('{} failed!'.format(cmd))

async def _get_batch_script(self, **subvars):
"""Format batch script from vars"""
# Could be overridden by subclasses, but mainly useful for testing
Expand Down Expand Up @@ -299,6 +374,27 @@ async def submit_batch_script(self):
self.job_id = ""
return self.job_id

def background_tasks_ok(self):
# Check background processes.
if self.background_processes:
self.log.debug('Checking background processes...')
for background_process in self.background_processes:
if background_process.done():
self.log.debug('Found a background process in state "done"...')
try:
background_exception = background_process.exception()
except asyncio.CancelledError:
self.log.error('Background process was cancelled!')
if background_exception:
self.log.error('Background process exited with an exception:')
self.log.error(background_exception)
self.log.error('At least one background process exited!')
return False
else:
self.log.debug('Found a not-yet-done background process...')
self.log.debug('All background processes still running.')
return True

# Override if your batch system needs something more elaborate to query the job status
batch_query_cmd = Unicode(
"",
Expand Down Expand Up @@ -353,6 +449,29 @@ async def cancel_batch_job(self):
)
)
self.log.info("Cancelling job " + self.job_id + ": " + cmd)

if self.background_processes:
self.log.debug('Job being cancelled, cancelling background processes...')
for background_process in self.background_processes:
if not background_process.cancelled():
try:
background_process.cancel()
except:
self.log.error('Encountered an exception cancelling background process...')
self.log.debug('Cancelled background process, waiting for it to finish...')
try:
await asyncio.wait([background_process])
except asyncio.CancelledError:
self.log.error('Successfully cancelled background process.')
pass
except:
self.log.error('Background process exited with another exception!')
raise
else:
self.log.debug('Background process already cancelled...')
self.background_processes.clear()
self.log.debug('All background processes cancelled.')

await self.run_command(cmd)

def load_state(self, state):
Expand Down Expand Up @@ -400,6 +519,13 @@ async def poll(self):
"""Poll the process"""
status = await self.query_job_status()
if status in (JobStatus.PENDING, JobStatus.RUNNING, JobStatus.UNKNOWN):
if not self.background_tasks_ok():
self.log.debug('Going to stop job, since background tasks have failed!')
await self.stop(now=True)
status = await self.query_job_status()
if status not in (JobStatus.PENDING, JobStatus.RUNNING, JobStatus.UNKNOWN):
self.clear_state()
return 1
return None
else:
self.clear_state()
Expand Down Expand Up @@ -459,6 +585,14 @@ async def start(self):
if hasattr(self, "mock_port"):
self.port = self.mock_port

if self.connect_to_job_cmd:
await self.connect_to_job()

# Port and ip can be changed in connect_to_job, push out to jupyterhub.
if self.server:
self.server.port = self.port
self.server.ip = self.ip

self.db.commit()
self.log.info(
"Notebook server job {0} started at {1}:{2}".format(
Expand Down Expand Up @@ -887,6 +1021,7 @@ class CondorSpawner(UserEnvMixin, BatchSpawnerRegexStates):
'condor_q {job_id} -format "%s, " JobStatus -format "%s" RemoteHost -format "\n" True'
).tag(config=True)
batch_cancel_cmd = Unicode("condor_rm {job_id}").tag(config=True)
connect_to_job_cmd = Unicode("condor_ssh_to_job -ssh \"ssh -L {port}:localhost:{rport} -oExitOnForwardFailure=yes\" {job_id}").tag(config=True)
# job status: 1 = pending, 2 = running
state_pending_re = Unicode(r"^1,").tag(config=True)
state_running_re = Unicode(r"^2,").tag(config=True)
Expand All @@ -909,6 +1044,12 @@ def cmd_formatted_for_batch(self):
.replace("'", "''")
)

def state_gethost(self):
"""Returns localhost if connect_to_job is used, as this forwards the singleuser server port from the spawned job"""
if self.connect_to_job_cmd:
return "localhost"
else:
return super(CondorSpawner,self).state_gethost()

class LsfSpawner(BatchSpawnerBase):
"""A Spawner that uses IBM's Platform Load Sharing Facility (LSF) to launch notebooks."""
Expand Down
1 change: 1 addition & 0 deletions batchspawner/tests/test_spawners.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def test_condor(db, io_loop):
"req_nprocs": "5",
"req_memory": "5678",
"req_options": "some_option_asdf",
"connect_to_job_cmd": "",
}
batch_script_re_list = [
re.compile(r"exec batchspawner-singleuser singleuser_command"),
Expand Down