Skip to content

Add typing to aiida/common/{utils,extendeddicts}.py #6706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: update-mypy
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ repos:
src/aiida/cmdline/utils/ascii_vis.py|
src/aiida/cmdline/utils/common.py|
src/aiida/cmdline/utils/echo.py|
src/aiida/common/extendeddicts.py|
src/aiida/common/utils.py|
src/aiida/engine/daemon/execmanager.py|
src/aiida/engine/processes/calcjobs/manager.py|
src/aiida/engine/processes/calcjobs/monitors.py|
Expand All @@ -126,10 +124,6 @@ repos:
src/aiida/manage/configuration/__init__.py|
src/aiida/manage/configuration/config.py|
src/aiida/manage/external/rmq/launcher.py|
src/aiida/orm/comments.py|
src/aiida/orm/computers.py|
src/aiida/orm/implementation/storage_backend.py|
src/aiida/orm/nodes/comments.py|
src/aiida/orm/nodes/data/array/bands.py|
src/aiida/orm/nodes/data/array/trajectory.py|
src/aiida/orm/nodes/data/cif.py|
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,10 @@ select = [
'RUF' # ruff
]

# https://github.com/astral-sh/ruff/issues/9298
[tool.ruff.lint.pyflakes]
extend-generics = ["aiida.orm.entities.Entity", "aiida.orm.entities.Collection"]

[tool.tox]
legacy_tox_ini = """
[tox]
Expand Down
19 changes: 10 additions & 9 deletions src/aiida/common/extendeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""Various dictionary types with extended functionality."""

from collections.abc import Mapping
from typing import Optional

from . import exceptions

Expand All @@ -25,7 +26,7 @@ class AttributeDict(dict):
used.
"""

def __init__(self, dictionary=None):
def __init__(self, dictionary: Optional[Mapping] = None):
"""Recursively turn the `dict` and all its nested dictionaries into `AttributeDict` instance."""
super().__init__()
if dictionary is None:
Expand All @@ -37,7 +38,7 @@ def __init__(self, dictionary=None):
else:
self[key] = value

def __repr__(self):
def __repr__(self) -> str:
"""Representation of the object."""
return f'{self.__class__.__name__}({dict.__repr__(self)})'

Expand Down Expand Up @@ -104,7 +105,7 @@ class TestExample(FixedFieldsAttributeDict):
_valid_fields = ('a','b','c')
"""

_valid_fields = tuple()
_valid_fields: tuple = tuple()

def __init__(self, init=None):
if init is None:
Expand All @@ -131,11 +132,11 @@ def __setattr__(self, attr, value):
super().__setattr__(attr, value)

@classmethod
def get_valid_fields(cls):
def get_valid_fields(cls) -> tuple:
"""Return the list of valid fields."""
return cls._valid_fields

def __dir__(self):
def __dir__(self) -> list:
return list(self._valid_fields)


Expand Down Expand Up @@ -192,7 +193,7 @@ class TestExample(DefaultFieldsAttributeDict):
See if we want that setting a default field to None means deleting it.
"""

_default_fields = tuple()
_default_fields: tuple = tuple()

def validate(self):
"""Validate the keys, if any ``validate_*`` method is available."""
Expand Down Expand Up @@ -225,14 +226,14 @@ def __getitem__(self, key):
raise

@classmethod
def get_default_fields(cls):
def get_default_fields(cls) -> list:
"""Return the list of default fields, either defined in the instance or not."""
return list(cls._default_fields)

def defaultkeys(self):
def defaultkeys(self) -> list:
"""Return the default keys defined in the instance."""
return [_ for _ in self.keys() if _ in self._default_fields]

def extrakeys(self):
def extrakeys(self) -> list:
"""Return the extra keys defined in the instance."""
return [_ for _ in self.keys() if _ not in self._default_fields]
55 changes: 29 additions & 26 deletions src/aiida/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
import os
import re
import sys
from datetime import datetime
from typing import Any, Dict
from collections.abc import Iterable
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
from uuid import UUID

from .lang import classproperty
Expand All @@ -41,7 +42,7 @@
return str(parsed_uuid) == given_uuid


def validate_list_of_string_tuples(val, tuple_length):
def validate_list_of_string_tuples(val: Any, tuple_length: int) -> bool:
"""Check that:

1. ``val`` is a list or tuple
Expand Down Expand Up @@ -75,7 +76,7 @@
return True


