diff --git a/batchspawner/batchspawner.py b/batchspawner/batchspawner.py index 866bf51..1adc21c 100644 --- a/batchspawner/batchspawner.py +++ b/batchspawner/batchspawner.py @@ -458,7 +458,6 @@ async def start(self): " but died before launching the single-user server." ) - self.db.commit() self.log.info( "Notebook server job {} started at {}:{}".format( self.job_id, self.ip, self.port diff --git a/batchspawner/tests/conftest.py b/batchspawner/tests/conftest.py index 0fe4f7e..ade777b 100644 --- a/batchspawner/tests/conftest.py +++ b/batchspawner/tests/conftest.py @@ -1,4 +1 @@ """Relevant pytest fixtures are re-used from JupyterHub's test suite""" - -# We use "db" directly, but we also need event_loop -from jupyterhub.tests.conftest import db, event_loop # noqa diff --git a/batchspawner/tests/test_spawners.py b/batchspawner/tests/test_spawners.py index e5d43c0..5fe0dee 100644 --- a/batchspawner/tests/test_spawners.py +++ b/batchspawner/tests/test_spawners.py @@ -4,13 +4,12 @@ import pwd import re import time +from contextlib import contextmanager from getpass import getuser from unittest import mock import pytest -from jupyterhub import orm from jupyterhub.objects import Hub, Server -from jupyterhub.user import User from traitlets import Unicode from .. import BatchSpawnerRegexStates, JobStatus @@ -20,12 +19,21 @@ testport = 54321 +@pytest.fixture +def user(): + mock_user = mock.Mock() + mock_user.name = mock_user.escaped_name = getuser() + mock_user.url = "" + + return mock_user + + @pytest.fixture(autouse=True) -def _always_get_my_home(): +def _always_get_my_home(user): # pwd.getbwnam() is always called with the current user # ignoring the requested name, which usually doesn't exist getpwnam = pwd.getpwnam - with mock.patch.object(pwd, "getpwnam", lambda name: getpwnam(getuser())): + with mock.patch.object(pwd, "getpwnam", lambda name: getpwnam(user.name)): yield @@ -67,168 +75,162 @@ async def run_command(self, *args, **kwargs): return out -def new_spawner(db, spawner_class=BatchDummy, **kwargs): +@contextmanager +def new_spawner(user, spawner_class=BatchDummy, **kwargs): kwargs.setdefault("cmd", ["singleuser_command"]) - user = db.query(orm.User).first() - hub = Hub() - user = User(user, {}) - server = Server() - # Set it after constructions because it isn't a traitlet. - kwargs.setdefault("hub", hub) - kwargs.setdefault("user", user) - kwargs.setdefault("poll_interval", 1) - - # These are not traitlets so we have to set them here - spawner = user._new_spawner("", spawner_class=spawner_class, **kwargs) - spawner.server = server - spawner.mock_port = testport - return spawner + + with mock.patch.object(spawner_class, "_find_my_config") as mock_find_config: + mock_find_config.return_value = {} + spawner = spawner_class(user=user, hub=Hub(), **kwargs) + spawner.server = Server() + spawner.mock_port = testport + yield spawner def check_ip(spawner, value): assert spawner.ip == value -async def test_spawner_start_stop_poll(db, event_loop): - spawner = new_spawner(db=db) - - status = await asyncio.wait_for(spawner.poll(), timeout=5) - assert status == 1 - assert spawner.job_id == "" - assert spawner.get_state() == {} +async def test_spawner_start_stop_poll(user): + with new_spawner(user=user) as spawner: + status = await asyncio.wait_for(spawner.poll(), timeout=5) + assert status == 1 + assert spawner.job_id == "" + assert spawner.get_state() == {} - await asyncio.wait_for(spawner.start(), timeout=5) - check_ip(spawner, testhost) - assert spawner.job_id == testjob + await asyncio.wait_for(spawner.start(), timeout=5) + check_ip(spawner, testhost) + assert spawner.job_id == testjob - status = await asyncio.wait_for(spawner.poll(), timeout=5) - assert status is None - spawner.batch_query_cmd = "echo NOPE" - await asyncio.wait_for(spawner.stop(), timeout=5) - status = await asyncio.wait_for(spawner.poll(), timeout=5) - assert status == 1 - assert spawner.get_state() == {} + status = await asyncio.wait_for(spawner.poll(), timeout=5) + assert status is None + spawner.batch_query_cmd = "echo NOPE" + await asyncio.wait_for(spawner.stop(), timeout=5) + status = await asyncio.wait_for(spawner.poll(), timeout=5) + assert status == 1 + assert spawner.get_state() == {} -async def test_stress_submit(db, event_loop): +async def test_stress_submit(user): for i in range(200): time.sleep(0.01) - test_spawner_start_stop_poll(db, event_loop) + test_spawner_start_stop_poll(user) -async def test_spawner_state_reload(db, event_loop): - spawner = new_spawner(db=db) - assert spawner.get_state() == {} +async def test_spawner_state_reload(user): + with new_spawner(user=user) as spawner: + assert spawner.get_state() == {} - await asyncio.wait_for(spawner.start(), timeout=30) - check_ip(spawner, testhost) - assert spawner.job_id == testjob + await asyncio.wait_for(spawner.start(), timeout=30) + check_ip(spawner, testhost) + assert spawner.job_id == testjob - state = spawner.get_state() - assert state == dict(job_id=testjob, job_status="RUN " + testhost) - spawner = new_spawner(db=db) - spawner.clear_state() - assert spawner.get_state() == {} - spawner.load_state(state) - # We used to check IP here, but that is actually only computed on start(), - # and is not part of the spawner's persistent state - assert spawner.job_id == testjob + state = spawner.get_state() + assert state == dict(job_id=testjob, job_status="RUN " + testhost) + with new_spawner(user=user) as spawner: + spawner.clear_state() + assert spawner.get_state() == {} + spawner.load_state(state) + # We used to check IP here, but that is actually only computed on start(), + # and is not part of the spawner's persistent state + assert spawner.job_id == testjob -async def test_submit_failure(db, event_loop): - spawner = new_spawner(db=db) - assert spawner.get_state() == {} - spawner.batch_submit_cmd = "cat > /dev/null; true" - with pytest.raises(RuntimeError): - await asyncio.wait_for(spawner.start(), timeout=30) - assert spawner.job_id == "" - assert spawner.job_status == "" +async def test_submit_failure(user): + with new_spawner(user=user) as spawner: + assert spawner.get_state() == {} + spawner.batch_submit_cmd = "cat > /dev/null; true" + with pytest.raises(RuntimeError): + await asyncio.wait_for(spawner.start(), timeout=30) + assert spawner.job_id == "" + assert spawner.job_status == "" -async def test_submit_pending_fails(db, event_loop): + +async def test_submit_pending_fails(user): """Submission works, but the batch query command immediately fails""" - spawner = new_spawner(db=db) - assert spawner.get_state() == {} - spawner.batch_query_cmd = "echo xyz" - with pytest.raises(RuntimeError): + with new_spawner(user=user) as spawner: + assert spawner.get_state() == {} + spawner.batch_query_cmd = "echo xyz" + with pytest.raises(RuntimeError): + await asyncio.wait_for(spawner.start(), timeout=30) + status = await asyncio.wait_for(spawner.query_job_status(), timeout=30) + assert status == JobStatus.NOTFOUND + assert spawner.job_id == "" + assert spawner.job_status == "" + + +async def test_poll_fails(user): + """Submission works, but a later .poll() fails""" + with new_spawner(user=user) as spawner: + assert spawner.get_state() == {} + # The start is successful: await asyncio.wait_for(spawner.start(), timeout=30) - status = await asyncio.wait_for(spawner.query_job_status(), timeout=30) - assert status == JobStatus.NOTFOUND - assert spawner.job_id == "" - assert spawner.job_status == "" + spawner.batch_query_cmd = "echo xyz" + # Now, the poll fails: + await asyncio.wait_for(spawner.poll(), timeout=30) + # .poll() will run self.clear_state() if it's not found: + assert spawner.job_id == "" + assert spawner.job_status == "" -async def test_poll_fails(db, event_loop): - """Submission works, but a later .poll() fails""" - spawner = new_spawner(db=db) - assert spawner.get_state() == {} - # The start is successful: - await asyncio.wait_for(spawner.start(), timeout=30) - spawner.batch_query_cmd = "echo xyz" - # Now, the poll fails: - await asyncio.wait_for(spawner.poll(), timeout=30) - # .poll() will run self.clear_state() if it's not found: - assert spawner.job_id == "" - assert spawner.job_status == "" - - -async def test_unknown_status(db, event_loop): +async def test_unknown_status(user): """Polling returns an unknown status""" - spawner = new_spawner(db=db) - assert spawner.get_state() == {} - # The start is successful: - await asyncio.wait_for(spawner.start(), timeout=30) - spawner.batch_query_cmd = "echo UNKNOWN" - # This poll should not fail: - await asyncio.wait_for(spawner.poll(), timeout=30) - status = await asyncio.wait_for(spawner.query_job_status(), timeout=30) - assert status == JobStatus.UNKNOWN - assert spawner.job_id == "12345" - assert spawner.job_status != "" - - -async def test_templates(db, event_loop): - """Test templates in the run_command commands""" - spawner = new_spawner(db=db) - - # Test when not running - spawner.cmd_expectlist = [ - re.compile(".*RUN"), - ] - status = await asyncio.wait_for(spawner.poll(), timeout=5) - assert status == 1 - assert spawner.job_id == "" - assert spawner.get_state() == {} - - # Test starting - spawner.cmd_expectlist = [ - re.compile(".*echo"), - re.compile(".*RUN"), - ] - await asyncio.wait_for(spawner.start(), timeout=5) - check_ip(spawner, testhost) - assert spawner.job_id == testjob + with new_spawner(user=user) as spawner: + assert spawner.get_state() == {} + # The start is successful: + await asyncio.wait_for(spawner.start(), timeout=30) + spawner.batch_query_cmd = "echo UNKNOWN" + # This poll should not fail: + await asyncio.wait_for(spawner.poll(), timeout=30) + status = await asyncio.wait_for(spawner.query_job_status(), timeout=30) + assert status == JobStatus.UNKNOWN + assert spawner.job_id == "12345" + assert spawner.job_status != "" - # Test poll - running - spawner.cmd_expectlist = [ - re.compile(".*RUN"), - ] - status = await asyncio.wait_for(spawner.poll(), timeout=5) - assert status is None - - # Test stopping - spawner.batch_query_cmd = "echo NOPE" - spawner.cmd_expectlist = [ - re.compile(".*STOP"), - re.compile(".*NOPE"), - ] - await asyncio.wait_for(spawner.stop(), timeout=5) - status = await asyncio.wait_for(spawner.poll(), timeout=5) - assert status == 1 - assert spawner.get_state() == {} +async def test_templates(user): + """Test templates in the run_command commands""" -async def test_batch_script(db, event_loop): + with new_spawner(user=user) as spawner: + # Test when not running + spawner.cmd_expectlist = [ + re.compile(".*RUN"), + ] + status = await asyncio.wait_for(spawner.poll(), timeout=5) + assert status == 1 + assert spawner.job_id == "" + assert spawner.get_state() == {} + + # Test starting + spawner.cmd_expectlist = [ + re.compile(".*echo"), + re.compile(".*RUN"), + ] + await asyncio.wait_for(spawner.start(), timeout=5) + check_ip(spawner, testhost) + assert spawner.job_id == testjob + + # Test poll - running + spawner.cmd_expectlist = [ + re.compile(".*RUN"), + ] + status = await asyncio.wait_for(spawner.poll(), timeout=5) + assert status is None + + # Test stopping + spawner.batch_query_cmd = "echo NOPE" + spawner.cmd_expectlist = [ + re.compile(".*STOP"), + re.compile(".*NOPE"), + ] + await asyncio.wait_for(spawner.stop(), timeout=5) + status = await asyncio.wait_for(spawner.poll(), timeout=5) + assert status == 1 + assert spawner.get_state() == {} + + +async def test_batch_script(user): """Test that the batch script substitutes {cmd}""" class BatchDummyTestScript(BatchDummy): @@ -237,14 +239,14 @@ async def _get_batch_script(self, **subvars): assert "singleuser_command" in script return script - spawner = new_spawner(db=db, spawner_class=BatchDummyTestScript) - # status = await asyncio.wait_for(spawner.poll(), timeout=5) - await asyncio.wait_for(spawner.start(), timeout=5) - # status = await asyncio.wait_for(spawner.poll(), timeout=5) - # await asyncio.wait_for(spawner.stop(), timeout=5) + with new_spawner(user=user, spawner_class=BatchDummyTestScript) as spawner: + # status = await asyncio.wait_for(spawner.poll(), timeout=5) + await asyncio.wait_for(spawner.start(), timeout=5) + # status = await asyncio.wait_for(spawner.poll(), timeout=5) + # await asyncio.wait_for(spawner.stop(), timeout=5) -async def test_exec_prefix(db, event_loop): +async def test_exec_prefix(user): """Test that all run_commands have exec_prefix""" class BatchDummyTestScript(BatchDummy): @@ -257,29 +259,29 @@ async def run_command(self, cmd, *args, **kwargs): out = await super().run_command(cmd, *args, **kwargs) return out - spawner = new_spawner(db=db, spawner_class=BatchDummyTestScript) - # Not running - status = await asyncio.wait_for(spawner.poll(), timeout=5) - assert status == 1 - # Start - await asyncio.wait_for(spawner.start(), timeout=5) - assert spawner.job_id == testjob - # Poll - status = await asyncio.wait_for(spawner.poll(), timeout=5) - assert status is None - # Stop - spawner.batch_query_cmd = "echo NOPE" - await asyncio.wait_for(spawner.stop(), timeout=5) - status = await asyncio.wait_for(spawner.poll(), timeout=5) - assert status == 1 + with new_spawner(user=user, spawner_class=BatchDummyTestScript) as spawner: + # Not running + status = await asyncio.wait_for(spawner.poll(), timeout=5) + assert status == 1 + # Start + await asyncio.wait_for(spawner.start(), timeout=5) + assert spawner.job_id == testjob + # Poll + status = await asyncio.wait_for(spawner.poll(), timeout=5) + assert status is None + # Stop + spawner.batch_query_cmd = "echo NOPE" + await asyncio.wait_for(spawner.stop(), timeout=5) + status = await asyncio.wait_for(spawner.poll(), timeout=5) + assert status == 1 async def run_spawner_script( - db, spawner, script, batch_script_re_list=None, spawner_kwargs={} + user, spawner, script, batch_script_re_list=None, spawner_kwargs={} ): """Run a spawner script and test that the output and behavior is as expected. - db: same as in this module + user: mock user spawner: the BatchSpawnerBase subclass to test script: list of (input_re_to_match, output) batch_script_re_list: if given, assert batch script matches all of these @@ -312,27 +314,29 @@ async def run_command(self, cmd, input=None, env=None): print(" --> " + out) return out - spawner = new_spawner(db=db, spawner_class=BatchDummyTestScript, **spawner_kwargs) - # Not running at beginning (no command run) - status = await asyncio.wait_for(spawner.poll(), timeout=5) - assert status == 1 - # batch_submit_cmd - # batch_query_cmd (result=pending) - # batch_query_cmd (result=running) - await asyncio.wait_for(spawner.start(), timeout=5) - assert spawner.job_id == testjob - check_ip(spawner, testhost) - # batch_query_cmd - status = await asyncio.wait_for(spawner.poll(), timeout=5) - assert status is None - # batch_cancel_cmd - await asyncio.wait_for(spawner.stop(), timeout=5) - # batch_poll_cmd - status = await asyncio.wait_for(spawner.poll(), timeout=5) - assert status == 1 - - -async def test_torque(db, event_loop): + with new_spawner( + user=user, spawner_class=BatchDummyTestScript, **spawner_kwargs + ) as spawner: + # Not running at beginning (no command run) + status = await asyncio.wait_for(spawner.poll(), timeout=5) + assert status == 1 + # batch_submit_cmd + # batch_query_cmd (result=pending) + # batch_query_cmd (result=running) + await asyncio.wait_for(spawner.start(), timeout=5) + assert spawner.job_id == testjob + check_ip(spawner, testhost) + # batch_query_cmd + status = await asyncio.wait_for(spawner.poll(), timeout=5) + assert status is None + # batch_cancel_cmd + await asyncio.wait_for(spawner.stop(), timeout=5) + # batch_poll_cmd + status = await asyncio.wait_for(spawner.poll(), timeout=5) + assert status == 1 + + +async def test_torque(user): spawner_kwargs = { "req_nprocs": "5", "req_memory": "5678", @@ -368,7 +372,7 @@ async def test_torque(db, event_loop): from .. import TorqueSpawner await run_spawner_script( - db, + user, TorqueSpawner, script, batch_script_re_list=batch_script_re_list, @@ -376,7 +380,7 @@ async def test_torque(db, event_loop): ) -async def test_moab(db, event_loop): +async def test_moab(user): spawner_kwargs = { "req_nprocs": "5", "req_memory": "5678", @@ -409,7 +413,7 @@ async def test_moab(db, event_loop): from .. import MoabSpawner await run_spawner_script( - db, + user, MoabSpawner, script, batch_script_re_list=batch_script_re_list, @@ -417,7 +421,7 @@ async def test_moab(db, event_loop): ) -async def test_pbs(db, event_loop): +async def test_pbs(user): spawner_kwargs = { "req_nprocs": "4", "req_memory": "10256", @@ -450,7 +454,7 @@ async def test_pbs(db, event_loop): from .. import PBSSpawner await run_spawner_script( - db, + user, PBSSpawner, script, batch_script_re_list=batch_script_re_list, @@ -458,7 +462,7 @@ async def test_pbs(db, event_loop): ) -async def test_slurm(db, event_loop): +async def test_slurm(user): spawner_kwargs = { "req_runtime": "3-05:10:10", "req_nprocs": "5", @@ -482,7 +486,7 @@ async def test_slurm(db, event_loop): from .. import SlurmSpawner await run_spawner_script( - db, + user, SlurmSpawner, normal_slurm_script, batch_script_re_list=batch_script_re_list, @@ -509,7 +513,7 @@ async def test_slurm(db, event_loop): async def run_typical_slurm_spawner( - db, + user, spawner=SlurmSpawner, script=normal_slurm_script, batch_script_re_list=None, @@ -521,7 +525,7 @@ async def run_typical_slurm_spawner( of batch scripts. """ return await run_spawner_script( - db, + user, spawner, script, batch_script_re_list=batch_script_re_list, @@ -529,7 +533,7 @@ async def run_typical_slurm_spawner( ) -# async def test_gridengine(db, event_loop): +# async def test_gridengine(user): # spawner_kwargs = { # 'req_options': 'some_option_asdf', # } @@ -546,12 +550,12 @@ async def run_typical_slurm_spawner( # (re.compile(r'sudo.*qstat'), ''), # ] # from .. import GridengineSpawner -# await run_spawner_script(db, GridengineSpawner, script, +# await run_spawner_script(user, GridengineSpawner, script, # batch_script_re_list=batch_script_re_list, # spawner_kwargs=spawner_kwargs) -async def test_condor(db, event_loop): +async def test_condor(user): spawner_kwargs = { "req_nprocs": "5", "req_memory": "5678", @@ -578,7 +582,7 @@ async def test_condor(db, event_loop): from .. import CondorSpawner await run_spawner_script( - db, + user, CondorSpawner, script, batch_script_re_list=batch_script_re_list, @@ -586,7 +590,7 @@ async def test_condor(db, event_loop): ) -async def test_lfs(db, event_loop): +async def test_lfs(user): spawner_kwargs = { "req_nprocs": "5", "req_memory": "5678", @@ -617,7 +621,7 @@ async def test_lfs(db, event_loop): from .. import LsfSpawner await run_spawner_script( - db, + user, LsfSpawner, script, batch_script_re_list=batch_script_re_list, @@ -625,7 +629,7 @@ async def test_lfs(db, event_loop): ) -async def test_keepvars(db, event_loop): +async def test_keepvars(user): # req_keepvars spawner_kwargs = { "req_keepvars": "ABCDE", @@ -634,7 +638,7 @@ async def test_keepvars(db, event_loop): re.compile(r"--export=ABCDE", re.X | re.M), ] await run_typical_slurm_spawner( - db, + user, spawner_kwargs=spawner_kwargs, batch_script_re_list=batch_script_re_list, ) @@ -648,13 +652,13 @@ async def test_keepvars(db, event_loop): re.compile(r"--export=ABCDE,XYZ", re.X | re.M), ] await run_typical_slurm_spawner( - db, + user, spawner_kwargs=spawner_kwargs, batch_script_re_list=batch_script_re_list, ) -async def test_early_stop(db, event_loop): +async def test_early_stop(user): script = [ (re.compile(r"sudo.*sbatch"), str(testjob)), (re.compile(r"sudo.*squeue"), "PENDING "), # pending @@ -667,4 +671,4 @@ async def test_early_stop(db, event_loop): (re.compile(r"sudo.*scancel"), "STOP"), ] with pytest.raises(RuntimeError, match="job has disappeared"): - await run_spawner_script(db, SlurmSpawner, script) + await run_spawner_script(user, SlurmSpawner, script)