Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing to aiida/common/{utils,extendeddicts}.py #6706

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ repos:
src/aiida/cmdline/utils/ascii_vis.py|
src/aiida/cmdline/utils/common.py|
src/aiida/cmdline/utils/echo.py|
src/aiida/common/extendeddicts.py|
src/aiida/common/utils.py|
src/aiida/engine/daemon/execmanager.py|
src/aiida/engine/processes/calcjobs/manager.py|
src/aiida/engine/processes/calcjobs/monitors.py|
Expand All @@ -122,10 +120,6 @@ repos:
src/aiida/manage/external/rmq/launcher.py|
src/aiida/manage/tests/main.py|
src/aiida/manage/tests/pytest_fixtures.py|
src/aiida/orm/comments.py|
src/aiida/orm/computers.py|
src/aiida/orm/implementation/storage_backend.py|
src/aiida/orm/nodes/comments.py|
src/aiida/orm/nodes/data/array/bands.py|
src/aiida/orm/nodes/data/array/trajectory.py|
src/aiida/orm/nodes/data/cif.py|
Expand Down
19 changes: 10 additions & 9 deletions src/aiida/common/extendeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""Various dictionary types with extended functionality."""

from collections.abc import Mapping
from typing import Optional

from . import exceptions

Expand All @@ -25,7 +26,7 @@ class AttributeDict(dict):
used.
"""

def __init__(self, dictionary=None):
def __init__(self, dictionary: Optional[Mapping] = None):
"""Recursively turn the `dict` and all its nested dictionaries into `AttributeDict` instance."""
super().__init__()
if dictionary is None:
Expand All @@ -37,7 +38,7 @@ def __init__(self, dictionary=None):
else:
self[key] = value

def __repr__(self):
def __repr__(self) -> str:
"""Representation of the object."""
return f'{self.__class__.__name__}({dict.__repr__(self)})'

Expand Down Expand Up @@ -104,7 +105,7 @@ class TestExample(FixedFieldsAttributeDict):
_valid_fields = ('a','b','c')
"""

_valid_fields = tuple()
_valid_fields: tuple = tuple()

def __init__(self, init=None):
if init is None:
Expand All @@ -131,11 +132,11 @@ def __setattr__(self, attr, value):
super().__setattr__(attr, value)

@classmethod
def get_valid_fields(cls):
def get_valid_fields(cls) -> tuple:
"""Return the list of valid fields."""
return cls._valid_fields

def __dir__(self):
def __dir__(self) -> list:
return list(self._valid_fields)


Expand Down Expand Up @@ -192,7 +193,7 @@ class TestExample(DefaultFieldsAttributeDict):
See if we want that setting a default field to None means deleting it.
"""

_default_fields = tuple()
_default_fields: tuple = tuple()

def validate(self):
"""Validate the keys, if any ``validate_*`` method is available."""
Expand Down Expand Up @@ -225,14 +226,14 @@ def __getitem__(self, key):
raise

@classmethod
def get_default_fields(cls):
def get_default_fields(cls) -> list:
"""Return the list of default fields, either defined in the instance or not."""
return list(cls._default_fields)

def defaultkeys(self):
def defaultkeys(self) -> list:
"""Return the default keys defined in the instance."""
return [_ for _ in self.keys() if _ in self._default_fields]

def extrakeys(self):
def extrakeys(self) -> list:
"""Return the extra keys defined in the instance."""
return [_ for _ in self.keys() if _ not in self._default_fields]
59 changes: 31 additions & 28 deletions src/aiida/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
import os
import re
import sys
from datetime import datetime
from typing import Any, Dict
from collections.abc import Iterable
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
from uuid import UUID

from .lang import classproperty
Expand All @@ -41,7 +42,7 @@
return str(parsed_uuid) == given_uuid


def validate_list_of_string_tuples(val, tuple_length):
def validate_list_of_string_tuples(val: Any, tuple_length: int) -> bool:
"""Check that:

1. ``val`` is a list or tuple
Expand Down Expand Up @@ -75,7 +76,7 @@
return True


def get_unique_filename(filename, list_of_filenames):
def get_unique_filename(filename: str, list_of_filenames: 'list | tuple') -> str:
"""Return a unique filename that can be added to the list_of_filenames.

If filename is not in list_of_filenames, it simply returns the filename
Expand Down Expand Up @@ -104,7 +105,7 @@
return new_filename


def str_timedelta(dt, max_num_fields=3, short=False, negative_to_zero=False):
def str_timedelta(dt: timedelta, max_num_fields: int = 3, short: bool = False, negative_to_zero: bool = False) -> str:
"""Given a dt in seconds, return it in a HH:MM:SS format.

:param dt: a TimeDelta object
Expand Down Expand Up @@ -170,7 +171,7 @@
return f'{raw_string}{negative_string}'


def get_class_string(obj):
def get_class_string(obj: Any) -> str:
"""Return the string identifying the class of the object (module + object name,
joined by dots).

Expand All @@ -182,7 +183,7 @@
return f'{obj.__module__}.{obj.__class__.__name__}'