def get_unique_filename(filename, list_of_filenames):
def get_unique_filename(filename: str, list_of_filenames: 'list | tuple') -> str:
"""Return a unique filename that can be added to the list_of_filenames.

If filename is not in list_of_filenames, it simply returns the filename
Expand Down Expand Up @@ -104,7 +105,7 @@
return new_filename


def str_timedelta(dt, max_num_fields=3, short=False, negative_to_zero=False):
def str_timedelta(dt: timedelta, max_num_fields: int = 3, short: bool = False, negative_to_zero: bool = False) -> str:
"""Given a dt in seconds, return it in a HH:MM:SS format.

:param dt: a TimeDelta object
Expand Down Expand Up @@ -170,7 +171,7 @@
return f'{raw_string}{negative_string}'


def get_class_string(obj):
def get_class_string(obj: Any) -> str:
"""Return the string identifying the class of the object (module + object name,
joined by dots).

Expand All @@ -182,7 +183,7 @@
return f'{obj.__module__}.{obj.__class__.__name__}'


def get_object_from_string(class_string):
def get_object_from_string(class_string: str) -> Any:
"""Given a string identifying an object (as returned by the get_class_string
method) load and return the actual object.
"""
Expand All @@ -193,7 +194,7 @@
return getattr(importlib.import_module(the_module), the_name)


def grouper(n, iterable):
def grouper(n: int, iterable: Iterable) -> Iterable:
"""Given an iterable, returns an iterable that returns tuples of groups of
elements from iterable of length n, except the last one that has the
required length to exaust iterable (i.e., there is no filling applied).
Expand Down Expand Up @@ -223,11 +224,11 @@
self.seq = -1

def array_counter(self):
self.seq += 1
self.seq += 1 # type: ignore[operator]

Check warning on line 227 in src/aiida/common/utils.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/common/utils.py#L227

Added line #L227 was not covered by tests
return self.seq


def are_dir_trees_equal(dir1, dir2):
def are_dir_trees_equal(dir1: str, dir2: str) -> tuple[bool, str]:
"""Compare two directories recursively. Files in each directory are
assumed to be equal if their names and contents are equal.

Expand Down Expand Up @@ -376,14 +377,14 @@
}

@classmethod
def get_prettifiers(cls):
def get_prettifiers(cls) -> list[str]:
"""Return a list of valid prettifier strings

:return: a list of strings
"""
return sorted(cls.prettifiers.keys())

def __init__(self, format):
def __init__(self, format: Optional[str]):
"""Create a class to pretttify strings of a given format

:param format: a string with the format to use to prettify.
Expand All @@ -397,7 +398,7 @@
except KeyError:
raise ValueError(f"Unknown prettifier format {format}; valid formats: {', '.join(self.get_prettifiers())}")

def prettify(self, label):
def prettify(self, label: str) -> str:
"""Prettify a label using the format passed in the initializer

:param label: the string to prettify
Expand All @@ -406,7 +407,7 @@
return self._prettifier_f(label)


def prettify_labels(labels, format=None):
def prettify_labels(labels: list, format: Optional[str] = None) -> list:
"""Prettify label for typesetting in various formats

:param labels: a list of length-2 tuples, in the format(position, label)
Expand All @@ -420,7 +421,7 @@
return [(pos, prettifier.prettify(label)) for pos, label in labels]


def join_labels(labels, join_symbol='|', threshold=1.0e-6):
def join_labels(labels: list, join_symbol: str = '|', threshold: float = 1.0e-6):
"""Join labels with a joining symbol when they are very close

:param labels: a list of length-2 tuples, in the format(position, label)
Expand Down Expand Up @@ -466,14 +467,14 @@
lines, use obj.stderr_lines. If False, obj.stderr_lines is None.
"""

def __init__(self, capture_stderr=False):
def __init__(self, capture_stderr: bool = False):
"""Construct a new instance."""
self.stdout_lines = []
self.stdout_lines: list[str] = []
super().__init__()

self._capture_stderr = capture_stderr
if self._capture_stderr:
self.stderr_lines = []
self.stderr_lines: Optional[list] = []

Check warning on line 477 in src/aiida/common/utils.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/common/utils.py#L477

Added line #L477 was not covered by tests
else:
self.stderr_lines = None

