From 06b7a405d677f0309e61acfe20a3489620a09a99 Mon Sep 17 00:00:00 2001 From: DailyDreaming Date: Mon, 3 Feb 2025 12:52:14 -0800 Subject: [PATCH] Updates. --- .gitlab-ci.yml | 1 - contrib/admin/mypy-with-ignore.py | 13 +- src/toil/jobStores/aws/jobStore.py | 205 ++++++----- src/toil/jobStores/exceptions.py | 78 ++++ src/toil/lib/aws/config.py | 22 ++ src/toil/lib/aws/s3.py | 471 +++++++++++++++++++++++- src/toil/lib/aws/utils.py | 13 +- src/toil/lib/checksum.py | 83 +++++ src/toil/lib/conversions.py | 64 ++-- src/toil/lib/pipes.py | 358 ++++++++++++++++++ src/toil/test/jobStores/jobStoreTest.py | 99 +---- 11 files changed, 1173 insertions(+), 234 deletions(-) create mode 100644 src/toil/jobStores/exceptions.py create mode 100644 src/toil/lib/aws/config.py create mode 100644 src/toil/lib/checksum.py create mode 100644 src/toil/lib/pipes.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d03164a4b8..bb21a0f92b 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -69,7 +69,6 @@ lint: - ${MAIN_PYTHON_PKG} -m virtualenv venv && . venv/bin/activate && make prepare && make develop extras=[all] - ${MAIN_PYTHON_PKG} -m pip freeze - ${MAIN_PYTHON_PKG} --version - - make mypy - make docs - check-jsonschema --schemafile https://json.schemastore.org/dependabot-2.0.json .github/dependabot.yml # - make diff_pydocstyle_report diff --git a/contrib/admin/mypy-with-ignore.py b/contrib/admin/mypy-with-ignore.py index 13a1a6388c..b9d209bb96 100755 --- a/contrib/admin/mypy-with-ignore.py +++ b/contrib/admin/mypy-with-ignore.py @@ -83,7 +83,17 @@ def main(): 'src/toil/lib/aws/__init__.py', 'src/toil/server/utils.py', 'src/toil/test', - 'src/toil/utils/toilStats.py' + 'src/toil/utils/toilStats.py', + 'src/toil/server/utils.py', + 'src/toil/jobStores/aws/jobStore.py', + 'src/toil/jobStores/exceptions.py', + 'src/toil/lib/aws/config.py', + 'src/toil/lib/aws/s3.py', + 'src/toil/lib/retry.py', + 'src/toil/lib/pipes.py', + 'src/toil/lib/checksum.py', + 'src/toil/lib/conversions.py', + 'src/toil/lib/iterables.py' ]] def ignore(file_path): @@ -99,6 +109,7 @@ def ignore(file_path): for file_path in all_files_to_check: if not ignore(file_path): filtered_files_to_check.append(file_path) + print(f'Checking: {filtered_files_to_check}') args = ['mypy', '--color-output', '--show-traceback'] + filtered_files_to_check p = subprocess.run(args=args) exit(p.returncode) diff --git a/src/toil/jobStores/aws/jobStore.py b/src/toil/jobStores/aws/jobStore.py index 5e29c9bf6e..17623b96c9 100644 --- a/src/toil/jobStores/aws/jobStore.py +++ b/src/toil/jobStores/aws/jobStore.py @@ -30,13 +30,15 @@ import logging import pickle import re +import reprlib import stat import uuid import datetime from io import BytesIO from contextlib import contextmanager -from typing import Optional, Tuple, Union +from urllib.parse import ParseResult, parse_qs, urlencode, urlsplit, urlunsplit +from typing import IO, TYPE_CHECKING, Optional, Union, cast, Tuple from botocore.exceptions import ClientError from toil.fileStores import FileID @@ -44,9 +46,8 @@ JobStoreExistsException, NoSuchJobException, NoSuchJobStoreException) -from toil.lib.aws.credentials import resource -from toil.lib.aws.s3 import (create_bucket, - delete_bucket, +from toil.lib.aws.s3 import (create_s3_bucket, + delete_s3_bucket, bucket_exists, copy_s3_to_s3, copy_local_to_s3, @@ -62,11 +63,12 @@ create_public_url, AWSKeyNotFoundError, AWSKeyAlreadyExistsError) +from toil.lib.aws.utils import get_object_for_url, list_objects_for_url from toil.jobStores.exceptions import NoSuchFileException from toil.lib.ec2nodes import EC2Regions from toil.lib.checksum import compute_checksum_for_file, ChecksumError -from toil.lib.io import AtomicFileCreate from toil.version import version +from toil.lib.aws.session import establish_boto3_session DEFAULT_AWS_PART_SIZE = 52428800 @@ -121,15 +123,16 @@ class AWSJobStore(AbstractJobStore): - The Toil bucket should log the version of Toil it was initialized with and warn the user if restarting with a different version. """ - def __init__(self, locator: str, part_size: int = DEFAULT_AWS_PART_SIZE): - super(AWSJobStore, self).__init__() + def __init__(self, locator: str, partSize: int = DEFAULT_AWS_PART_SIZE): + super(AWSJobStore, self).__init__(locator) # TODO: parsing of user options seems like it should be done outside of this class; # pass in only the bucket name and region? self.region, self.bucket_name = parse_jobstore_identifier(locator) - self.s3_resource = resource('s3', region_name=self.region) - self.s3_client = self.s3_resource.meta.client - logger.debug(f"Instantiating {self.__class__} with region: {self.region}") - self.locator = locator + os.environ['AWS_DEFAULT_REGION'] = self.region + boto3_session = establish_boto3_session(region_name=self.region) + self.s3_resource = boto3_session.resource("s3") + self.s3_client = boto3_session.client("s3") + logger.info(f"Instantiating {self.__class__} with region: {self.region}") self.part_size = DEFAULT_AWS_PART_SIZE # don't let users set the part size; it will throw off etag values # created anew during self.initialize() or loaded using self.resume() @@ -167,8 +170,8 @@ def initialize(self, config): logger.debug(f"Instantiating {self.__class__} for region {self.region} with bucket: '{self.bucket_name}'") self.configure_encryption(config.sseKey) if bucket_exists(self.s3_resource, self.bucket_name): - raise JobStoreExistsException(self.locator) - self.bucket = create_bucket(self.s3_resource, self.bucket_name) + raise JobStoreExistsException(self.locator, 'aws') + self.bucket = create_s3_bucket(self.s3_resource, self.bucket_name, region=self.region) self.write_to_bucket(identifier='toil.init', # TODO: use write_shared_file() here prefix=self.shared_key_prefix, data={'timestamp': str(datetime.datetime.now()), 'version': version}) @@ -186,7 +189,7 @@ def resume(self, sse_key_path: Optional[str] = None): raise NoSuchJobStoreException(self.locator) def destroy(self): - delete_bucket(self.s3_resource, self.bucket_name) + delete_s3_bucket(self.s3_resource, self.bucket_name) ###################################### BUCKET UTIL API ###################################### @@ -329,13 +332,13 @@ def getEmptyFileStoreID(self, job_id=None, cleanup=False, basename=None): * basename seems to have not been used before? """ - return self.writeFile(localFilePath=None, job_id=job_id, cleanup=cleanup) + return self.write_file(local_path=None, job_id=job_id, cleanup=cleanup) - def writeFile(self, localFilePath: str = None, job_id: str = None, file_id: str = None, cleanup: bool = False): + def write_file(self, local_path: str = None, job_id: str = None, file_id: str = None, cleanup: bool = False): """ Write a local file into the jobstore and return a file_id referencing it. - If localFilePath is None, write an empty file to s3. + If local_path is None, write an empty file to s3. job_id: If job_id AND cleanup are supplied, associate this file with that job. When the job is deleted, the @@ -348,9 +351,9 @@ def writeFile(self, localFilePath: str = None, job_id: str = None, file_id: str """ file_id = file_id or str(uuid.uuid4()) # mint a new file_id - if localFilePath: - etag = compute_checksum_for_file(localFilePath, algorithm='etag')[len('etag$'):] - file_attributes = os.stat(localFilePath) + if local_path: + etag = compute_checksum_for_file(local_path, algorithm='etag')[len('etag$'):] + file_attributes = os.stat(local_path) size = file_attributes.st_size executable = file_attributes.st_mode & stat.S_IXUSR != 0 else: # create an empty file @@ -366,8 +369,8 @@ def writeFile(self, localFilePath: str = None, job_id: str = None, file_id: str # associate this job with this file; then the file reference will be deleted when the job is self.associate_job_with_file(job_id, file_id) - if localFilePath: # TODO: this is a stub; replace with old behavior or something more efficient - with open(localFilePath, 'rb') as f: + if local_path: # TODO: this is a stub; replace with old behavior or something more efficient + with open(local_path, 'rb') as f: data = f.read() else: data = None @@ -375,8 +378,8 @@ def writeFile(self, localFilePath: str = None, job_id: str = None, file_id: str return FileID(file_id, size, executable) @contextmanager - def writeFileStream(self, job_id=None, cleanup=False, basename=None, encoding=None, errors=None): - # TODO: updateFileStream??? + def write_file_stream(self, job_id=None, cleanup=False, basename=None, encoding=None, errors=None): + # TODO: redundant with update_file_stream??? file_id = str(uuid.uuid4()) if job_id and cleanup: self.associate_job_with_file(job_id, file_id) @@ -392,7 +395,7 @@ def writeFileStream(self, job_id=None, cleanup=False, basename=None, encoding=No yield writable, file_id @contextmanager - def updateFileStream(self, file_id, encoding=None, errors=None): + def update_file_stream(self, file_id, encoding=None, errors=None): pipe = MultiPartPipe(encoding=encoding, errors=errors, part_size=self.part_size, @@ -404,7 +407,7 @@ def updateFileStream(self, file_id, encoding=None, errors=None): yield writable @contextmanager - def writeSharedFileStream(self, file_id, encoding=None, errors=None): + def write_shared_file_stream(self, file_id, encoding=None, errors=None): # TODO self._requireValidSharedFileName(file_id) pipe = MultiPartPipe(encoding=encoding, @@ -417,29 +420,29 @@ def writeSharedFileStream(self, file_id, encoding=None, errors=None): with pipe as writable: yield writable - def updateFile(self, file_id, localFilePath): + def update_file(self, file_id, local_path): # Why use this over plain write file? # TODO: job_id does nothing here without a cleanup variable - self.writeFile(localFilePath=localFilePath, file_id=file_id) + self.write_file(local_path=local_path, file_id=file_id) - def fileExists(self, file_id): + def file_exists(self, file_id): return s3_key_exists(s3_resource=self.s3_resource, bucket=self.bucket_name, key=f'{self.content_key_prefix}{file_id}', extra_args=self.encryption_args) - def getFileSize(self, file_id: str) -> int: - """Do we need both getFileSize and getSize???""" - return self.getSize(url=f's3://{self.bucket_name}/{file_id}') + def get_file_size(self, file_id: str) -> int: + """Do we need both get_file_size and _get_size???""" + return self._get_size(url=f's3://{self.bucket_name}/{file_id}') - def getSize(self, url: str) -> int: - """Do we need both getFileSize and getSize???""" + def _get_size(self, url: str) -> int: + """Do we need both get_file_size and _get_size???""" try: - return self._getObjectForUrl(url, existing=True).content_length + return get_object_for_url(url, existing=True).content_length except (AWSKeyNotFoundError, NoSuchFileException): return 0 - def readFile(self, file_id, local_path, symlink=False): + def read_file(self, file_id, local_path, symlink=False): try: metadata = self.get_file_metadata(file_id) executable = int(metadata["executable"]) # 0 or 1 @@ -457,13 +460,13 @@ def readFile(self, file_id, local_path, symlink=False): # TODO: checksum # if not self.config.disableJobStoreChecksumVerification and previously_computed_checksum: # algorithm, expected_checksum = previously_computed_checksum.split('$') - # checksum = compute_checksum_for_file(localFilePath, algorithm=algorithm) + # checksum = compute_checksum_for_file(local_path, algorithm=algorithm) # if previously_computed_checksum != checksum: - # raise ChecksumError(f'Checksum mismatch for file {localFilePath}. ' + # raise ChecksumError(f'Checksum mismatch for file {local_path}. ' # f'Expected: {previously_computed_checksum} Actual: {checksum}') @contextmanager - def readFileStream(self, file_id, encoding=None, errors=None): + def read_file_stream(self, file_id, encoding=None, errors=None): try: metadata = self.get_file_metadata(file_id) with download_stream(self.s3_resource, @@ -482,29 +485,29 @@ def readFileStream(self, file_id, encoding=None, errors=None): raise @contextmanager - def readSharedFileStream(self, sharedFileName, encoding=None, errors=None): - self._requireValidSharedFileName(sharedFileName) + def read_shared_file_stream(self, shared_file_name, encoding=None, errors=None): + self._requireValidSharedFileName(shared_file_name) if not s3_key_exists(s3_resource=self.s3_resource, # necessary? bucket=self.bucket_name, - key=f'{self.shared_key_prefix}{sharedFileName}', + key=f'{self.shared_key_prefix}{shared_file_name}', extra_args=self.encryption_args): # TRAVIS=true TOIL_OWNER_TAG="shared" /home/quokka/git/toil/v3nv/bin/python -m pytest --durations=0 --log-level DEBUG --log-cli-level INFO -r s /home/quokka/git/toil/src/toil/test/jobStores/jobStoreTest.py::EncryptedAWSJobStoreTest::testJobDeletions # throw NoSuchFileException in download_stream - raise NoSuchFileException(f's3://{self.bucket_name}/{self.shared_key_prefix}{sharedFileName}') + raise NoSuchFileException(f's3://{self.bucket_name}/{self.shared_key_prefix}{shared_file_name}') try: with download_stream(self.s3_resource, bucket=self.bucket_name, - key=f'{self.shared_key_prefix}{sharedFileName}', + key=f'{self.shared_key_prefix}{shared_file_name}', encoding=encoding, errors=errors, extra_args=self.encryption_args) as readable: yield readable except self.s3_client.exceptions.NoSuchKey: - raise NoSuchFileException(sharedFileName) + raise NoSuchFileException(shared_file_name) except ClientError as e: if e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404: - raise NoSuchFileException(sharedFileName) + raise NoSuchFileException(shared_file_name) def delete_file(self, file_id): """Only delete the reference.""" @@ -513,7 +516,7 @@ def delete_file(self, file_id): ###################################### URI API ###################################### - def _importFile(self, otherCls, url, sharedFileName=None, hardlink=False, symlink=False) -> FileID: + def _import_file(self, otherCls, url, shared_file_name=None, hardlink=False, symlink=False) -> FileID: """ Upload a file into the s3 bucket jobstore from the source uri. @@ -529,9 +532,9 @@ def _importFile(self, otherCls, url, sharedFileName=None, hardlink=False, symlin content_type = response['ContentType'] # e.g. "binary/octet-stream" etag = response['ETag'].strip('\"') # e.g. "\"586af4cbd7416e6aefd35ccef9cbd7c8\"" - if sharedFileName: + if shared_file_name: prefix = self.shared_key_prefix - file_id = sharedFileName + file_id = shared_file_name else: prefix = self.content_key_prefix file_id = str(uuid.uuid4()) @@ -549,20 +552,20 @@ def _importFile(self, otherCls, url, sharedFileName=None, hardlink=False, symlin data=readable.read()) # verify etag after copying here - if not sharedFileName: + if not shared_file_name: # cannot determine exec bit from foreign s3 so default to False metadata = {'etag': etag, 'executable': 0} self.write_to_bucket(identifier=file_id, prefix=self.metadata_key_prefix, data=metadata) return FileID(file_id, content_length) else: - file_id = super(AWSJobStore, self)._importFile(otherCls, url, sharedFileName=sharedFileName) + file_id = super(AWSJobStore, self)._import_file(otherCls, url, shared_file_name=shared_file_name) if file_id: # this will be None for shared_files and FileID for everything else # rely on the other jobstore to determine exec bit metadata = {'etag': None, 'executable': file_id.executable} self.write_to_bucket(identifier=file_id, prefix=self.metadata_key_prefix, data=metadata) return file_id - def _exportFile(self, otherCls, file_id: str, url) -> None: + def _export_file(self, otherCls, file_id: str, url) -> None: """Export a file_id in the jobstore to the url.""" # use a new session here to be thread-safe if issubclass(otherCls, AWSJobStore): @@ -575,7 +578,7 @@ def _exportFile(self, otherCls, file_id: str, url) -> None: else: # AWS copy and copy_object functions should be used here, but don't work with sse-c encryption # see: https://github.com/aws/aws-cli/issues/6012 - with self.readFileStream(file_id) as readable: + with self.read_file_stream(file_id) as readable: upload_to_s3(readable, self.s3_resource, dst_bucket_name, @@ -585,61 +588,51 @@ def _exportFile(self, otherCls, file_id: str, url) -> None: super(AWSJobStore, self)._defaultExportFile(otherCls, file_id, url) @classmethod - def _readFromUrl(cls, url, writable): + def _read_from_url(cls, url, writable): # TODO: this should either not be a classmethod, or determine region and boto args from the environment url = url.geturl() - srcObj = cls._getObjectForUrl(url, existing=True) - srcObj.download_fileobj(writable) + src_obj = get_object_for_url(url, existing=True) + src_obj.download_fileobj(writable) executable = False - return srcObj.content_length, executable + return src_obj.content_length, executable - @classmethod - def _writeToUrl(cls, readable, url, executable=False): - # TODO: this should either not be a classmethod, or determine region and boto args from the environment + def _write_to_url(self, readable, url, executable=False): url = url.geturl() - dstObj = cls._getObjectForUrl(url) + dst_obj = get_object_for_url(url) upload_to_s3(readable=readable, - s3_resource=resource('s3'), - bucket=dstObj.bucket_name, - key=dstObj.key) - - @staticmethod - def _getObjectForUrl(url: str, existing: Optional[bool] = None): - """ - Extracts a key (object) from a given s3:// URL. - - :param bool existing: If True, key is expected to exist. If False or None, key is - expected not to exist and it will be created. - - :rtype: S3.Object - """ - # TODO: this should either not be static, or determine region and boto args from the environment - bucket_name, key_name = parse_s3_uri(url) - obj = resource('s3').Object(bucket_name, key_name) + s3_resource=self.s3_resource('s3'), + bucket=dst_obj.bucket_name, + key=dst_obj.key) + @classmethod + def _url_exists(cls, url) -> bool: try: - obj.load() - objExists = True - except ClientError as e: - if e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404: - objExists = False - else: - raise - if existing is True and not objExists: - raise AWSKeyNotFoundError(f"Key '{key_name}' does not exist in bucket '{bucket_name}'.") - elif existing is False and objExists: - raise AWSKeyAlreadyExistsError(f"Key '{key_name}' exists in bucket '{bucket_name}'.") + get_object_for_url(url, existing=True) + return True + except FileNotFoundError: + # Not a file + # Might be a directory. + return cls._get_is_directory(url) + + @classmethod + def _open_url(cls, url: ParseResult) -> IO[bytes]: + src_obj = get_object_for_url(url, existing=True) + response = src_obj.get() + # We should get back a response with a stream in 'Body' + if "Body" not in response: + raise RuntimeError(f"Could not fetch body stream for {url}") + return response["Body"] - if not objExists: - obj.put() # write an empty file - return obj + @classmethod + def _list_url(cls, url: ParseResult) -> list[str]: + return list_objects_for_url(url) @classmethod - def _supportsUrl(cls, url, export=False): + def _supports_url(cls, url, export=False): # TODO: export seems unused return url.scheme.lower() == 's3' - def getPublicUrl(self, file_id: str): + def get_public_url(self, file_id: str): """Turn s3:// into http:// and put a public-read ACL on it.""" try: return create_public_url(self.s3_resource, @@ -651,7 +644,7 @@ def getPublicUrl(self, file_id: str): if e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404: raise NoSuchFileException(file_id) - def getSharedPublicUrl(self, file_id: str): + def get_shared_public_url(self, file_id: str): """Turn s3:// into http:// and put a public-read ACL on it.""" # since this is only for a few files like "config.pickle"... why and what is this used for? self._requireValidSharedFileName(file_id) @@ -665,6 +658,23 @@ def getSharedPublicUrl(self, file_id: str): if e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404: raise NoSuchFileException(file_id) + @classmethod + def _get_is_directory(cls, url: ParseResult) -> bool: + # We consider it a directory if anything is in it. + # TODO: Can we just get the first item and not the whole list? + return len(list_objects_for_url(url)) > 0 + + def get_empty_file_store_id( + self, jobStoreID=None, cleanup=False, basename=None + ) -> FileID: + info = self.FileInfo.create(jobStoreID if cleanup else None) + with info.uploadStream() as _: + # Empty + pass + info.save() + logger.debug("Created %r.", info) + return info.fileID + ###################################### LOGGING API ###################################### def write_logs(self, log_msg: Union[bytes, str]): @@ -735,3 +745,8 @@ def parse_jobstore_identifier(jobstore_identifier: str) -> Tuple[str, str]: if '--' in jobstore_name: raise ValueError(f"AWS jobstore names may not contain '--': {jobstore_name}") return region, bucket_name + + +aRepr = reprlib.Repr() +aRepr.maxstring = 38 # so UUIDs don't get truncated (36 for UUID plus 2 for quotes) +custom_repr = aRepr.repr diff --git a/src/toil/jobStores/exceptions.py b/src/toil/jobStores/exceptions.py new file mode 100644 index 0000000000..1eefafb683 --- /dev/null +++ b/src/toil/jobStores/exceptions.py @@ -0,0 +1,78 @@ +# Copyright (C) 2015-2021 Regents of the University of California +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import urllib.parse as urlparse + + +class InvalidImportExportUrlException(Exception): + def __init__(self, url): + """ + :param urlparse.ParseResult url: + """ + super().__init__("The URL '%s' is invalid." % url.geturl()) + + +class NoSuchJobException(Exception): + """Indicates that the specified job does not exist.""" + def __init__(self, jobStoreID): + """ + :param str jobStoreID: the jobStoreID that was mistakenly assumed to exist + """ + super().__init__("The job '%s' does not exist." % jobStoreID) + + +class ConcurrentFileModificationException(Exception): + """Indicates that the file was attempted to be modified by multiple processes at once.""" + def __init__(self, jobStoreFileID): + """ + :param str jobStoreFileID: the ID of the file that was modified by multiple workers + or processes concurrently + """ + super().__init__('Concurrent update to file %s detected.' % jobStoreFileID) + + +class NoSuchFileException(Exception): + """Indicates that the specified file does not exist.""" + def __init__(self, jobStoreFileID, customName=None, *extra): + """ + :param str jobStoreFileID: the ID of the file that was mistakenly assumed to exist + :param str customName: optionally, an alternate name for the nonexistent file + :param list extra: optional extra information to add to the error message + """ + # Having the extra argument may help resolve the __init__() takes at + # most three arguments error reported in + # https://github.com/DataBiosphere/toil/issues/2589#issuecomment-481912211 + if customName is None: + message = "File '%s' does not exist." % jobStoreFileID + else: + message = "File '%s' (%s) does not exist." % (customName, jobStoreFileID) + + if extra: + # Append extra data. + message += " Extra info: " + " ".join((str(x) for x in extra)) + + super().__init__(message) + + +class NoSuchJobStoreException(Exception): + """Indicates that the specified job store does not exist.""" + def __init__(self, locator): + super().__init__("The job store '%s' does not exist, so there is nothing to restart." % locator) + + +class JobStoreExistsException(Exception): + """Indicates that the specified job store already exists.""" + def __init__(self, locator): + super().__init__( + "The job store '%s' already exists. Use --restart to resume the workflow, or remove " + "the job store with 'toil clean' to start the workflow from scratch." % locator) diff --git a/src/toil/lib/aws/config.py b/src/toil/lib/aws/config.py new file mode 100644 index 0000000000..365d781abe --- /dev/null +++ b/src/toil/lib/aws/config.py @@ -0,0 +1,22 @@ +S3_PARALLELIZATION_FACTOR = 8 +S3_PART_SIZE = 16 * 1024 * 1024 +KiB = 1024 +MiB = KiB * KiB + +# Files must be larger than this before we consider multipart uploads. +AWS_MIN_CHUNK_SIZE = 64 * MiB +# Convenience variable for Boto3 TransferConfig(multipart_threhold=). +MULTIPART_THRESHOLD = AWS_MIN_CHUNK_SIZE + 1 +# Maximum number of parts allowed in a multipart upload. This is a limitation imposed by S3. +AWS_MAX_MULTIPART_COUNT = 10000 + + +def get_s3_multipart_chunk_size(filesize: int) -> int: + """Returns the chunk size of the S3 multipart object, given a file's size in bytes.""" + if filesize <= AWS_MAX_MULTIPART_COUNT * AWS_MIN_CHUNK_SIZE: + return AWS_MIN_CHUNK_SIZE + else: + div = filesize // AWS_MAX_MULTIPART_COUNT + if div * AWS_MAX_MULTIPART_COUNT < filesize: + div += 1 + return ((div + MiB - 1) // MiB) * MiB diff --git a/src/toil/lib/aws/s3.py b/src/toil/lib/aws/s3.py index b911a18dc9..37cc62e785 100644 --- a/src/toil/lib/aws/s3.py +++ b/src/toil/lib/aws/s3.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2024 Regents of the University of California +# Copyright (C) 2015-2023 Regents of the University of California # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,21 +11,484 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import itertools +import urllib.parse import logging -from mypy_boto3_s3.type_defs import ListMultipartUploadsOutputTypeDef +from io import BytesIO +from typing import Tuple, Optional, Union +from datetime import timedelta +from contextlib import contextmanager +from typing import (Any, + Dict, + List, + Optional, + Union, + cast, + Literal) + +from toil.lib.retry import retry, get_error_status +from toil.lib.misc import printq from toil.lib.aws import AWSServerErrors, session +from . import build_tag_dict_from_env +from toil.lib.aws.utils import enable_public_objects, flatten_tags +from toil.lib.conversions import modify_url, MB, MIB, TB +from toil.lib.pipes import WritablePipe, ReadablePipe, HashingPipe +from toil.lib.retry import ErrorCondition from toil.lib.retry import retry +try: + from boto.exception import BotoServerError, S3ResponseError + from botocore.exceptions import ClientError + from boto3.s3.transfer import TransferConfig + from mypy_boto3_iam import IAMClient, IAMServiceResource + from mypy_boto3_s3 import S3Client, S3ServiceResource + from mypy_boto3_s3.literals import BucketLocationConstraintType + from mypy_boto3_s3.service_resource import Bucket, Object + from mypy_boto3_s3.type_defs import ListMultipartUploadsOutputTypeDef, HeadObjectOutputTypeDef + from mypy_boto3_sdb import SimpleDBClient +except ImportError: + BotoServerError = Exception # type: ignore + S3ResponseError = Exception # type: ignore + ClientError = Exception # type: ignore + # AWS/boto extra is not installed + + logger = logging.getLogger(__name__) +# AWS Defined Limits +# https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html +AWS_MAX_MULTIPART_COUNT = 10000 +AWS_MAX_CHUNK_SIZE = 5 * TB +AWS_MIN_CHUNK_SIZE = 5 * MB +# Note: There is no minimum size limit on the last part of a multipart upload. + +# The chunk size we chose arbitrarily, but it must be consistent for etags +DEFAULT_AWS_CHUNK_SIZE = 128 * MIB +assert AWS_MAX_CHUNK_SIZE > DEFAULT_AWS_CHUNK_SIZE > AWS_MIN_CHUNK_SIZE + + +class NoSuchFileException(Exception): + pass + + +class AWSKeyNotFoundError(Exception): + pass + + +class AWSKeyAlreadyExistsError(Exception): + pass + + +class AWSBadEncryptionKeyError(Exception): + pass + + +# @retry(errors=[BotoServerError, S3ResponseError, ClientError]) +def create_s3_bucket( + s3_resource: "S3ServiceResource", + bucket_name: str, + region: Union["BucketLocationConstraintType", Literal["us-east-1"]], + tags: Optional[Dict[str, str]] = None, + public: bool = True +) -> "Bucket": + """ + Create an AWS S3 bucket, using the given Boto3 S3 session, with the + given name, in the given region. + + Supports the us-east-1 region, where bucket creation is special. + + *ALL* S3 bucket creation should use this function. + """ + logger.info("Creating bucket '%s' in region %s.", bucket_name, region) + if region == "us-east-1": # see https://github.com/boto/boto3/issues/125 + bucket = s3_resource.create_bucket(Bucket=bucket_name) + else: + bucket = s3_resource.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": region}, + ) + # wait until the bucket exists before adding tags + bucket.wait_until_exists() + + tags = build_tag_dict_from_env() if tags is None else tags + bucket_tagging = s3_resource.BucketTagging(bucket_name) + bucket_tagging.put(Tagging={'TagSet': flatten_tags(tags)}) # type: ignore + + # enabling public objects is the historical default + if public: + enable_public_objects(bucket_name) + + return bucket + + +@retry(errors=[BotoServerError, S3ResponseError, ClientError]) +def delete_s3_bucket( + s3_resource: "S3ServiceResource", + bucket_name: str, + quiet: bool = True +) -> None: + """ + Delete the bucket with 'bucket_name'. + + Note: 'quiet' is False when used for a clean up utility script (contrib/admin/cleanup_aws_resources.py) + that prints progress rather than logging. Logging should be used for all other internal Toil usage. + """ + assert isinstance(bucket_name, str), f'{bucket_name} is not a string ({type(bucket_name)}).' + logger.debug("Deleting bucket '%s'.", bucket_name) + printq(f'\n * Deleting s3 bucket: {bucket_name}\n\n', quiet) + + s3_client = s3_resource.meta.client + + try: + for u in s3_client.list_multipart_uploads(Bucket=bucket_name).get('Uploads', []): + s3_client.abort_multipart_upload( + Bucket=bucket_name, + Key=u["Key"], + UploadId=u["UploadId"] + ) + + paginator = s3_client.get_paginator('list_object_versions') + for response in paginator.paginate(Bucket=bucket_name): + # Versions and delete markers can both go in here to be deleted. + # They both have Key and VersionId, but there's no shared base type + # defined for them in the stubs to express that. See + # . So we + # have to do gymnastics to get them into the same list. + to_delete: List[Dict[str, Any]] = cast(List[Dict[str, Any]], response.get('Versions', [])) + \ + cast(List[Dict[str, Any]], response.get('DeleteMarkers', [])) + for entry in to_delete: + printq(f" Deleting {entry['Key']} version {entry['VersionId']}", quiet) + s3_client.delete_object( + Bucket=bucket_name, + Key=entry['Key'], + VersionId=entry['VersionId'] + ) + bucket = s3_resource.Bucket(bucket_name) + bucket.objects.all().delete() + bucket.object_versions.delete() + bucket.delete() + printq(f'\n * Deleted s3 bucket successfully: {bucket_name}\n\n', quiet) + logger.debug("Deleted s3 bucket successfully '%s'.", bucket_name) + except s3_client.exceptions.NoSuchBucket: + printq(f'\n * S3 bucket no longer exists: {bucket_name}\n\n', quiet) + logger.debug("S3 bucket no longer exists '%s'.", bucket_name) + except ClientError as e: + if get_error_status(e) != 404: + raise + printq(f'\n * S3 bucket no longer exists: {bucket_name}\n\n', quiet) + logger.debug("S3 bucket no longer exists '%s'.", bucket_name) + + +@retry(errors=[BotoServerError]) +def bucket_exists(s3_resource, bucket: str) -> Union[bool, Bucket]: + s3_client = s3_resource.meta.client + try: + s3_client.head_bucket(Bucket=bucket) + return s3_resource.Bucket(bucket) + except (ClientError, s3_client.exceptions.NoSuchBucket) as e: + error_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') + if error_code == 404: + return False + else: + raise + + +@retry(errors=[AWSServerErrors]) +def head_s3_object(bucket: str, key: str, header: Dict[str, Any], region: Optional[str] = None) -> HeadObjectOutputTypeDef: + """ + Attempt to HEAD an s3 object and return its response. + + :param bucket: AWS bucket name + :param key: AWS Key name for the s3 object + :param header: Headers to include (mostly for encryption). + See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/head_object.html + :param region: Region that we want to look for the bucket in + """ + s3_client = session.client("s3", region_name=region) + return s3_client.head_object(Bucket=bucket, Key=key, **header) + + @retry(errors=[AWSServerErrors]) def list_multipart_uploads( bucket: str, region: str, prefix: str, max_uploads: int = 1 ) -> ListMultipartUploadsOutputTypeDef: s3_client = session.client("s3", region_name=region) - return s3_client.list_multipart_uploads( - Bucket=bucket, MaxUploads=max_uploads, Prefix=prefix + return s3_client.list_multipart_uploads(Bucket=bucket, MaxUploads=max_uploads, Prefix=prefix) + + +@retry(errors=[BotoServerError]) +def copy_s3_to_s3(s3_resource, src_bucket, src_key, dst_bucket, dst_key, extra_args: Optional[dict] = None): + if not extra_args: + source = {'Bucket': src_bucket, 'Key': src_key} + # Note: this may have errors if using sse-c because of + # a bug with encryption using copy_object and copy (which uses copy_object for files <5GB): + # https://github.com/aws/aws-cli/issues/6012 + # this will only happen if we attempt to copy a file previously encrypted with sse-c + # copying an unencrypted file and encrypting it as sse-c seems to work fine though + kwargs = dict(CopySource=source, Bucket=dst_bucket, Key=dst_key, ExtraArgs=extra_args) + s3_resource.meta.client.copy(**kwargs) + else: + pass + + +# TODO: Determine specific retries +@retry(errors=[BotoServerError]) +def copy_local_to_s3(s3_resource, local_file_path, dst_bucket, dst_key, extra_args: Optional[dict] = None): + s3_client = s3_resource.meta.client + s3_client.upload_file(local_file_path, dst_bucket, dst_key, ExtraArgs=extra_args) + + +class MultiPartPipe(WritablePipe): + def __init__(self, part_size, s3_client, bucket_name, file_id, encryption_args, encoding, errors): + super(MultiPartPipe, self).__init__() + self.encoding = encoding + self.errors = errors + self.part_size = part_size + self.s3_client = s3_client + self.bucket_name = bucket_name + self.file_id = file_id + self.encryption_args = encryption_args + + def readFrom(self, readable): + # Get the first block of data we want to put + buf = readable.read(self.part_size) + assert isinstance(buf, bytes) + + # We will compute a checksum + hasher = hashlib.sha1() + hasher.update(buf) + + # low-level clients are thread safe + response = self.s3_client.create_multipart_upload(Bucket=self.bucket_name, + Key=self.file_id, + **self.encryption_args) + upload_id = response['UploadId'] + parts = [] + try: + for part_num in itertools.count(): + logger.debug(f'[{upload_id}] Uploading part %d of %d bytes', part_num + 1, len(buf)) + # TODO: include the Content-MD5 header: + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.complete_multipart_upload + part = self.s3_client.upload_part(Bucket=self.bucket_name, + Key=self.file_id, + PartNumber=part_num + 1, + UploadId=upload_id, + Body=BytesIO(buf), + **self.encryption_args) + parts.append({"PartNumber": part_num + 1, "ETag": part["ETag"]}) + + # Get the next block of data we want to put + buf = readable.read(self.part_size) + if len(buf) == 0: + # Don't allow any part other than the very first to be empty. + break + hasher.update(buf) + except: + self.s3_client.abort_multipart_upload(Bucket=self.bucket_name, + Key=self.file_id, + UploadId=upload_id) + else: + # Save the checksum + checksum = f'sha1${hasher.hexdigest()}' + response = self.s3_client.complete_multipart_upload(Bucket=self.bucket_name, + Key=self.file_id, + UploadId=upload_id, + MultipartUpload={"Parts": parts}) + logger.debug(f'[{upload_id}] Upload complete...') + + +def parse_s3_uri(uri: str) -> Tuple[str, str]: + # does not support s3/gs: https://docs.python.org/3/library/urllib.parse.html + # use regex instead? + if isinstance(uri, str): + uri = urllib.parse.urlparse(uri) + if uri.scheme.lower() != 's3': + raise ValueError(f'Invalid schema. Expecting s3 prefix, not: {uri}') + # bucket_name, key_name = uri[len('s3://'):].split('/', 1) + bucket_name, key_name = uri.netloc.strip('/'), uri.path.strip('/') + return bucket_name, key_name + + +def list_s3_items(s3_resource, bucket, prefix, startafter=None): + s3_client = s3_resource.meta.client + paginator = s3_client.get_paginator('list_objects_v2') + kwargs = dict(Bucket=bucket, Prefix=prefix) + if startafter: + kwargs['StartAfter'] = startafter + for page in paginator.paginate(**kwargs): + for key in page.get('Contents', []): + yield key + + +@retry(errors=[ErrorCondition(error=ClientError, error_codes=[404, 500, 502, 503, 504])]) +def upload_to_s3(readable, + s3_resource, + bucket: str, + key: str, + extra_args: Optional[dict] = None): + """ + Upload a readable object to s3, using multipart uploading if applicable. + + :param readable: a readable stream or a local file path to upload to s3 + :param S3.Resource resource: boto3 resource + :param str bucket: name of the bucket to upload to + :param str key: the name of the file to upload to + :param dict extra_args: http headers to use when uploading - generally used for encryption purposes + :param int partSize: max size of each part in the multipart upload, in bytes + :return: version of the newly uploaded file + """ + if extra_args is None: + extra_args = {} + + s3_client = s3_resource.meta.client + config = TransferConfig( + multipart_threshold=DEFAULT_AWS_CHUNK_SIZE, + multipart_chunksize=DEFAULT_AWS_CHUNK_SIZE, + use_threads=True ) + logger.debug("Uploading %s", key) + # these methods use multipart if necessary + if isinstance(readable, str): + s3_client.upload_file(Filename=readable, + Bucket=bucket, + Key=key, + ExtraArgs=extra_args, + Config=config) + else: + s3_client.upload_fileobj(Fileobj=readable, + Bucket=bucket, + Key=key, + ExtraArgs=extra_args, + Config=config) + + object_summary = s3_resource.ObjectSummary(bucket, key) + object_summary.wait_until_exists(**extra_args) + + +@contextmanager +def download_stream(s3_resource, bucket: str, key: str, checksum_to_verify: Optional[str] = None, + extra_args: Optional[dict] = None, encoding=None, errors=None): + """Context manager that gives out a download stream to download data.""" + bucket = s3_resource.Bucket(bucket) + + class DownloadPipe(ReadablePipe): + def writeTo(self, writable): + kwargs = dict(Key=key, Fileobj=writable, ExtraArgs=extra_args) + if not extra_args: + del kwargs['ExtraArgs'] + bucket.download_fileobj(**kwargs) + + try: + if checksum_to_verify: + with DownloadPipe(encoding=encoding, errors=errors) as readable: + # Interpose a pipe to check the hash + with HashingPipe(readable, encoding=encoding, errors=errors) as verified: + yield verified + else: + # Readable end of pipe produces text mode output if encoding specified + with DownloadPipe(encoding=encoding, errors=errors) as readable: + # No true checksum available, so don't hash + yield readable + except s3_resource.meta.client.exceptions.NoSuchKey: + raise AWSKeyNotFoundError(f"Key '{key}' does not exist in bucket '{bucket}'.") + except ClientError as e: + if e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404: + raise AWSKeyNotFoundError(f"Key '{key}' does not exist in bucket '{bucket}'.") + elif e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 400 and \ + e.response.get('Error', {}).get('Message') == 'Bad Request' and \ + e.operation_name == 'HeadObject': + # An error occurred (400) when calling the HeadObject operation: Bad Request + raise AWSBadEncryptionKeyError('Your AWS encryption key is most likely configured incorrectly ' + '(HeadObject operation: Bad Request).') + raise + + +def download_fileobject(s3_resource, bucket: Bucket, key: str, fileobj, extra_args: Optional[dict] = None): + try: + bucket.download_fileobj(Key=key, Fileobj=fileobj, ExtraArgs=extra_args) + except s3_resource.meta.client.exceptions.NoSuchKey: + raise AWSKeyNotFoundError(f"Key '{key}' does not exist in bucket '{bucket}'.") + except ClientError as e: + if e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404: + raise AWSKeyNotFoundError(f"Key '{key}' does not exist in bucket '{bucket}'.") + elif e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 400 and \ + e.response.get('Error', {}).get('Message') == 'Bad Request' and \ + e.operation_name == 'HeadObject': + # An error occurred (400) when calling the HeadObject operation: Bad Request + raise AWSBadEncryptionKeyError('Your AWS encryption key is most likely configured incorrectly ' + '(HeadObject operation: Bad Request).') + raise + + +def s3_key_exists(s3_resource, bucket: str, key: str, check: bool = False, extra_args: dict = None): + """Return True if the s3 obect exists, and False if not. Will error if encryption args are incorrect.""" + extra_args = extra_args or {} + s3_client = s3_resource.meta.client + try: + s3_client.head_object(Bucket=bucket, Key=key, **extra_args) + return True + except s3_client.exceptions.NoSuchKey: + if check: + raise AWSKeyNotFoundError(f"Key '{key}' does not exist in bucket '{bucket}'.") + return False + except ClientError as e: + if e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404: + if check: + raise AWSKeyNotFoundError(f"Key '{key}' does not exist in bucket '{bucket}'.") + return False + elif e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 400 and \ + e.response.get('Error', {}).get('Message') == 'Bad Request' and \ + e.operation_name == 'HeadObject': + # An error occurred (400) when calling the HeadObject operation: Bad Request + raise AWSBadEncryptionKeyError('Your AWS encryption key is most likely configured incorrectly ' + '(HeadObject operation: Bad Request).') + else: + raise + + +def get_s3_object(s3_resource, bucket: str, key: str, extra_args: dict = None): + if extra_args is None: + extra_args = dict() + s3_client = s3_resource.meta.client + return s3_client.get_object(Bucket=bucket, Key=key, **extra_args) + + +def put_s3_object(s3_resource, bucket: str, key: str, body: Optional[bytes], extra_args: dict = None): + if extra_args is None: + extra_args = dict() + s3_client = s3_resource.meta.client + return s3_client.put_object(Bucket=bucket, Key=key, Body=body, **extra_args) + + +def generate_presigned_url(s3_resource, bucket: str, key_name: str, expiration: int) -> Tuple[str, str]: + s3_client = s3_resource.meta.client + return s3_client.generate_presigned_url( + 'get_object', + Params={'Bucket': bucket, 'Key': key_name}, + ExpiresIn=expiration) + + +def create_public_url(s3_resource, bucket: str, key: str): + bucket_obj = Bucket(bucket) + bucket_obj.Object(key).Acl().put(ACL='public-read') # TODO: do we need to generate a signed url after doing this? + url = generate_presigned_url(s3_resource=s3_resource, + bucket=bucket, + key_name=key, + # One year should be sufficient to finish any pipeline ;-) + expiration=int(timedelta(days=365).total_seconds())) + # boto doesn't properly remove the x-amz-security-token parameter when + # query_auth is False when using an IAM role (see issue #2043). Including the + # x-amz-security-token parameter without the access key results in a 403, + # even if the resource is public, so we need to remove it. + # TODO: verify that this is still the case + return modify_url(url, remove=['x-amz-security-token', 'AWSAccessKeyId', 'Signature']) + + +def get_s3_bucket_region(s3_resource, bucket: str): + s3_client = s3_resource.meta.client + # AWS returns None for the default of 'us-east-1' + return s3_client.get_bucket_location(Bucket=bucket).get('LocationConstraint', None) or 'us-east-1' diff --git a/src/toil/lib/aws/utils.py b/src/toil/lib/aws/utils.py index a6b2a2e0a3..91772089b8 100644 --- a/src/toil/lib/aws/utils.py +++ b/src/toil/lib/aws/utils.py @@ -16,8 +16,8 @@ import os import socket from collections.abc import Iterable, Iterator -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, cast -from urllib.parse import ParseResult +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union, cast +from urllib.parse import ParseResult, urlparse from toil.lib.aws import AWSRegionName, AWSServerErrors, session from toil.lib.conversions import strtobool @@ -346,7 +346,7 @@ def bucket_location_to_region(location: Optional[str]) -> str: return "us-east-1" if location == "" or location is None else location -def get_object_for_url(url: ParseResult, existing: Optional[bool] = None) -> "S3Object": +def get_object_for_url(url: Union[ParseResult, str], existing: Optional[bool] = None) -> "S3Object": """ Extracts a key (object) from a given parsed s3:// URL. @@ -355,6 +355,8 @@ def get_object_for_url(url: ParseResult, existing: Optional[bool] = None) -> "S3 :param bool existing: If True, key is expected to exist. If False, key is expected not to exists and it will be created. If None, the key will be created if it doesn't exist. """ + if isinstance(url, str): + url = urlparse(url) key_name = url.path[1:] bucket_name = url.netloc @@ -407,11 +409,14 @@ def get_object_for_url(url: ParseResult, existing: Optional[bool] = None) -> "S3 @retry(errors=[AWSServerErrors]) -def list_objects_for_url(url: ParseResult) -> list[str]: +def list_objects_for_url(url: Union[ParseResult, str]) -> list[str]: """ Extracts a key (object) from a given parsed s3:// URL. The URL will be supplemented with a trailing slash if it is missing. """ + if isinstance(url, str): + url = urlparse(url) + key_name = url.path[1:] bucket_name = url.netloc diff --git a/src/toil/lib/checksum.py b/src/toil/lib/checksum.py new file mode 100644 index 0000000000..101a91d6da --- /dev/null +++ b/src/toil/lib/checksum.py @@ -0,0 +1,83 @@ +# Copyright (C) 2015-2021 Regents of the University of California +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import hashlib + +from io import BytesIO +from typing import BinaryIO, Union + +from toil.lib.aws.config import S3_PART_SIZE + +logger = logging.getLogger(__name__) + + +class ChecksumError(Exception): + """Raised when a download does not contain the correct data.""" + + +class Etag: + """A hasher for s3 etags.""" + def __init__(self, chunk_size): + self.etag_bytes = 0 + self.etag_parts = [] + self.etag_hasher = hashlib.md5() + self.chunk_size = chunk_size + + def update(self, chunk): + if self.etag_bytes + len(chunk) > self.chunk_size: + chunk_head = chunk[:self.chunk_size - self.etag_bytes] + chunk_tail = chunk[self.chunk_size - self.etag_bytes:] + self.etag_hasher.update(chunk_head) + self.etag_parts.append(self.etag_hasher.digest()) + self.etag_hasher = hashlib.md5() + self.etag_hasher.update(chunk_tail) + self.etag_bytes = len(chunk_tail) + else: + self.etag_hasher.update(chunk) + self.etag_bytes += len(chunk) + + def hexdigest(self): + if self.etag_bytes: + self.etag_parts.append(self.etag_hasher.digest()) + self.etag_bytes = 0 + if len(self.etag_parts) > 1: + etag = hashlib.md5(b"".join(self.etag_parts)).hexdigest() + return f'{etag}-{len(self.etag_parts)}' + else: + return self.etag_hasher.hexdigest() + + +hashers = {'sha1': hashlib.sha1(), + 'sha256': hashlib.sha256(), + 'etag': Etag(chunk_size=S3_PART_SIZE)} + + +def compute_checksum_for_file(local_file_path: str, algorithm: str = 'sha1') -> str: + with open(local_file_path, 'rb') as fh: + checksum_result = compute_checksum_for_content(fh, algorithm=algorithm) + return checksum_result + + +def compute_checksum_for_content(fh: Union[BinaryIO, BytesIO], algorithm: str = 'sha1') -> str: + """ + Note: Chunk size matters for s3 etags, and must be the same to get the same hash from the same object. + Therefore this buffer is not modifiable throughout Toil. + """ + hasher = hashers[algorithm] + contents = fh.read(S3_PART_SIZE) + while contents != b'': + hasher.update(contents) + contents = fh.read(S3_PART_SIZE) + + return f'{algorithm}${hasher.hexdigest()}' diff --git a/src/toil/lib/conversions.py b/src/toil/lib/conversions.py index f391f5ade4..fbb33a3d78 100644 --- a/src/toil/lib/conversions.py +++ b/src/toil/lib/conversions.py @@ -2,40 +2,28 @@ Conversion utilities for mapping memory, disk, core declarations from strings to numbers and vice versa. Also contains general conversion functions """ - import math -from typing import Optional, SupportsInt, Union +import urllib.parse + +from typing import Optional, SupportsInt, Union, List + +KIB = 1024 +MIB = 1024 ** 2 +GIB = 1024 ** 3 +TIB = 1024 ** 4 +PIB = 1024 ** 5 +EIB = 1024 ** 6 + +KB = 1000 +MB = 1000 ** 2 +GB = 1000 ** 3 +TB = 1000 ** 4 +PB = 1000 ** 5 +EB = 1000 ** 6 # See https://en.wikipedia.org/wiki/Binary_prefix -BINARY_PREFIXES = [ - "ki", - "mi", - "gi", - "ti", - "pi", - "ei", - "kib", - "mib", - "gib", - "tib", - "pib", - "eib", -] -DECIMAL_PREFIXES = [ - "b", - "k", - "m", - "g", - "t", - "p", - "e", - "kb", - "mb", - "gb", - "tb", - "pb", - "eb", -] +BINARY_PREFIXES = ['ki', 'mi', 'gi', 'ti', 'pi', 'ei', 'kib', 'mib', 'gib', 'tib', 'pib', 'eib'] +DECIMAL_PREFIXES = ['b', 'k', 'm', 'g', 't', 'p', 'e', 'kb', 'mb', 'gb', 'tb', 'pb', 'eb'] VALID_PREFIXES = BINARY_PREFIXES + DECIMAL_PREFIXES @@ -185,3 +173,17 @@ def strtobool(val: str) -> bool: def opt_strtobool(b: Optional[str]) -> Optional[bool]: """Convert an optional string representation of bool to None or bool""" return b if b is None else strtobool(b) + + +def modify_url(url: str, remove: List[str]) -> str: + """ + Given a valid URL string, split out the params, remove any offending + params in 'remove', and return the cleaned URL. + """ + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + params = urllib.parse.parse_qs(query) + for param_key in remove: + if param_key in params: + del params[param_key] + query = urllib.parse.urlencode(params, doseq=True) + return urllib.parse.urlunsplit((scheme, netloc, path, query, fragment)) diff --git a/src/toil/lib/pipes.py b/src/toil/lib/pipes.py new file mode 100644 index 0000000000..68f9144328 --- /dev/null +++ b/src/toil/lib/pipes.py @@ -0,0 +1,358 @@ +import errno +import logging +import os +import hashlib +from abc import ABC, abstractmethod + +from toil.lib.checksum import ChecksumError +from toil.lib.threading import ExceptionalThread + +log = logging.getLogger(__name__) + + +class WritablePipe(ABC): + """ + An object-oriented wrapper for os.pipe. Clients should subclass it, implement + :meth:`.readFrom` to consume the readable end of the pipe, then instantiate the class as a + context manager to get the writable end. See the example below. + + >>> import sys, shutil + >>> class MyPipe(WritablePipe): + ... def readFrom(self, readable): + ... shutil.copyfileobj(codecs.getreader('utf-8')(readable), sys.stdout) + >>> with MyPipe() as writable: + ... _ = writable.write('Hello, world!\\n'.encode('utf-8')) + Hello, world! + + Each instance of this class creates a thread and invokes the readFrom method in that thread. + The thread will be join()ed upon normal exit from the context manager, i.e. the body of the + `with` statement. If an exception occurs, the thread will not be joined but a well-behaved + :meth:`.readFrom` implementation will terminate shortly thereafter due to the pipe having + been closed. + + Now, exceptions in the reader thread will be reraised in the main thread: + + >>> class MyPipe(WritablePipe): + ... def readFrom(self, readable): + ... raise RuntimeError('Hello, world!') + >>> with MyPipe() as writable: + ... pass + Traceback (most recent call last): + ... + RuntimeError: Hello, world! + + More complicated, less illustrative tests: + + Same as above, but proving that handles are closed: + + >>> x = os.dup(0); os.close(x) + >>> class MyPipe(WritablePipe): + ... def readFrom(self, readable): + ... raise RuntimeError('Hello, world!') + >>> with MyPipe() as writable: + ... pass + Traceback (most recent call last): + ... + RuntimeError: Hello, world! + >>> y = os.dup(0); os.close(y); x == y + True + + Exceptions in the body of the with statement aren't masked, and handles are closed: + + >>> x = os.dup(0); os.close(x) + >>> class MyPipe(WritablePipe): + ... def readFrom(self, readable): + ... pass + >>> with MyPipe() as writable: + ... raise RuntimeError('Hello, world!') + Traceback (most recent call last): + ... + RuntimeError: Hello, world! + >>> y = os.dup(0); os.close(y); x == y + True + """ + + @abstractmethod + def readFrom(self, readable): + """ + Implement this method to read data from the pipe. This method should support both + binary and text mode output. + + :param file readable: the file object representing the readable end of the pipe. Do not + explicitly invoke the close() method of the object, that will be done automatically. + """ + raise NotImplementedError() + + def _reader(self): + with os.fdopen(self.readable_fh, 'rb') as readable: + # TODO: If the reader somehow crashes here, both threads might try + # to close readable_fh. Fortunately we don't do anything that + # should be able to fail here. + self.readable_fh = None # signal to parent thread that we've taken over + self.readFrom(readable) + self.reader_done = True + + def __init__(self, encoding=None, errors=None): + """ + The specified encoding and errors apply to the writable end of the pipe. + + :param str encoding: the name of the encoding used to encode the file. Encodings are the same + as for encode(). Defaults to None which represents binary mode. + + :param str errors: an optional string that specifies how encoding errors are to be handled. Errors + are the same as for open(). Defaults to 'strict' when an encoding is specified. + """ + super(WritablePipe, self).__init__() + self.encoding = encoding + self.errors = errors + self.readable_fh = None + self.writable = None + self.thread = None + self.reader_done = False + + def __enter__(self): + self.readable_fh, writable_fh = os.pipe() + self.writable = os.fdopen(writable_fh, 'wb' if self.encoding == None else 'wt', encoding=self.encoding, errors=self.errors) + self.thread = ExceptionalThread(target=self._reader) + self.thread.start() + return self.writable + + def __exit__(self, exc_type, exc_val, exc_tb): + # Closeing the writable end will send EOF to the readable and cause the reader thread + # to finish. + # TODO: Can close() fail? If so, would we try and clean up after the reader? + self.writable.close() + try: + if self.thread is not None: + # reraises any exception that was raised in the thread + self.thread.join() + except Exception as e: + if exc_type is None: + # Only raise the child exception if there wasn't + # already an exception in the main thread + raise + else: + log.error('Swallowing additional exception in reader thread: %s', str(e)) + finally: + # The responsibility for closing the readable end is generally that of the reader + # thread. To cover the small window before the reader takes over we also close it here. + readable_fh = self.readable_fh + if readable_fh is not None: + # Close the file handle. The reader thread must be dead now. + os.close(readable_fh) + + +class ReadablePipe(ABC): + """ + An object-oriented wrapper for os.pipe. Clients should subclass it, implement + :meth:`.writeTo` to place data into the writable end of the pipe, then instantiate the class + as a context manager to get the writable end. See the example below. + + >>> import sys, shutil + >>> class MyPipe(ReadablePipe): + ... def writeTo(self, writable): + ... writable.write('Hello, world!\\n'.encode('utf-8')) + >>> with MyPipe() as readable: + ... shutil.copyfileobj(codecs.getreader('utf-8')(readable), sys.stdout) + Hello, world! + + Each instance of this class creates a thread and invokes the :meth:`.writeTo` method in that + thread. The thread will be join()ed upon normal exit from the context manager, i.e. the body + of the `with` statement. If an exception occurs, the thread will not be joined but a + well-behaved :meth:`.writeTo` implementation will terminate shortly thereafter due to the + pipe having been closed. + + Now, exceptions in the reader thread will be reraised in the main thread: + + >>> class MyPipe(ReadablePipe): + ... def writeTo(self, writable): + ... raise RuntimeError('Hello, world!') + >>> with MyPipe() as readable: + ... pass + Traceback (most recent call last): + ... + RuntimeError: Hello, world! + + More complicated, less illustrative tests: + + Same as above, but proving that handles are closed: + + >>> x = os.dup(0); os.close(x) + >>> class MyPipe(ReadablePipe): + ... def writeTo(self, writable): + ... raise RuntimeError('Hello, world!') + >>> with MyPipe() as readable: + ... pass + Traceback (most recent call last): + ... + RuntimeError: Hello, world! + >>> y = os.dup(0); os.close(y); x == y + True + + Exceptions in the body of the with statement aren't masked, and handles are closed: + + >>> x = os.dup(0); os.close(x) + >>> class MyPipe(ReadablePipe): + ... def writeTo(self, writable): + ... pass + >>> with MyPipe() as readable: + ... raise RuntimeError('Hello, world!') + Traceback (most recent call last): + ... + RuntimeError: Hello, world! + >>> y = os.dup(0); os.close(y); x == y + True + """ + + @abstractmethod + def writeTo(self, writable): + """ + Implement this method to write data from the pipe. This method should support both + binary and text mode input. + + :param file writable: the file object representing the writable end of the pipe. Do not + explicitly invoke the close() method of the object, that will be done automatically. + """ + raise NotImplementedError() + + def _writer(self): + try: + with os.fdopen(self.writable_fh, 'wb') as writable: + self.writeTo(writable) + except IOError as e: + # The other side of the pipe may have been closed by the + # reading thread, which is OK. + if e.errno != errno.EPIPE: + raise + + def __init__(self, encoding=None, errors=None): + """ + The specified encoding and errors apply to the readable end of the pipe. + + :param str encoding: the name of the encoding used to encode the file. Encodings are the same + as for encode(). Defaults to None which represents binary mode. + + :param str errors: an optional string that specifies how encoding errors are to be handled. Errors + are the same as for open(). Defaults to 'strict' when an encoding is specified. + """ + super(ReadablePipe, self).__init__() + self.encoding = encoding + self.errors = errors + self.writable_fh = None + self.readable = None + self.thread = None + + def __enter__(self): + readable_fh, self.writable_fh = os.pipe() + self.readable = os.fdopen(readable_fh, 'rb' if self.encoding == None else 'rt', encoding=self.encoding, errors=self.errors) + self.thread = ExceptionalThread(target=self._writer) + self.thread.start() + return self.readable + + def __exit__(self, exc_type, exc_val, exc_tb): + # Close the read end of the pipe. The writing thread may + # still be writing to the other end, but this will wake it up + # if that's the case. + self.readable.close() + try: + if self.thread is not None: + # reraises any exception that was raised in the thread + self.thread.join() + except: + if exc_type is None: + # Only raise the child exception if there wasn't + # already an exception in the main thread + raise + + +class ReadableTransformingPipe(ReadablePipe): + """ + A pipe which is constructed around a readable stream, and which provides a + context manager that gives a readable stream. + + Useful as a base class for pipes which have to transform or otherwise visit + bytes that flow through them, instead of just consuming or producing data. + + Clients should subclass it and implement :meth:`.transform`, like so: + + >>> import sys, shutil + >>> class MyPipe(ReadableTransformingPipe): + ... def transform(self, readable, writable): + ... writable.write(readable.read().decode('utf-8').upper().encode('utf-8')) + >>> class SourcePipe(ReadablePipe): + ... def writeTo(self, writable): + ... writable.write('Hello, world!\\n'.encode('utf-8')) + >>> with SourcePipe() as source: + ... with MyPipe(source) as transformed: + ... shutil.copyfileobj(codecs.getreader('utf-8')(transformed), sys.stdout) + HELLO, WORLD! + + The :meth:`.transform` method runs in its own thread, and should move data + chunk by chunk instead of all at once. It should finish normally if it + encounters either an EOF on the readable, or a :class:`BrokenPipeError` on + the writable. This means tat it should make sure to actually catch a + :class:`BrokenPipeError` when writing. + + See also: :class:`toil.lib.misc.WriteWatchingStream`. + + """ + def __init__(self, source, encoding=None, errors=None): + """ + :param str encoding: the name of the encoding used to encode the file. Encodings are the same + as for encode(). Defaults to None which represents binary mode. + + :param str errors: an optional string that specifies how encoding errors are to be handled. Errors + are the same as for open(). Defaults to 'strict' when an encoding is specified. + """ + super(ReadableTransformingPipe, self).__init__(encoding=encoding, errors=errors) + self.source = source + + @abstractmethod + def transform(self, readable, writable): + """ + Implement this method to ship data through the pipe. + + :param file readable: the input stream file object to transform. + + :param file writable: the file object representing the writable end of the pipe. Do not + explicitly invoke the close() method of the object, that will be done automatically. + """ + raise NotImplementedError() + + def writeTo(self, writable): + self.transform(self.source, writable) + + +class HashingPipe(ReadableTransformingPipe): + """ + Class which checksums all the data read through it. If it + reaches EOF and the checksum isn't correct, raises ChecksumError. + + Assumes info actually has a checksum. + """ + def __init__(self, source, encoding=None, errors=None, checksum_to_verify=None): + """ + :param str encoding: the name of the encoding used to encode the file. Encodings are the same + as for encode(). Defaults to None which represents binary mode. + + :param str errors: an optional string that specifies how encoding errors are to be handled. Errors + are the same as for open(). Defaults to 'strict' when an encoding is specified. + """ + super(HashingPipe, self).__init__(source=source, encoding=encoding, errors=errors) + self.checksum_to_verify = checksum_to_verify + + def transform(self, readable, writable): + hash_object = hashlib.sha1() + contents = readable.read(1024 * 1024) + while contents != b'': + hash_object.update(contents) + try: + writable.write(contents) + except BrokenPipeError: + # Read was stopped early by user code. + # Can't check the checksum. + return + contents = readable.read(1024 * 1024) + final_computed_checksum = f'sha1${hash_object.hexdigest()}' + if not self.checksum_to_verify == final_computed_checksum: + raise ChecksumError(f'Checksum mismatch. Expected: {self.checksum_to_verify} Actual: {final_computed_checksum}') diff --git a/src/toil/test/jobStores/jobStoreTest.py b/src/toil/test/jobStores/jobStoreTest.py index 1c88c071ab..dca723b4bf 100644 --- a/src/toil/test/jobStores/jobStoreTest.py +++ b/src/toil/test/jobStores/jobStoreTest.py @@ -1435,98 +1435,6 @@ def _corruptJobStore(self): assert isinstance(self.jobstore_initialized, AWSJobStore) # type hinting self.jobstore_initialized.destroy() - def testSDBDomainsDeletedOnFailedJobstoreBucketCreation(self): - """ - This test ensures that SDB domains bound to a jobstore are deleted if the jobstore bucket - failed to be created. We simulate a failed jobstore bucket creation by using a bucket in a - different region with the same name. - """ - from botocore.exceptions import ClientError - - from toil.jobStores.aws.jobStore import BucketLocationConflictException - from toil.lib.aws.session import establish_boto3_session - from toil.lib.aws.utils import retry_s3 - - externalAWSLocation = "us-west-1" - for testRegion in "us-east-1", "us-west-2": - # We run this test twice, once with the default s3 server us-east-1 as the test region - # and once with another server (us-west-2). The external server is always us-west-1. - # This incidentally tests that the BucketLocationConflictException is thrown when using - # both the default, and a non-default server. - testJobStoreUUID = str(uuid.uuid4()) - # Create the bucket at the external region - bucketName = "domain-test-" + testJobStoreUUID + "--files" - client = establish_boto3_session().client( - "s3", region_name=externalAWSLocation - ) - resource = establish_boto3_session().resource( - "s3", region_name=externalAWSLocation - ) - - for attempt in retry_s3(delays=(2, 5, 10, 30, 60), timeout=600): - with attempt: - # Create the bucket at the home region - client.create_bucket( - Bucket=bucketName, - CreateBucketConfiguration={ - "LocationConstraint": externalAWSLocation - }, - ) - - owner_tag = os.environ.get("TOIL_OWNER_TAG") - if owner_tag: - for attempt in retry_s3(delays=(1, 1, 2, 4, 8, 16), timeout=33): - with attempt: - bucket_tagging = resource.BucketTagging(bucketName) - bucket_tagging.put( - Tagging={"TagSet": [{"Key": "Owner", "Value": owner_tag}]} - ) - - options = Job.Runner.getDefaultOptions( - "aws:" + testRegion + ":domain-test-" + testJobStoreUUID - ) - options.logLevel = "DEBUG" - try: - with Toil(options) as toil: - pass - except BucketLocationConflictException: - # Catch the expected BucketLocationConflictException and ensure that the bound - # domains don't exist in SDB. - sdb = establish_boto3_session().client( - region_name=self.awsRegion(), service_name="sdb" - ) - next_token = None - allDomainNames = [] - while True: - if next_token is None: - domains = sdb.list_domains(MaxNumberOfDomains=100) - else: - domains = sdb.list_domains( - MaxNumberOfDomains=100, NextToken=next_token - ) - allDomainNames.extend(domains["DomainNames"]) - next_token = domains.get("NextToken") - if next_token is None: - break - self.assertFalse([d for d in allDomainNames if testJobStoreUUID in d]) - else: - self.fail() - finally: - try: - for attempt in retry_s3(): - with attempt: - client.delete_bucket(Bucket=bucketName) - except ClientError as e: - # The actual HTTP code of the error is in status. - if ( - e.response.get("ResponseMetadata", {}).get("HTTPStatusCode") - == 404 - ): - # The bucket doesn't exist; maybe a failed delete actually succeeded. - pass - else: - raise - @slow def testInlinedFiles(self): from toil.jobStores.aws.jobStore import AWSJobStore @@ -1667,13 +1575,8 @@ def _largeLogEntrySize(self): # So we get into the else branch of reader() in uploadStream(multiPart=False): return AWSJobStore.FileInfo.maxBinarySize() * 2 - def _batchDeletionSize(self): - from toil.jobStores.aws.jobStore import AWSJobStore - return AWSJobStore.itemsPerBatchDelete - - -@needs_aws_s3 +# @needs_aws_s3 class InvalidAWSJobStoreTest(ToilTest): def testInvalidJobStoreName(self): from toil.jobStores.aws.jobStore import AWSJobStore