def get_object_from_string(class_string):
def get_object_from_string(class_string: str) -> Any:
"""Given a string identifying an object (as returned by the get_class_string
method) load and return the actual object.
"""
Expand All @@ -193,7 +194,7 @@
return getattr(importlib.import_module(the_module), the_name)


def grouper(n, iterable):
def grouper(n: int, iterable: Iterable) -> Iterable:
"""Given an iterable, returns an iterable that returns tuples of groups of
elements from iterable of length n, except the last one that has the
required length to exaust iterable (i.e., there is no filling applied).
Expand Down Expand Up @@ -223,11 +224,11 @@
self.seq = -1

def array_counter(self):
self.seq += 1
self.seq += 1 # type: ignore[operator]

Check warning on line 227 in src/aiida/common/utils.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/common/utils.py#L227

Added line #L227 was not covered by tests
return self.seq


def are_dir_trees_equal(dir1, dir2):
def are_dir_trees_equal(dir1: str, dir2: str) -> tuple[bool, str]:
"""Compare two directories recursively. Files in each directory are
assumed to be equal if their names and contents are equal.

Expand Down Expand Up @@ -376,14 +377,14 @@
}

@classmethod
def get_prettifiers(cls):
def get_prettifiers(cls) -> list[str]:
"""Return a list of valid prettifier strings

:return: a list of strings
"""
return sorted(cls.prettifiers.keys())

def __init__(self, format):
def __init__(self, format: Optional[str]):
"""Create a class to pretttify strings of a given format

:param format: a string with the format to use to prettify.
Expand All @@ -397,7 +398,7 @@
except KeyError:
raise ValueError(f"Unknown prettifier format {format}; valid formats: {', '.join(self.get_prettifiers())}")

def prettify(self, label):
def prettify(self, label: str) -> str:
"""Prettify a label using the format passed in the initializer

:param label: the string to prettify
Expand All @@ -406,7 +407,7 @@
return self._prettifier_f(label)


def prettify_labels(labels, format=None):
def prettify_labels(labels: list, format: Optional[str] = None) -> list:
"""Prettify label for typesetting in various formats

:param labels: a list of length-2 tuples, in the format(position, label)
Expand All @@ -420,7 +421,7 @@
return [(pos, prettifier.prettify(label)) for pos, label in labels]


def join_labels(labels, join_symbol='|', threshold=1.0e-6):
def join_labels(labels: list, join_symbol: str = '|', threshold: float = 1.0e-6):
"""Join labels with a joining symbol when they are very close

:param labels: a list of length-2 tuples, in the format(position, label)
Expand All @@ -446,9 +447,9 @@
return new_labels


def strip_prefix(full_string, prefix):
def strip_prefix(full_string: str, prefix: str) -> str:
"""Strip the prefix from the given string and return it. If the prefix is not present
the original string will be returned unaltered
the original string will be returned unaltered.

:param full_string: the string from which to remove the prefix
:param prefix: the prefix to remove
Expand Down Expand Up @@ -480,14 +481,14 @@
lines, use obj.stderr_lines. If False, obj.stderr_lines is None.
"""

def __init__(self, capture_stderr=False):
def __init__(self, capture_stderr: bool = False):
"""Construct a new instance."""
self.stdout_lines = []
self.stdout_lines: list[str] = []
super().__init__()

self._capture_stderr = capture_stderr
if self._capture_stderr:
self.stderr_lines = []
self.stderr_lines: Optional[list] = []

Check warning on line 491 in src/aiida/common/utils.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/common/utils.py#L491

Added line #L491 was not covered by tests
else:
self.stderr_lines = None

Expand All @@ -508,7 +509,8 @@
sys.stdout = self._stdout
del self._stringioout # free up some memory
if self._capture_stderr:
self.stderr_lines.extend(self._stringioerr.getvalue().splitlines())
# NOTE: mypy is not clever enough to now that when we're here, self.stderr_lines is not None
self.stderr_lines.extend(self._stringioerr.getvalue().splitlines()) # type: ignore[union-attr]

Check warning on line 513 in src/aiida/common/utils.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/common/utils.py#L513

Added line #L513 was not covered by tests
sys.stderr = self._stderr
del self._stringioerr # free up some memory

Expand All @@ -531,23 +533,23 @@

def __init__(self, *error_cls):
self.error_cls = error_cls
self.errors = {k: [] for k in self.error_cls}
self.errors: dict[type, list] = {k: [] for k in self.error_cls}

def run(self, function, *args, **kwargs):
try:
function(*args, **kwargs)
except self.error_cls as err:
self.errors[err.__class__].append(err)

def success(self):
def success(self) -> bool:
return bool(not any(self.errors.values()))

def result(self, raise_error=Exception):
if raise_error:
self.raise_errors(raise_error)
return self.success(), self.errors

def raise_errors(self, raise_cls):
def raise_errors(self, raise_cls: type[Exception]) -> None:
if not self.success():
raise raise_cls(f'The following errors were encountered: {self.errors}')