Expand All @@ -494,7 +495,8 @@
sys.stdout = self._stdout
del self._stringioout # free up some memory
if self._capture_stderr:
self.stderr_lines.extend(self._stringioerr.getvalue().splitlines())
# NOTE: mypy is not clever enough to now that when we're here, self.stderr_lines is not None
self.stderr_lines.extend(self._stringioerr.getvalue().splitlines()) # type: ignore[union-attr]

Check warning on line 499 in src/aiida/common/utils.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/common/utils.py#L499

Added line #L499 was not covered by tests
sys.stderr = self._stderr
del self._stringioerr # free up some memory

Expand All @@ -517,23 +519,23 @@

def __init__(self, *error_cls):
self.error_cls = error_cls
self.errors = {k: [] for k in self.error_cls}
self.errors: dict[type, list] = {k: [] for k in self.error_cls}

def run(self, function, *args, **kwargs):
try:
function(*args, **kwargs)
except self.error_cls as err:
self.errors[err.__class__].append(err)

def success(self):
def success(self) -> bool:
return bool(not any(self.errors.values()))

def result(self, raise_error=Exception):
if raise_error:
self.raise_errors(raise_error)
return self.success(), self.errors

def raise_errors(self, raise_cls):
def raise_errors(self, raise_cls: type[Exception]) -> None:
if not self.success():
raise raise_cls(f'The following errors were encountered: {self.errors}')

Expand All @@ -548,7 +550,7 @@
4 (dare + hour + minute +second)
"""

def __init__(self, dtobj, precision):
def __init__(self, dtobj: datetime, precision: int):
"""Constructor to check valid datetime object and precision"""
if not isinstance(dtobj, datetime):
raise TypeError('dtobj argument has to be a datetime object')
Expand Down Expand Up @@ -583,9 +585,10 @@
factor = 1024 # 1 KB = 1024 B
index = 0

while size_in_bytes >= factor and index < len(prefixes) - 1:
size_in_bytes /= factor
converted_size: float = size_in_bytes
while converted_size >= factor and index < len(prefixes) - 1:
converted_size /= factor
index += 1

# Format the size to two decimal places
return f'{size_in_bytes:.2f} {prefixes[index]}'
return f'{converted_size:.2f} {prefixes[index]}'
2 changes: 1 addition & 1 deletion src/aiida/orm/authinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

if TYPE_CHECKING:
from aiida.orm.implementation import StorageBackend
from aiida.orm.implementation.authinfos import BackendAuthInfo # noqa: F401
from aiida.orm.implementation.authinfos import BackendAuthInfo

Check warning on line 26 in src/aiida/orm/authinfos.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/authinfos.py#L26

Added line #L26 was not covered by tests
from aiida.transports import Transport

__all__ = ('AuthInfo',)
Expand Down
11 changes: 6 additions & 5 deletions src/aiida/orm/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from . import entities

if TYPE_CHECKING:
from aiida.orm.implementation import StorageBackend
from aiida.orm.implementation import BackendComment, BackendNode, StorageBackend

Check warning on line 20 in src/aiida/orm/comments.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/comments.py#L20

Added line #L20 was not covered by tests

from .nodes.node import Node
from .users import User
Expand Down Expand Up @@ -81,13 +81,13 @@
description='Node PK that the comment is attached to',
is_attribute=False,
orm_class='core.node',
orm_to_model=lambda comment, _: comment.node.pk,
orm_to_model=lambda comment, _: comment.node.pk, # type: ignore[attr-defined]
)
user: int = MetadataField(
description='User PK that created the comment',
is_attribute=False,
orm_class='core.user',
orm_to_model=lambda comment, _: comment.user.pk,
orm_to_model=lambda comment, _: comment.user.pk, # type: ignore[attr-defined]
)
content: str = MetadataField(description='Content of the comment', is_attribute=False)

Expand Down Expand Up @@ -139,7 +139,7 @@
return self._backend_entity.set_mtime(value)

@property
def node(self) -> 'Node':
def node(self) -> 'BackendNode':
return self._backend_entity.node

@property
Expand All @@ -149,7 +149,8 @@
return entities.from_backend_entity(User, self._backend_entity.user)

def set_user(self, value: 'User') -> None:
self._backend_entity.user = value.backend_entity
# ignoring mypy error: Property "user" defined in "BackendComment" is read-only
self._backend_entity.user = value.backend_entity # type: ignore[misc]

Check warning on line 153 in src/aiida/orm/comments.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/comments.py#L153

Added line #L153 was not covered by tests

@property
def content(self) -> str:
Expand Down
Loading
Loading