From 449aed874037c24c54006d426ebb6b2235803f2c Mon Sep 17 00:00:00 2001 From: Daniel Hollas Date: Tue, 14 Jan 2025 14:50:30 +0000 Subject: [PATCH 1/5] Run mypy on orm/{comments,computer}.py --- .pre-commit-config.yaml | 3 --- src/aiida/orm/comments.py | 7 ++++--- src/aiida/orm/computers.py | 12 ++++++------ src/aiida/transports/transport.py | 2 +- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fb2019ac26..77a7478cc4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -122,10 +122,7 @@ repos: src/aiida/manage/external/rmq/launcher.py| src/aiida/manage/tests/main.py| src/aiida/manage/tests/pytest_fixtures.py| - src/aiida/orm/comments.py| - src/aiida/orm/computers.py| src/aiida/orm/implementation/storage_backend.py| - src/aiida/orm/nodes/comments.py| src/aiida/orm/nodes/data/array/bands.py| src/aiida/orm/nodes/data/array/trajectory.py| src/aiida/orm/nodes/data/cif.py| diff --git a/src/aiida/orm/comments.py b/src/aiida/orm/comments.py index bc92351c54..b8594da7c5 100644 --- a/src/aiida/orm/comments.py +++ b/src/aiida/orm/comments.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from aiida.orm import Node, User - from aiida.orm.implementation import StorageBackend + from aiida.orm.implementation import BackendComment, BackendNode, StorageBackend # noqa: F401 __all__ = ('Comment',) @@ -146,7 +146,7 @@ def set_mtime(self, value: datetime) -> None: return self._backend_entity.set_mtime(value) @property - def node(self) -> 'Node': + def node(self) -> 'BackendNode': return self._backend_entity.node @property @@ -154,7 +154,8 @@ def user(self) -> 'User': return entities.from_backend_entity(users.User, self._backend_entity.user) def set_user(self, value: 'User') -> None: - self._backend_entity.user = value.backend_entity + # ignoring mypy error: Property "user" defined in "BackendComment" is read-only + self._backend_entity.user = value.backend_entity # type: ignore[misc] @property def content(self) -> str: diff --git a/src/aiida/orm/computers.py b/src/aiida/orm/computers.py index bae925b25c..51fd491f99 100644 --- a/src/aiida/orm/computers.py +++ b/src/aiida/orm/computers.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from aiida.orm import AuthInfo, User - from aiida.orm.implementation import StorageBackend + from aiida.orm.implementation import BackendComputer, StorageBackend # noqa: F401 from aiida.schedulers import Scheduler from aiida.transports import Transport @@ -54,7 +54,7 @@ def get_or_create(self, label: Optional[str] = None, **kwargs) -> Tuple[bool, 'C def list_labels(self) -> List[str]: """Return a list with all the labels of the computers in the DB.""" - return self._backend.computers.list_names() + return self._backend.computers.list_names() # type: ignore[attr-defined] def delete(self, pk: int) -> None: """Delete the computer with the given id""" @@ -224,7 +224,7 @@ def _mpirun_command_validator(self, mpirun_cmd: Union[List[str], Tuple[str, ...] """Validates the mpirun_command variable. MUST be called after properly checking for a valid scheduler. """ - if not isinstance(mpirun_cmd, (tuple, list)) or not all(isinstance(i, str) for i in mpirun_cmd): + if not isinstance(mpirun_cmd, (tuple, list)) or not all(isinstance(i, str) for i in mpirun_cmd): # type: ignore[redundant-expr] raise exceptions.ValidationError('the mpirun_command must be a list of strings') try: @@ -278,7 +278,7 @@ def _default_mpiprocs_per_machine_validator(cls, def_cpus_per_machine: Optional[ if def_cpus_per_machine is None: return - if not isinstance(def_cpus_per_machine, int) or def_cpus_per_machine <= 0: + if not isinstance(def_cpus_per_machine, int) or def_cpus_per_machine <= 0: # type: ignore[redundant-expr] raise exceptions.ValidationError( 'Invalid value for default_mpiprocs_per_machine, must be a positive integer, or an empty string if you ' 'do not want to provide a default value.' @@ -290,7 +290,7 @@ def default_memory_per_machine_validator(cls, def_memory_per_machine: Optional[i if def_memory_per_machine is None: return - if not isinstance(def_memory_per_machine, int) or def_memory_per_machine <= 0: + if not isinstance(def_memory_per_machine, int) or def_memory_per_machine <= 0: # type: ignore[redundant-expr] raise exceptions.ValidationError( f'Invalid value for def_memory_per_machine, must be a positive int, got: {def_memory_per_machine}' ) @@ -487,7 +487,7 @@ def set_mpirun_command(self, val: Union[List[str], Tuple[str, ...]]) -> None: """Set the mpirun command. It must be a list of strings (you can use string.split() if you have a single, space-separated string). """ - if not isinstance(val, (tuple, list)) or not all(isinstance(i, str) for i in val): + if not isinstance(val, (tuple, list)) or not all(isinstance(i, str) for i in val): # type: ignore[redundant-expr] raise TypeError('the mpirun_command must be a list of strings') self.set_property('mpirun_command', val) diff --git a/src/aiida/transports/transport.py b/src/aiida/transports/transport.py index a6d755973e..06d2a4e573 100644 --- a/src/aiida/transports/transport.py +++ b/src/aiida/transports/transport.py @@ -43,7 +43,7 @@ class Transport(abc.ABC): """Abstract class for a generic transport (ssh, local, ...) contains the set of minimal methods.""" # This will be used for ``Computer.get_minimum_job_poll_interval`` - DEFAULT_MINIMUM_JOB_POLL_INTERVAL = 10 + DEFAULT_MINIMUM_JOB_POLL_INTERVAL = 10.0 # This is used as a global default in case subclasses don't redefine this, # but this should be redefined in plugins where appropriate From 8d0fa457dc2989fadc7a34824855271fd3599ebb Mon Sep 17 00:00:00 2001 From: Daniel Hollas Date: Tue, 14 Jan 2025 15:00:25 +0000 Subject: [PATCH 2/5] Run typing on orm/implementation/storage_backend.py --- .pre-commit-config.yaml | 1 - src/aiida/orm/implementation/storage_backend.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 77a7478cc4..c0211570f0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -122,7 +122,6 @@ repos: src/aiida/manage/external/rmq/launcher.py| src/aiida/manage/tests/main.py| src/aiida/manage/tests/pytest_fixtures.py| - src/aiida/orm/implementation/storage_backend.py| src/aiida/orm/nodes/data/array/bands.py| src/aiida/orm/nodes/data/array/trajectory.py| src/aiida/orm/nodes/data/cif.py| diff --git a/src/aiida/orm/implementation/storage_backend.py b/src/aiida/orm/implementation/storage_backend.py index 6137508f51..7e19a25ca3 100644 --- a/src/aiida/orm/implementation/storage_backend.py +++ b/src/aiida/orm/implementation/storage_backend.py @@ -456,7 +456,7 @@ def get_orm_entities(self, detailed: bool = False) -> dict: """ from aiida.orm import Comment, Computer, Group, Log, Node, QueryBuilder, User - data = {} + data: dict[str, Any] = {} query_user = QueryBuilder(self).append(User, project=['email']) data['Users'] = {'count': query_user.count()} From b378e1606a93002b3a47e956fefa7fee24a0c737 Mon Sep 17 00:00:00 2001 From: Daniel Hollas Date: Tue, 14 Jan 2025 18:38:37 +0000 Subject: [PATCH 3/5] Add typing to common/utils.py --- .pre-commit-config.yaml | 1 - src/aiida/common/utils.py | 59 ++++++++++++++++++++------------------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c0211570f0..ccd36ea7f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -110,7 +110,6 @@ repos: src/aiida/cmdline/utils/common.py| src/aiida/cmdline/utils/echo.py| src/aiida/common/extendeddicts.py| - src/aiida/common/utils.py| src/aiida/engine/daemon/execmanager.py| src/aiida/engine/processes/calcjobs/manager.py| src/aiida/engine/processes/calcjobs/monitors.py| diff --git a/src/aiida/common/utils.py b/src/aiida/common/utils.py index 1b2f2b14ce..6020983378 100644 --- a/src/aiida/common/utils.py +++ b/src/aiida/common/utils.py @@ -14,8 +14,9 @@ import os import re import sys -from datetime import datetime -from typing import Any, Dict +from collections.abc import Iterable +from datetime import datetime, timedelta +from typing import Any, Dict, Optional from uuid import UUID from .lang import classproperty @@ -41,7 +42,7 @@ def validate_uuid(given_uuid: str) -> bool: return str(parsed_uuid) == given_uuid -def validate_list_of_string_tuples(val, tuple_length): +def validate_list_of_string_tuples(val: Any, tuple_length: int) -> bool: """Check that: 1. ``val`` is a list or tuple @@ -75,7 +76,7 @@ def validate_list_of_string_tuples(val, tuple_length): return True -def get_unique_filename(filename, list_of_filenames): +def get_unique_filename(filename: str, list_of_filenames: list | tuple) -> str: """Return a unique filename that can be added to the list_of_filenames. If filename is not in list_of_filenames, it simply returns the filename @@ -104,7 +105,7 @@ def get_unique_filename(filename, list_of_filenames): return new_filename -def str_timedelta(dt, max_num_fields=3, short=False, negative_to_zero=False): +def str_timedelta(dt: timedelta, max_num_fields: int = 3, short: bool = False, negative_to_zero: bool = False) -> str: """Given a dt in seconds, return it in a HH:MM:SS format. :param dt: a TimeDelta object @@ -170,7 +171,7 @@ def str_timedelta(dt, max_num_fields=3, short=False, negative_to_zero=False): return f'{raw_string}{negative_string}' -def get_class_string(obj): +def get_class_string(obj: Any) -> str: """Return the string identifying the class of the object (module + object name, joined by dots). @@ -182,7 +183,7 @@ def get_class_string(obj): return f'{obj.__module__}.{obj.__class__.__name__}' -def get_object_from_string(class_string): +def get_object_from_string(class_string: str) -> Any: """Given a string identifying an object (as returned by the get_class_string method) load and return the actual object. """ @@ -193,7 +194,7 @@ def get_object_from_string(class_string): return getattr(importlib.import_module(the_module), the_name) -def grouper(n, iterable): +def grouper(n: int, iterable: Iterable) -> Iterable: """Given an iterable, returns an iterable that returns tuples of groups of elements from iterable of length n, except the last one that has the required length to exaust iterable (i.e., there is no filling applied). @@ -223,11 +224,11 @@ def __init__(self): self.seq = -1 def array_counter(self): - self.seq += 1 + self.seq += 1 # type: ignore[operator] return self.seq -def are_dir_trees_equal(dir1, dir2): +def are_dir_trees_equal(dir1: str, dir2: str) -> tuple[bool, str]: """Compare two directories recursively. Files in each directory are assumed to be equal if their names and contents are equal. @@ -376,14 +377,14 @@ def prettifiers(cls) -> Dict[str, Any]: # noqa: N805 } @classmethod - def get_prettifiers(cls): + def get_prettifiers(cls) -> list[str]: """Return a list of valid prettifier strings :return: a list of strings """ return sorted(cls.prettifiers.keys()) - def __init__(self, format): + def __init__(self, format: Optional[str]): """Create a class to pretttify strings of a given format :param format: a string with the format to use to prettify. @@ -397,7 +398,7 @@ def __init__(self, format): except KeyError: raise ValueError(f"Unknown prettifier format {format}; valid formats: {', '.join(self.get_prettifiers())}") - def prettify(self, label): + def prettify(self, label: str) -> str: """Prettify a label using the format passed in the initializer :param label: the string to prettify @@ -406,7 +407,7 @@ def prettify(self, label): return self._prettifier_f(label) -def prettify_labels(labels, format=None): +def prettify_labels(labels: list, format: Optional[str] = None) -> list: """Prettify label for typesetting in various formats :param labels: a list of length-2 tuples, in the format(position, label) @@ -420,7 +421,7 @@ def prettify_labels(labels, format=None): return [(pos, prettifier.prettify(label)) for pos, label in labels] -def join_labels(labels, join_symbol='|', threshold=1.0e-6): +def join_labels(labels: list, join_symbol: str = '|', threshold: float = 1.0e-6): """Join labels with a joining symbol when they are very close :param labels: a list of length-2 tuples, in the format(position, label) @@ -446,9 +447,9 @@ def join_labels(labels, join_symbol='|', threshold=1.0e-6): return new_labels -def strip_prefix(full_string, prefix): +def strip_prefix(full_string: str, prefix: str) -> str: """Strip the prefix from the given string and return it. If the prefix is not present - the original string will be returned unaltered + the original string will be returned unaltered. :param full_string: the string from which to remove the prefix :param prefix: the prefix to remove @@ -480,14 +481,14 @@ class Capturing: lines, use obj.stderr_lines. If False, obj.stderr_lines is None. """ - def __init__(self, capture_stderr=False): + def __init__(self, capture_stderr: bool = False): """Construct a new instance.""" - self.stdout_lines = [] + self.stdout_lines: list[str] = [] super().__init__() self._capture_stderr = capture_stderr if self._capture_stderr: - self.stderr_lines = [] + self.stderr_lines: Optional[list] = [] else: self.stderr_lines = None @@ -508,7 +509,8 @@ def __exit__(self, *args): sys.stdout = self._stdout del self._stringioout # free up some memory if self._capture_stderr: - self.stderr_lines.extend(self._stringioerr.getvalue().splitlines()) + # NOTE: mypy is not clever enough to now that when we're here, self.stderr_lines is not None + self.stderr_lines.extend(self._stringioerr.getvalue().splitlines()) # type: ignore[union-attr] sys.stderr = self._stderr del self._stringioerr # free up some memory @@ -531,7 +533,7 @@ class ErrorAccumulator: def __init__(self, *error_cls): self.error_cls = error_cls - self.errors = {k: [] for k in self.error_cls} + self.errors: dict[type, list] = {k: [] for k in self.error_cls} def run(self, function, *args, **kwargs): try: @@ -539,7 +541,7 @@ def run(self, function, *args, **kwargs): except self.error_cls as err: self.errors[err.__class__].append(err) - def success(self): + def success(self) -> bool: return bool(not any(self.errors.values())) def result(self, raise_error=Exception): @@ -547,7 +549,7 @@ def result(self, raise_error=Exception): self.raise_errors(raise_error) return self.success(), self.errors - def raise_errors(self, raise_cls): + def raise_errors(self, raise_cls: type[Exception]) -> None: if not self.success(): raise raise_cls(f'The following errors were encountered: {self.errors}') @@ -562,7 +564,7 @@ class DatetimePrecision: 4 (dare + hour + minute +second) """ - def __init__(self, dtobj, precision): + def __init__(self, dtobj: datetime, precision: int): """Constructor to check valid datetime object and precision""" if not isinstance(dtobj, datetime): raise TypeError('dtobj argument has to be a datetime object') @@ -597,9 +599,10 @@ def format_directory_size(size_in_bytes: int) -> str: factor = 1024 # 1 KB = 1024 B index = 0 - while size_in_bytes >= factor and index < len(prefixes) - 1: - size_in_bytes /= factor + converted_size: float = size_in_bytes + while converted_size >= factor and index < len(prefixes) - 1: + converted_size /= factor index += 1 # Format the size to two decimal places - return f'{size_in_bytes:.2f} {prefixes[index]}' + return f'{converted_size:.2f} {prefixes[index]}' From d4289e77875d311240578710f1b51ca1c59df758 Mon Sep 17 00:00:00 2001 From: Daniel Hollas Date: Tue, 14 Jan 2025 18:51:11 +0000 Subject: [PATCH 4/5] Add typing to common/extendeddicts.py --- .pre-commit-config.yaml | 1 - src/aiida/common/extendeddicts.py | 19 ++++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ccd36ea7f5..447a454491 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -109,7 +109,6 @@ repos: src/aiida/cmdline/utils/ascii_vis.py| src/aiida/cmdline/utils/common.py| src/aiida/cmdline/utils/echo.py| - src/aiida/common/extendeddicts.py| src/aiida/engine/daemon/execmanager.py| src/aiida/engine/processes/calcjobs/manager.py| src/aiida/engine/processes/calcjobs/monitors.py| diff --git a/src/aiida/common/extendeddicts.py b/src/aiida/common/extendeddicts.py index b45c191c31..59665c1811 100644 --- a/src/aiida/common/extendeddicts.py +++ b/src/aiida/common/extendeddicts.py @@ -9,6 +9,7 @@ """Various dictionary types with extended functionality.""" from collections.abc import Mapping +from typing import Optional from . import exceptions @@ -25,7 +26,7 @@ class AttributeDict(dict): used. """ - def __init__(self, dictionary=None): + def __init__(self, dictionary: Optional[Mapping] = None): """Recursively turn the `dict` and all its nested dictionaries into `AttributeDict` instance.""" super().__init__() if dictionary is None: @@ -37,7 +38,7 @@ def __init__(self, dictionary=None): else: self[key] = value - def __repr__(self): + def __repr__(self) -> str: """Representation of the object.""" return f'{self.__class__.__name__}({dict.__repr__(self)})' @@ -104,7 +105,7 @@ class TestExample(FixedFieldsAttributeDict): _valid_fields = ('a','b','c') """ - _valid_fields = tuple() + _valid_fields: tuple = tuple() def __init__(self, init=None): if init is None: @@ -131,11 +132,11 @@ def __setattr__(self, attr, value): super().__setattr__(attr, value) @classmethod - def get_valid_fields(cls): + def get_valid_fields(cls) -> tuple: """Return the list of valid fields.""" return cls._valid_fields - def __dir__(self): + def __dir__(self) -> list: return list(self._valid_fields) @@ -192,7 +193,7 @@ class TestExample(DefaultFieldsAttributeDict): See if we want that setting a default field to None means deleting it. """ - _default_fields = tuple() + _default_fields: tuple = tuple() def validate(self): """Validate the keys, if any ``validate_*`` method is available.""" @@ -225,14 +226,14 @@ def __getitem__(self, key): raise @classmethod - def get_default_fields(cls): + def get_default_fields(cls) -> list: """Return the list of default fields, either defined in the instance or not.""" return list(cls._default_fields) - def defaultkeys(self): + def defaultkeys(self) -> list: """Return the default keys defined in the instance.""" return [_ for _ in self.keys() if _ in self._default_fields] - def extrakeys(self): + def extrakeys(self) -> list: """Return the extra keys defined in the instance.""" return [_ for _ in self.keys() if _ not in self._default_fields] From 55d03cdf7a0a250cb9f34c09990b6825382fe342 Mon Sep 17 00:00:00 2001 From: Daniel Hollas Date: Tue, 14 Jan 2025 18:56:55 +0000 Subject: [PATCH 5/5] Fix 3.9 --- src/aiida/common/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aiida/common/utils.py b/src/aiida/common/utils.py index 6020983378..8117001bf6 100644 --- a/src/aiida/common/utils.py +++ b/src/aiida/common/utils.py @@ -76,7 +76,7 @@ def validate_list_of_string_tuples(val: Any, tuple_length: int) -> bool: return True -def get_unique_filename(filename: str, list_of_filenames: list | tuple) -> str: +def get_unique_filename(filename: str, list_of_filenames: 'list | tuple') -> str: """Return a unique filename that can be added to the list_of_filenames. If filename is not in list_of_filenames, it simply returns the filename