Expand All @@ -562,7 +564,7 @@
4 (dare + hour + minute +second)
"""

def __init__(self, dtobj, precision):
def __init__(self, dtobj: datetime, precision: int):
"""Constructor to check valid datetime object and precision"""
if not isinstance(dtobj, datetime):
raise TypeError('dtobj argument has to be a datetime object')
Expand Down Expand Up @@ -597,9 +599,10 @@
factor = 1024 # 1 KB = 1024 B
index = 0

while size_in_bytes >= factor and index < len(prefixes) - 1:
size_in_bytes /= factor
converted_size: float = size_in_bytes
while converted_size >= factor and index < len(prefixes) - 1:
converted_size /= factor
index += 1

# Format the size to two decimal places
return f'{size_in_bytes:.2f} {prefixes[index]}'
return f'{converted_size:.2f} {prefixes[index]}'
7 changes: 4 additions & 3 deletions src/aiida/orm/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

if TYPE_CHECKING:
from aiida.orm import Node, User
from aiida.orm.implementation import StorageBackend
from aiida.orm.implementation import BackendComment, BackendNode, StorageBackend # noqa: F401

Check warning on line 21 in src/aiida/orm/comments.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/comments.py#L21

Added line #L21 was not covered by tests

__all__ = ('Comment',)

Expand Down Expand Up @@ -146,15 +146,16 @@
return self._backend_entity.set_mtime(value)

@property
def node(self) -> 'Node':
def node(self) -> 'BackendNode':
return self._backend_entity.node

@property
def user(self) -> 'User':
return entities.from_backend_entity(users.User, self._backend_entity.user)

def set_user(self, value: 'User') -> None:
self._backend_entity.user = value.backend_entity
# ignoring mypy error: Property "user" defined in "BackendComment" is read-only
self._backend_entity.user = value.backend_entity # type: ignore[misc]

Check warning on line 158 in src/aiida/orm/comments.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/comments.py#L158

Added line #L158 was not covered by tests

@property
def content(self) -> str:
Expand Down
12 changes: 6 additions & 6 deletions src/aiida/orm/computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

if TYPE_CHECKING:
from aiida.orm import AuthInfo, User
from aiida.orm.implementation import StorageBackend
from aiida.orm.implementation import BackendComputer, StorageBackend # noqa: F401

Check warning on line 24 in src/aiida/orm/computers.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/computers.py#L24

Added line #L24 was not covered by tests
from aiida.schedulers import Scheduler
from aiida.transports import Transport

Expand Down Expand Up @@ -54,7 +54,7 @@

def list_labels(self) -> List[str]:
"""Return a list with all the labels of the computers in the DB."""
return self._backend.computers.list_names()
return self._backend.computers.list_names() # type: ignore[attr-defined]

Check warning on line 57 in src/aiida/orm/computers.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/computers.py#L57

Added line #L57 was not covered by tests

def delete(self, pk: int) -> None:
"""Delete the computer with the given id"""
Expand Down Expand Up @@ -224,7 +224,7 @@
"""Validates the mpirun_command variable. MUST be called after properly
checking for a valid scheduler.
"""
if not isinstance(mpirun_cmd, (tuple, list)) or not all(isinstance(i, str) for i in mpirun_cmd):
if not isinstance(mpirun_cmd, (tuple, list)) or not all(isinstance(i, str) for i in mpirun_cmd): # type: ignore[redundant-expr]
raise exceptions.ValidationError('the mpirun_command must be a list of strings')

try:
Expand Down Expand Up @@ -278,7 +278,7 @@
if def_cpus_per_machine is None:
return

if not isinstance(def_cpus_per_machine, int) or def_cpus_per_machine <= 0:
if not isinstance(def_cpus_per_machine, int) or def_cpus_per_machine <= 0: # type: ignore[redundant-expr]

Check warning on line 281 in src/aiida/orm/computers.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/computers.py#L281

Added line #L281 was not covered by tests
raise exceptions.ValidationError(
'Invalid value for default_mpiprocs_per_machine, must be a positive integer, or an empty string if you '
'do not want to provide a default value.'
Expand All @@ -290,7 +290,7 @@
if def_memory_per_machine is None:
return

if not isinstance(def_memory_per_machine, int) or def_memory_per_machine <= 0:
if not isinstance(def_memory_per_machine, int) or def_memory_per_machine <= 0: # type: ignore[redundant-expr]
raise exceptions.ValidationError(
f'Invalid value for def_memory_per_machine, must be a positive int, got: {def_memory_per_machine}'
)
Expand Down Expand Up @@ -487,7 +487,7 @@
"""Set the mpirun command. It must be a list of strings (you can use
string.split() if you have a single, space-separated string).
"""
if not isinstance(val, (tuple, list)) or not all(isinstance(i, str) for i in val):
if not isinstance(val, (tuple, list)) or not all(isinstance(i, str) for i in val): # type: ignore[redundant-expr]
raise TypeError('the mpirun_command must be a list of strings')
self.set_property('mpirun_command', val)

Expand Down
Loading
Loading