diff --git a/08-snakemake/profiles/slurm/CookieCutter.py b/08-snakemake/profiles/slurm/CookieCutter.py new file mode 100644 index 0000000..cfa9bb8 --- /dev/null +++ b/08-snakemake/profiles/slurm/CookieCutter.py @@ -0,0 +1,39 @@ +# +# Based on lsf CookieCutter.py +# +import os +import json + +d = os.path.dirname(__file__) +with open(os.path.join(d, "settings.json")) as fh: + settings = json.load(fh) + + +def from_entry_or_env(values, key): + """Return value from ``values`` and override with environment variables.""" + if key in os.environ: + return os.environ[key] + else: + return values[key] + + +class CookieCutter: + + SBATCH_DEFAULTS = from_entry_or_env(settings, "SBATCH_DEFAULTS") + CLUSTER_NAME = from_entry_or_env(settings, "CLUSTER_NAME") + CLUSTER_CONFIG = from_entry_or_env(settings, "CLUSTER_CONFIG") + + @staticmethod + def get_cluster_option() -> str: + cluster = CookieCutter.CLUSTER_NAME + if cluster != "": + return f"--cluster={cluster}" + return "" + + @staticmethod + def get_cluster_logpath() -> str: + return "logs/slurm/%r/%j" + + @staticmethod + def get_cluster_jobname() -> str: + return "%r_%w" diff --git a/08-snakemake/profiles/slurm/__pycache__/CookieCutter.cpython-311.pyc b/08-snakemake/profiles/slurm/__pycache__/CookieCutter.cpython-311.pyc new file mode 100644 index 0000000..59a2d99 Binary files /dev/null and b/08-snakemake/profiles/slurm/__pycache__/CookieCutter.cpython-311.pyc differ diff --git a/08-snakemake/profiles/slurm/__pycache__/slurm_utils.cpython-311.pyc b/08-snakemake/profiles/slurm/__pycache__/slurm_utils.cpython-311.pyc new file mode 100644 index 0000000..9b4aa84 Binary files /dev/null and b/08-snakemake/profiles/slurm/__pycache__/slurm_utils.cpython-311.pyc differ diff --git a/08-snakemake/profiles/slurm/config.yaml b/08-snakemake/profiles/slurm/config.yaml new file mode 100644 index 0000000..a21ef0e --- /dev/null +++ b/08-snakemake/profiles/slurm/config.yaml @@ -0,0 +1,29 @@ + +cluster-sidecar: "slurm-sidecar.py" +cluster-cancel: "scancel" +restart-times: "2" +jobscript: "slurm-jobscript.sh" +cluster: "slurm-submit.py" +cluster-status: "slurm-status.py" +max-jobs-per-second: "10" +max-status-checks-per-second: "10" +local-cores: 1 +latency-wait: "5" +use-conda: "False" +use-singularity: "False" +jobs: "500" +printshellcmds: "True" + +# Example resource configuration +# default-resources: +# - runtime=100 +# - mem_mb=6000 +# - disk_mb=1000000 +# # set-threads: map rule names to threads +# set-threads: +# - single_core_rule=1 +# - multi_core_rule=10 +# # set-resources: map rule names to resources in general +# set-resources: +# - high_memory_rule:mem_mb=12000 +# - long_running_rule:runtime=1200 diff --git a/08-snakemake/profiles/slurm/settings.json b/08-snakemake/profiles/slurm/settings.json new file mode 100644 index 0000000..4402649 --- /dev/null +++ b/08-snakemake/profiles/slurm/settings.json @@ -0,0 +1,5 @@ +{ + "SBATCH_DEFAULTS": "", + "CLUSTER_NAME": "eri", + "CLUSTER_CONFIG": "" +} diff --git a/08-snakemake/profiles/slurm/slurm-jobscript.sh b/08-snakemake/profiles/slurm/slurm-jobscript.sh new file mode 100755 index 0000000..391741e --- /dev/null +++ b/08-snakemake/profiles/slurm/slurm-jobscript.sh @@ -0,0 +1,3 @@ +#!/bin/bash +# properties = {properties} +{exec_job} diff --git a/08-snakemake/profiles/slurm/slurm-sidecar.py b/08-snakemake/profiles/slurm/slurm-sidecar.py new file mode 100755 index 0000000..e79f5da --- /dev/null +++ b/08-snakemake/profiles/slurm/slurm-sidecar.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python3 +"""Run a Snakemake v7+ sidecar process for Slurm + +This sidecar process will poll ``squeue --user [user] --format='%i,%T'`` +every 60 seconds by default (use environment variable +``SNAKEMAKE_SLURM_SQUEUE_WAIT`` for adjusting this). + +Note that you have to adjust the value to fit to your ``MinJobAge`` Slurm +configuration. Jobs remain at least ``MinJobAge`` seconds known to the +Slurm controller (default of 300 seconds). If you query ``squeue`` every +60 seconds then this is plenty and you will observe all relevant job status +states as they are relevant for Snakemake. + +If the environment variable ``SNAKEMAKE_CLUSTER_SIDECAR_VARS`` is set then +the ``slurm-status.py`` of the slurm profile will attempt to query this +sidecar process via HTTP. As the sidecar process does not update its +cache in real-time, setting ``SNAKEMAKE_SLURM_SQUEUE_WAIT`` too large might +lead to Snakemake missing the "done" job state. The defaults of +``SNAKEMAKE_SLURM_SQUEUE_WAIT=60`` and Slurm's ``MinJobAge=600`` work well +together and you will see all relevant job statuses. + +If the sidecar is queried for a job ID that it has not seen yet then it will +perform a query to ``sacct`` such that it works well if Snakemake "resume +external job" feature. The ``slurm-submit.py`` script of the Snakemake profile +will register all jobs via POST with this sidecar. +""" + +import http.server +import json +import logging +import os +import subprocess +import sys +import signal +import time +import threading +import uuid + +from CookieCutter import CookieCutter + + +#: Enables debug messages for slurm sidecar. +DEBUG = bool(int(os.environ.get("SNAKEMAKE_SLURM_DEBUG", "0"))) +#: Enables HTTP request logging in sidecar. +LOG_REQUESTS = bool(int(os.environ.get("SNAKEMAKE_SLURM_LOG_REQUESTS", "0"))) +#: Command to call when calling squeue +SQUEUE_CMD = os.environ.get("SNAKEMAKE_SLURM_SQUEUE_CMD", "squeue") +#: Number of seconds to wait between ``squeue`` calls. +SQUEUE_WAIT = int(os.environ.get("SNAKEMAKE_SLURM_SQUEUE_WAIT", "60")) + +logger = logging.getLogger(__name__) +if DEBUG: + logging.basicConfig(level=logging.DEBUG) + logger.setLevel(logging.DEBUG) + + +class PollSqueueThread(threading.Thread): + """Thread that polls ``squeue`` until stopped by ``stop()``""" + + def __init__( + self, + squeue_wait, + squeue_cmd, + squeue_timeout=2, + sleep_time=0.01, + max_tries=3, + *args, + **kwargs + ): + super().__init__(target=self._work, *args, **kwargs) + #: Time to wait between squeue calls. + self.squeue_wait = squeue_wait + #: Command to call squeue with. + self.squeue_cmd = squeue_cmd + #: Whether or not the thread should stop. + self.stopped = threading.Event() + #: Previous call to ``squeue`` + self.prev_call = 0.0 + #: Time to sleep between iterations in seconds. Thread can only be + #: terminated after this interval when waiting. + self.sleep_time = sleep_time + #: Maximal running time to accept for call to ``squeue``. + self.squeue_timeout = squeue_timeout + #: Maximal number of tries if call to ``squeue`` fails. + self.max_tries = max_tries + #: Dict mapping the job id to the job state string. + self.states = {} + #: Make at least one call to squeue, must not fail. + logger.debug("initializing trhead") + self._call_squeue(allow_failure=False) + self.prev_call = time.time() + + def _work(self): + """Execute the thread's action""" + while not self.stopped.is_set(): + now = time.time() + if now - self.prev_call > self.squeue_wait: + self._call_squeue() + self.prev_call = now + time.sleep(self.sleep_time) + + def get_state(self, jobid): + """Return the job state for the given jobid.""" + jobid = str(jobid) + if jobid not in self.states: + try: + self.states[jobid] = self._get_state_sacct(jobid) + except: + return "__not_seen_yet__" + return self.states.get(jobid, "__not_seen_yet__") + + def register_job(self, jobid): + """Register job with the given ID.""" + self.states.setdefault(jobid, None) + + def _get_state_sacct(self, jobid): + """Implement retrieving state via sacct for resuming jobs.""" + cluster = CookieCutter.get_cluster_option() + cmd = ["sacct", "-P", "-b", "-j", jobid, "-n"] + if cluster: + cmd.append(cluster) + try_num = 0 + while try_num < self.max_tries: + try_num += 1 + try: + logger.debug("Calling %s (try %d)", cmd, try_num) + output = subprocess.check_output(cmd, timeout=self.squeue_timeout, text=True) + except subprocess.TimeoutExpired as e: + logger.warning("Call to %s timed out (try %d of %d)", cmd, try_num, self.max_tries) + continue + except subprocess.CalledProcessError as e: + logger.warning("Call to %s failed (try %d of %d)", cmd, try_num, self.max_tries) + continue + try: + parsed = {x.split("|")[0]: x.split("|")[1] for x in output.strip().split("\n")} + logger.debug("Returning state of %s as %s", jobid, parsed[jobid]) + return parsed[jobid] + except IndexError: + logger.warning("Could not parse %s (try %d of %d)", repr(output), try_num, self.max_tries) + secs = try_num / 2.0 + loger.info("Sleeping %f seconds", secs) + time.sleep(secs) + raise Exception("Problem with call to %s" % cmd) + + def stop(self): + """Flag thread to stop execution""" + logger.debug("stopping thread") + self.stopped.set() + + def _call_squeue(self, allow_failure=True): + """Run the call to ``squeue``""" + cluster = CookieCutter.get_cluster_option() + try_num = 0 + cmd = [SQUEUE_CMD, "--user={}".format(os.environ.get("USER")), "--format=%i,%T", "--state=all"] + if cluster: + cmd.append(cluster) + while try_num < self.max_tries: + try_num += 1 + try: + logger.debug("Calling %s (try %d)", cmd, try_num) + output = subprocess.check_output(cmd, timeout=self.squeue_timeout, text=True) + logger.debug("Output is:\n---\n%s\n---", output) + break + except subprocess.TimeoutExpired as e: + if not allow_failure: + raise + logger.debug("Call to %s timed out (try %d of %d)", cmd, try_num, self.max_tries) + except subprocess.CalledProcessError as e: + if not allow_failure: + raise + logger.debug("Call to %s failed (try %d of %d)", cmd, try_num, self.max_tries) + if try_num >= self.max_tries: + logger.debug("Giving up for this round") + else: + logger.debug("parsing output") + self._parse_output(output) + + def _parse_output(self, output): + """Parse output of ``squeue`` call.""" + header = None + for line in output.splitlines(): + line = line.strip() + arr = line.split(",") + if not header: + if not line.startswith("JOBID"): + continue # skip leader + header = arr + else: + logger.debug("Updating state of %s to %s", arr[0], arr[1]) + self.states[arr[0]] = arr[1] + + +class JobStateHttpHandler(http.server.BaseHTTPRequestHandler): + """HTTP handler class that responds to ```/job/status/${jobid}/`` GET requests""" + + def do_GET(self): + """Only to ``/job/status/${job_id}/?``""" + logger.debug("--- BEGIN GET") + # Remove trailing slashes from path. + path = self.path + while path.endswith("/"): + path = path[:-1] + # Ensure that /job/status was requested + if not self.path.startswith("/job/status/"): + self.send_response(400) + self.end_headers() + return + # Ensure authentication bearer is correct + auth_required = "Bearer %s" % self.server.http_secret + auth_header = self.headers.get("Authorization") + logger.debug( + "Authorization header is %s, required: %s" % (repr(auth_header), repr(auth_required)) + ) + if auth_header != auth_required: + self.send_response(403) + self.end_headers() + return + # Otherwise, query job ID status + job_id = self.path[len("/job/status/") :] + try: + job_id=job_id.split("%20")[3] + except IndexError: + pass + logger.debug("Querying for job ID %s" % repr(job_id)) + status = self.server.poll_thread.get_state(job_id) + logger.debug("Status: %s" % status) + if not status: + self.send_response(404) + self.end_headers() + else: + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + output = json.dumps({"status": status}) + logger.debug("Sending %s" % repr(output)) + self.wfile.write(output.encode("utf-8")) + logger.debug("--- END GET") + + def do_POST(self): + """Handle POSTs (only to ``/job/register/${job_id}/?``)""" + logger.debug("--- BEGIN POST") + # Remove trailing slashes from path. + path = self.path + while path.endswith("/"): + path = path[:-1] + # Ensure that /job/register was requested + if not self.path.startswith("/job/register/"): + self.send_response(400) + self.end_headers() + return + # Ensure authentication bearer is correct + auth_required = "Bearer %s" % self.server.http_secret + auth_header = self.headers.get("Authorization") + logger.debug( + "Authorization header is %s, required: %s", repr(auth_header), repr(auth_required) + ) + # Otherwise, register job ID + job_id = self.path[len("/job/status/") :] + self.server.poll_thread.register_job(job_id) + self.send_response(200) + self.end_headers() + logger.debug("--- END POST") + + def log_request(self, *args, **kwargs): + if LOG_REQUESTS: + super().log_request(*args, **kwargs) + + +class JobStateHttpServer(http.server.HTTPServer): + """The HTTP server class""" + + allow_reuse_address = False + + def __init__(self, poll_thread): + """Initialize thread and print the ``SNAKEMAKE_CLUSTER_SIDECAR_VARS`` to stdout, then flush.""" + super().__init__(("0.0.0.0", 0), JobStateHttpHandler) + #: The ``PollSqueueThread`` with the state dictionary. + self.poll_thread = poll_thread + #: The secret to use. + self.http_secret = str(uuid.uuid4()) + sidecar_vars = { + "server_port": self.server_port, + "server_secret": self.http_secret, + "pid": os.getpid(), + } + logger.debug(json.dumps(sidecar_vars)) + sys.stdout.write(json.dumps(sidecar_vars) + "\n") + sys.stdout.flush() + + def log_message(self, *args, **kwargs): + """Log messages are printed if ``DEBUG`` is ``True``.""" + if DEBUG: + super().log_message(*args, **kwargs) + + +def main(): + # Start thread to poll ``squeue`` in a controlled fashion. + poll_thread = PollSqueueThread(SQUEUE_WAIT, SQUEUE_CMD, name="poll-squeue") + poll_thread.start() + + # Initialize HTTP server that makes available the output of ``squeue --user [user]`` + # in a controlled fashion. + http_server = JobStateHttpServer(poll_thread) + http_thread = threading.Thread(name="http-server", target=http_server.serve_forever) + http_thread.start() + + # Allow for graceful shutdown of poll thread and HTTP server. + def signal_handler(signum, frame): + """Handler for Unix signals. Shuts down http_server and poll_thread.""" + logger.info("Shutting down squeue poll thread and HTTP server...") + # from remote_pdb import set_trace + # set_trace() + poll_thread.stop() + http_server.shutdown() + logger.info("... HTTP server and poll thread shutdown complete.") + for thread in threading.enumerate(): + logger.info("ACTIVE %s", thread.name) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # Actually run the server. + poll_thread.join() + logger.debug("poll_thread done") + http_thread.join() + logger.debug("http_thread done") + + +if __name__ == "__main__": + sys.exit(int(main() or 0)) diff --git a/08-snakemake/profiles/slurm/slurm-status.py b/08-snakemake/profiles/slurm/slurm-status.py new file mode 100755 index 0000000..7cc28d1 --- /dev/null +++ b/08-snakemake/profiles/slurm/slurm-status.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +import json +import os +import re +import requests +import subprocess as sp +import shlex +import sys +import time +import logging +from CookieCutter import CookieCutter + +logger = logging.getLogger(__name__) + +STATUS_ATTEMPTS = 20 +SIDECAR_VARS = os.environ.get("SNAKEMAKE_CLUSTER_SIDECAR_VARS", None) +DEBUG = bool(int(os.environ.get("SNAKEMAKE_SLURM_DEBUG", "0"))) + +if DEBUG: + logging.basicConfig(level=logging.DEBUG) + logger.setLevel(logging.DEBUG) + + +def get_status_direct(jobid): + """Get status directly from sacct/scontrol""" + cluster = CookieCutter.get_cluster_option() + for i in range(STATUS_ATTEMPTS): + try: + sacct_res = sp.check_output(shlex.split(f"sacct {cluster} -P -b -j {jobid} -n")) + res = {x.split("|")[0]: x.split("|")[1] for x in sacct_res.decode().strip().split("\n")} + break + except sp.CalledProcessError as e: + logger.error("sacct process error") + logger.error(e) + except IndexError as e: + logger.error(e) + pass + # Try getting job with scontrol instead in case sacct is misconfigured + try: + sctrl_res = sp.check_output(shlex.split(f"scontrol {cluster} -o show job {jobid}")) + m = re.search(r"JobState=(\w+)", sctrl_res.decode()) + res = {jobid: m.group(1)} + break + except sp.CalledProcessError as e: + logger.error("scontrol process error") + logger.error(e) + if i >= STATUS_ATTEMPTS - 1: + print("failed") + exit(0) + else: + time.sleep(1) + + return res[jobid] or "" + + +def get_status_sidecar(jobid): + """Get status from cluster sidecar""" + sidecar_vars = json.loads(SIDECAR_VARS) + url = "http://localhost:%d/job/status/%s" % (sidecar_vars["server_port"], jobid) + headers = {"Authorization": "Bearer %s" % sidecar_vars["server_secret"]} + try: + resp = requests.get(url, headers=headers) + if resp.status_code == 404: + return "" # not found yet + logger.debug("sidecar returned: %s" % resp.json()) + resp.raise_for_status() + return resp.json().get("status") or "" + except requests.exceptions.ConnectionError as e: + logger.warning("slurm-status.py: could not query side car: %s", e) + logger.info("slurm-status.py: falling back to direct query") + return get_status_direct(jobid) + + +jobid = sys.argv[1] + +if SIDECAR_VARS: + logger.debug("slurm-status.py: querying sidecar") + status = get_status_sidecar(jobid) +else: + logger.debug("slurm-status.py: direct query") + status = get_status_direct(jobid) + +logger.debug("job status: %s", repr(status)) + +if status == "BOOT_FAIL": + print("failed") +elif status == "OUT_OF_MEMORY": + print("failed") +elif status.startswith("CANCELLED"): + print("failed") +elif status == "COMPLETED": + print("success") +elif status == "DEADLINE": + print("failed") +elif status == "FAILED": + print("failed") +elif status == "NODE_FAIL": + print("failed") +elif status == "PREEMPTED": + print("failed") +elif status == "TIMEOUT": + print("failed") +elif status == "SUSPENDED": + print("running") +else: + print("running") diff --git a/08-snakemake/profiles/slurm/slurm-submit.py b/08-snakemake/profiles/slurm/slurm-submit.py new file mode 100755 index 0000000..c5544b4 --- /dev/null +++ b/08-snakemake/profiles/slurm/slurm-submit.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +""" +Snakemake SLURM submit script. +""" +import json +import logging +import os + +import requests +from snakemake.utils import read_job_properties + +import slurm_utils +from CookieCutter import CookieCutter + +logger = logging.getLogger(__name__) + +SIDECAR_VARS = os.environ.get("SNAKEMAKE_CLUSTER_SIDECAR_VARS", None) +DEBUG = bool(int(os.environ.get("SNAKEMAKE_SLURM_DEBUG", "0"))) + +if DEBUG: + logging.basicConfig(level=logging.DEBUG) + logger.setLevel(logging.DEBUG) + + +def register_with_sidecar(jobid): + if SIDECAR_VARS is None: + return + sidecar_vars = json.loads(SIDECAR_VARS) + url = "http://localhost:%d/job/register/%s" % (sidecar_vars["server_port"], jobid) + logger.debug("POST to %s", url) + headers = {"Authorization": "Bearer %s" % sidecar_vars["server_secret"]} + requests.post(url, headers=headers) + + +# cookiecutter arguments +SBATCH_DEFAULTS = CookieCutter.SBATCH_DEFAULTS +CLUSTER = CookieCutter.get_cluster_option() +CLUSTER_CONFIG = CookieCutter.CLUSTER_CONFIG + +RESOURCE_MAPPING = { + "time": ("time", "runtime", "walltime"), + "mem": ("mem", "mem_mb", "ram", "memory"), + "mem-per-cpu": ("mem-per-cpu", "mem_per_cpu", "mem_per_thread"), + "nodes": ("nodes", "nnodes"), + "partition": ("partition", "queue"), +} + +# parse job +jobscript = slurm_utils.parse_jobscript() +job_properties = read_job_properties(jobscript) + +sbatch_options = {} +cluster_config = slurm_utils.load_cluster_config(CLUSTER_CONFIG) + +# 1) sbatch default arguments and cluster +sbatch_options.update(slurm_utils.parse_sbatch_defaults(SBATCH_DEFAULTS)) +sbatch_options.update(slurm_utils.parse_sbatch_defaults(CLUSTER)) + +# 2) cluster_config defaults +sbatch_options.update(cluster_config["__default__"]) + +# 3) Convert resources (no unit conversion!) and threads +sbatch_options.update(slurm_utils.convert_job_properties(job_properties, RESOURCE_MAPPING)) + +# 4) cluster_config for particular rule +sbatch_options.update(cluster_config.get(job_properties.get("rule"), {})) + +# 5) cluster_config options +sbatch_options.update(job_properties.get("cluster", {})) + +# convert human-friendly time - leaves slurm format time as is +if "time" in sbatch_options: + duration = str(sbatch_options["time"]) + sbatch_options["time"] = str(slurm_utils.Time(duration)) + +# 6) Format pattern in snakemake style +sbatch_options = slurm_utils.format_values(sbatch_options, job_properties) + +# 7) create output and error filenames and paths +joblog = slurm_utils.JobLog(job_properties) +log = "" +if "output" not in sbatch_options and CookieCutter.get_cluster_logpath(): + outlog = joblog.outlog + log = outlog + sbatch_options["output"] = outlog + +if "error" not in sbatch_options and CookieCutter.get_cluster_logpath(): + errlog = joblog.errlog + log = errlog + sbatch_options["error"] = errlog + +# ensure sbatch output dirs exist +for o in ("output", "error"): + slurm_utils.ensure_dirs_exist(sbatch_options[o]) if o in sbatch_options else None + +# 9) Set slurm job name +if "job-name" not in sbatch_options and "job_name" not in sbatch_options: + sbatch_options["job-name"] = joblog.jobname + +# submit job and echo id back to Snakemake (must be the only stdout) +jobid = slurm_utils.submit_job(jobscript, **sbatch_options) +logger.debug("Registering %s with sidecar...", jobid) +register_with_sidecar(jobid) +logger.debug("... done registering with sidecar") +print(jobid) diff --git a/08-snakemake/profiles/slurm/slurm_utils.py b/08-snakemake/profiles/slurm/slurm_utils.py new file mode 100644 index 0000000..c420154 --- /dev/null +++ b/08-snakemake/profiles/slurm/slurm_utils.py @@ -0,0 +1,403 @@ +#!/usr/bin/env python3 +import argparse +import math +import os +import re +import subprocess as sp +import sys +from datetime import timedelta +from os.path import dirname +from time import time as unix_time +from typing import Union +from uuid import uuid4 +import shlex +from io import StringIO + +from CookieCutter import CookieCutter +from snakemake import io +from snakemake.exceptions import WorkflowError +from snakemake.io import Wildcards +from snakemake.logging import logger +from snakemake.utils import AlwaysQuotedFormatter +from snakemake.utils import QuotedFormatter +from snakemake.utils import SequenceFormatter + + +def _convert_units_to_mb(memory): + """If memory is specified with SI unit, convert to MB""" + if isinstance(memory, int) or isinstance(memory, float): + return int(memory) + siunits = {"K": 1e-3, "M": 1, "G": 1e3, "T": 1e6} + regex = re.compile(r"(\d+)({})$".format("|".join(siunits.keys()))) + m = regex.match(memory) + if m is None: + logger.error( + (f"unsupported memory specification '{memory}';" " allowed suffixes: [K|M|G|T]") + ) + sys.exit(1) + factor = siunits[m.group(2)] + return int(int(m.group(1)) * factor) + + +def parse_jobscript(): + """Minimal CLI to require/only accept single positional argument.""" + p = argparse.ArgumentParser(description="SLURM snakemake submit script") + p.add_argument("jobscript", help="Snakemake jobscript with job properties.") + return p.parse_args().jobscript + + +def parse_sbatch_defaults(parsed): + """Unpack SBATCH_DEFAULTS.""" + d = shlex.split(parsed) if type(parsed) == str else parsed + args = {} + for keyval in [a.split("=") for a in d]: + k = keyval[0].strip().strip("-") + v = keyval[1].strip() if len(keyval) == 2 else None + args[k] = v + return args + + +def load_cluster_config(path): + """Load config to dict + + Load configuration to dict either from absolute path or relative + to profile dir. + """ + if path: + path = os.path.join(dirname(__file__), os.path.expandvars(path)) + dcc = io.load_configfile(path) + else: + dcc = {} + if "__default__" not in dcc: + dcc["__default__"] = {} + return dcc + + +# adapted from format function in snakemake.utils +def format(_pattern, _quote_all=False, **kwargs): # noqa: A001 + """Format a pattern in Snakemake style. + This means that keywords embedded in braces are replaced by any variable + values that are available in the current namespace. + """ + fmt = SequenceFormatter(separator=" ") + if _quote_all: + fmt.element_formatter = AlwaysQuotedFormatter() + else: + fmt.element_formatter = QuotedFormatter() + try: + return fmt.format(_pattern, **kwargs) + except KeyError as ex: + raise NameError( + f"The name {ex} is unknown in this context. Please " + "make sure that you defined that variable. " + "Also note that braces not used for variable access " + "have to be escaped by repeating them " + ) + + +# adapted from Job.format_wildcards in snakemake.jobs +def format_wildcards(string, job_properties): + """Format a string with variables from the job.""" + + class Job(object): + def __init__(self, job_properties): + for key in job_properties: + setattr(self, key, job_properties[key]) + + job = Job(job_properties) + if "params" in job_properties: + job._format_params = Wildcards(fromdict=job_properties["params"]) + else: + job._format_params = None + if "wildcards" in job_properties: + job._format_wildcards = Wildcards(fromdict=job_properties["wildcards"]) + else: + job._format_wildcards = None + _variables = dict() + _variables.update(dict(params=job._format_params, wildcards=job._format_wildcards)) + if hasattr(job, "rule"): + _variables.update(dict(rule=job.rule)) + try: + return format(string, **_variables) + except NameError as ex: + raise WorkflowError("NameError with group job {}: {}".format(job.jobid, str(ex))) + except IndexError as ex: + raise WorkflowError("IndexError with group job {}: {}".format(job.jobid, str(ex))) + + +# adapted from ClusterExecutor.cluster_params function in snakemake.executor +def format_values(dictionary, job_properties): + formatted = dictionary.copy() + for key, value in list(formatted.items()): + if key == "mem": + value = str(_convert_units_to_mb(value)) + if isinstance(value, str): + try: + formatted[key] = format_wildcards(value, job_properties) + except NameError as e: + msg = "Failed to format cluster config " "entry for job {}.".format( + job_properties["rule"] + ) + raise WorkflowError(msg, e) + return formatted + + +def convert_job_properties(job_properties, resource_mapping=None): + options = {} + if resource_mapping is None: + resource_mapping = {} + resources = job_properties.get("resources", {}) + for k, v in resource_mapping.items(): + options.update({k: resources[i] for i in v if i in resources}) + + if "threads" in job_properties: + options["cpus-per-task"] = job_properties["threads"] + + slurm_opts = resources.get("slurm", "") + if not isinstance(slurm_opts, str): + raise ValueError( + "The `slurm` argument to resources must be a space-separated string" + ) + + for opt in slurm_opts.split(): + kv = opt.split("=", maxsplit=1) + k = kv[0] + v = None if len(kv) == 1 else kv[1] + options[k.lstrip("-").replace("_", "-")] = v + + return options + + +def ensure_dirs_exist(path): + """Ensure output folder for Slurm log files exist.""" + di = dirname(path) + if di == "": + return + if not os.path.exists(di): + os.makedirs(di, exist_ok=True) + return + + +def format_sbatch_options(**sbatch_options): + """Format sbatch options""" + options = [] + for k, v in sbatch_options.items(): + val = "" + if v is not None: + val = f"={v}" + options.append(f"--{k}{val}") + return options + + +def submit_job(jobscript, **sbatch_options): + """Submit jobscript and return jobid.""" + options = format_sbatch_options(**sbatch_options) + try: + cmd = ["sbatch"] + ["--parsable"] + options + [jobscript] + res = sp.check_output(cmd) + except sp.CalledProcessError as e: + raise e + # Get jobid + res = res.decode() + try: + jobid = re.search(r"(\d+)", res).group(1) + except Exception as e: + raise e + return jobid + + +timeformats = [ + re.compile(r"^(?P\d+)-(?P\d+):(?P\d+):(?P\d+)$"), + re.compile(r"^(?P\d+)-(?P\d+):(?P\d+)$"), + re.compile(r"^(?P\d+)-(?P\d+)$"), + re.compile(r"^(?P\d+):(?P\d+):(?P\d+)$"), + re.compile(r"^(?P\d+):(?P\d+)$"), + re.compile(r"^(?P\d+)$"), +] + + +def time_to_minutes(time): + """Convert time string to minutes. + + According to slurm: + + Acceptable time formats include "minutes", "minutes:seconds", + "hours:minutes:seconds", "days-hours", "days-hours:minutes" + and "days-hours:minutes:seconds". + + """ + if not isinstance(time, str): + time = str(time) + d = {"days": 0, "hours": 0, "minutes": 0, "seconds": 0} + regex = list(filter(lambda regex: regex.match(time) is not None, timeformats)) + if len(regex) == 0: + return + assert len(regex) == 1, "multiple time formats match" + m = regex[0].match(time) + d.update(m.groupdict()) + minutes = ( + int(d["days"]) * 24 * 60 + + int(d["hours"]) * 60 + + int(d["minutes"]) + + math.ceil(int(d["seconds"]) / 60) + ) + assert minutes > 0, "minutes has to be greater than 0" + return minutes + + +class InvalidTimeUnitError(Exception): + pass + + +class Time: + _nanosecond_size = 1 + _microsecond_size = 1000 * _nanosecond_size + _millisecond_size = 1000 * _microsecond_size + _second_size = 1000 * _millisecond_size + _minute_size = 60 * _second_size + _hour_size = 60 * _minute_size + _day_size = 24 * _hour_size + _week_size = 7 * _day_size + units = { + "s": _second_size, + "m": _minute_size, + "h": _hour_size, + "d": _day_size, + "w": _week_size, + } + pattern = re.compile(rf"(?P\d+(\.\d*)?|\.\d+)(?P[a-zA-Z])") + + def __init__(self, duration: str): + self.duration = Time._from_str(duration) + + def __str__(self) -> str: + return Time._timedelta_to_slurm(self.duration) + + def __repr__(self): + return str(self) + + @staticmethod + def _timedelta_to_slurm(delta: Union[timedelta, str]) -> str: + if isinstance(delta, timedelta): + d = dict() + d["hours"], rem = divmod(delta.seconds, 3600) + d["minutes"], d["seconds"] = divmod(rem, 60) + d["hours"] += delta.days * 24 + return "{hours}:{minutes:02d}:{seconds:02d}".format(**d) + elif isinstance(delta, str): + return delta + else: + raise ValueError("Time is in an unknown format '{}'".format(delta)) + + @staticmethod + def _from_str(duration: str) -> Union[timedelta, str]: + """Parse a duration string to a datetime.timedelta""" + + matches = Time.pattern.finditer(duration) + + total = 0 + n_matches = 0 + for m in matches: + n_matches += 1 + value = m.group("val") + unit = m.group("unit").lower() + if unit not in Time.units: + raise InvalidTimeUnitError( + "Unknown unit '{}' in time {}".format(unit, duration) + ) + + total += float(value) * Time.units[unit] + + if n_matches == 0: + return duration + + microseconds = total / Time._microsecond_size + return timedelta(microseconds=microseconds) + + +class JobLog: + def __init__(self, job_props: dict): + self.job_properties = job_props + self.uid = str(uuid4()) + + @property + def wildcards(self) -> dict: + return self.job_properties.get("wildcards", dict()) + + @property + def wildcards_str(self) -> str: + return ( + ".".join("{}={}".format(k, v) for k, v in self.wildcards.items()) + or "unique" + ) + + @property + def rule_name(self) -> str: + if not self.is_group_jobtype: + return self.job_properties.get("rule", "nameless_rule") + return self.groupid + + @property + def groupid(self) -> str: + return self.job_properties.get("groupid", "group") + + @property + def is_group_jobtype(self) -> bool: + return self.job_properties.get("type", "") == "group" + + @property + def short_uid(self) -> str: + return self.uid.split("-")[0] + + def pattern_replace(self, s: str) -> str: + """ + %r - rule name. If group job, will use the group ID instead + %i - snakemake job ID + %w - wildcards. e.g., wildcards A and B will be concatenated as 'A=.B=' + %U - a random universally unique identifier + %S - shortened version od %U + %T - Unix time, aka seconds since epoch (rounded to an integer) + """ + replacement = { + "%r": self.rule_name, + "%i": self.jobid, + "%w": self.wildcards_str, + "%U": self.uid, + "%T": str(int(unix_time())), + "%S": self.short_uid, + } + for old, new in replacement.items(): + s = s.replace(old, new) + + return s + + @property + def jobname(self) -> str: + jobname_pattern = CookieCutter.get_cluster_jobname() + if not jobname_pattern: + return "" + + return self.pattern_replace(jobname_pattern) + + @property + def jobid(self) -> str: + """The snakemake jobid""" + if self.is_group_jobtype: + return self.job_properties.get("jobid", "").split("-")[0] + return str(self.job_properties.get("jobid")) + + @property + def logpath(self) -> str: + logpath_pattern = CookieCutter.get_cluster_logpath() + if not logpath_pattern: + return "" + + return self.pattern_replace(logpath_pattern) + + @property + def outlog(self) -> str: + return self.logpath + ".out" + + @property + def errlog(self) -> str: + return self.logpath + ".err" diff --git a/08-snakemake/run_snakefile.sh b/08-snakemake/run_snakefile.sh index babe13a..efbba31 100755 --- a/08-snakemake/run_snakefile.sh +++ b/08-snakemake/run_snakefile.sh @@ -5,7 +5,7 @@ APSIM_JOBS=100 # Run Snakefile_txt echo "Processing text files..." -snakemake -s Snakefile_1 --profile nesi --jobs 1 +snakemake -s Snakefile_1 --profile slurm --jobs 1 # Check if the previous command was successful if [ $? -eq 0 ]; then @@ -13,7 +13,7 @@ if [ $? -eq 0 ]; then # Run Snakefile_apsimx echo "Processing APSIM files..." - snakemake -s Snakefile_2 --profile nesi --jobs $APSIM_JOBS + snakemake -s Snakefile_2 --profile slurm --jobs $APSIM_JOBS if [ $? -eq 0 ]; then echo "APSIM file processing completed successfully." diff --git a/submit.sh b/submit.sh index 1944cb9..e0e1105 100755 --- a/submit.sh +++ b/submit.sh @@ -28,6 +28,8 @@ echo "" echo -e "${YELLOW}Provide the path to working directory for apsim_simulations:${NC} \c" read -r working_dir +echo "" + # Check if the directory exists, if not create it if [ ! -d "$working_dir" ]; then mkdir -p "$working_dir" @@ -48,17 +50,26 @@ cp 08-snakemake/run_snakefile.sh "$working_dir" #Copy snakemake profile 08-snakemake/profiles to ~/.config/snakemake mkdir -p ~/.config/snakemake -cp -r 08-snakemake/profiles/nesi ~/.config/snakemake/ +cp -r 08-snakemake/profiles/slurm ~/.config/snakemake/ # Print completion message echo -e "\nSwitching to working directory now and running generate_apsim_configs.R to create config files." +echo "" # Change to the working directory cd "$working_dir" || exit # Print current directory to confirm echo -e "${GREEN}${BOLD}Current working directory: $(pwd)${NC}" +echo "" +#Load modules +echo -e "${YELLOW}Loading required modules and copying nesi Snakemake profile...${NC}" +if [[ $(hostname) == *eri* ]]; then + module purge && module load snakemake/7.32.3-foss-2023a-Python-3.11.6 R/4.4.1-foss-2023a Graphviz/12.1.2 +elif [[ $(hostname) == *mahuika* ]]; then + module purge >/dev/null 2>&1 && module load snakemake/7.32.3-gimkl-2022a-Python-3.11.3 R/4.3.1-gimkl-2022a +fi # Execute the R script echo -e "${YELLOW}Generating config files and splitting into multiple sets...${NC}" @@ -80,21 +91,13 @@ echo -e "${GREEN}${BOLD}Config files generation complete.${NC}" echo "" -#Load modules -echo -e "${YELLOW}Loading required modules and copying nesi Snakemake profile...${NC}" -if [[ $(hostname) == *eri* ]]; then - module purge && module load snakemake/7.32.3-foss-2023a-Python-3.11.6 R/4.4.1-foss-2023a Graphviz/12.1.2 -elif [[ $(hostname) == *mahuika* ]]; then - module purge >/dev/null 2>&1 && module load snakemake/7.32.3-gimkl-2022a-Python-3.11.3 R/4.3.1-gimkl-2022a -fi - - - # Ask if the user wants to submit the APSIM-HPC workflow echo -n -e "${YELLOW}Would you like to submit the APSIM-HPC workflow to generate .db files? (yes/no) : ${NC}" read -r submit_answer +echo "" + if [ "${submit_answer,,}" = "yes" ]; then # Verify the user is in the correct directory if [ "$(pwd)" != "$working_dir" ]; then