Skip to content

Commit

Permalink
chore(pre-commit): [pre-commit.ci] autoupdate (#187)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xuehai Pan <[email protected]>
  • Loading branch information
pre-commit-ci[bot] and XuehaiPan authored Aug 16, 2023
1 parent e9cde22 commit ae8f96e
Show file tree
Hide file tree
Showing 9 changed files with 18 additions and 18 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ ci:
autofix_prs: true
autofix_commit_msg: "fix: [pre-commit.ci] auto fixes [...]"
autoupdate_commit_msg: "chore(pre-commit): [pre-commit.ci] autoupdate"
autoupdate_schedule: monthly
default_stages: [commit, push, manual]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand All @@ -29,7 +30,7 @@ repos:
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.282
rev: v0.0.284
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
5 changes: 3 additions & 2 deletions torchopt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import itertools
from abc import abstractmethod
from typing import TYPE_CHECKING, Callable, NamedTuple, Protocol
from typing_extensions import Self # Python 3.11+


if TYPE_CHECKING:
Expand Down Expand Up @@ -159,7 +160,7 @@ class ChainedGradientTransformation(GradientTransformation):

transformations: tuple[GradientTransformation, ...]

def __new__(cls, *transformations: GradientTransformation) -> ChainedGradientTransformation:
def __new__(cls, *transformations: GradientTransformation) -> Self:
"""Create a new chained gradient transformation."""
transformations = tuple(
itertools.chain.from_iterable(
Expand Down Expand Up @@ -235,7 +236,7 @@ def __reduce__(self) -> tuple[Callable, tuple[tuple[GradientTransformation, ...]
class IdentityGradientTransformation(GradientTransformation):
"""A gradient transformation that does nothing."""

def __new__(cls) -> IdentityGradientTransformation:
def __new__(cls) -> Self:
"""Create a new gradient transformation that does nothing."""
return super().__new__(cls, init=cls.init_fn, update=cls.update_fn)

Expand Down
10 changes: 5 additions & 5 deletions torchopt/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@


def clip_grad_norm(
max_norm: float | int,
norm_type: float | int = 2.0,
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
) -> GradientTransformation:
"""Clip gradient norm of an iterable of parameters.
Args:
max_norm (float or int): The maximum absolute value for each element in the update.
norm_type (float or int, optional): Type of the used p-norm. Can be ``'inf'`` for infinity
norm. (default: :const:`2.0`)
max_norm (float): The maximum absolute value for each element in the update.
norm_type (float, optional): Type of the used p-norm. Can be ``'inf'`` for infinity norm.
(default: :const:`2.0`)
error_if_nonfinite (bool, optional): If :data:`True`, an error is thrown if the total norm
of the gradients from ``updates`` is ``nan``, ``inf``, or ``-inf``.
(default: :data:`False`)
Expand Down
4 changes: 2 additions & 2 deletions torchopt/diff/implicit/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def _custom_root(
def make_custom_vjp_solver_fn(
solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]],
kwarg_keys: Sequence[str],
args_signs: tuple[tuple[int, int, type[tuple] | type[list] | None], ...],
args_signs: tuple[tuple[int, int, type[tuple | list] | None], ...],
) -> type[Function]:
# pylint: disable-next=missing-class-docstring,abstract-method
class ImplicitMetaGradient(Function):
Expand Down Expand Up @@ -396,7 +396,7 @@ def wrapped_solver_fn(
args, kwargs = _signature_bind(solver_fn_signature, *args, **kwargs)
keys, vals = list(kwargs.keys()), list(kwargs.values())

args_signs: list[tuple[int, int, type[tuple] | type[list] | None]] = []
args_signs: list[tuple[int, int, type[tuple | list] | None]] = []
flat_args: list[Any] = []
args_offset = 0
for idx, arg in enumerate(args):
Expand Down
4 changes: 2 additions & 2 deletions torchopt/distributed/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def is_available() -> bool:

if is_available():
# pylint: disable-next=unused-import,ungrouped-imports
from torch.distributed.autograd import DistAutogradContext, get_gradients # noqa: F401
from torch.distributed.autograd import DistAutogradContext, get_gradients

def backward(
autograd_ctx_id: int,
Expand Down Expand Up @@ -131,4 +131,4 @@ def grad(

return tuple(grads)

__all__.extend(['DistAutogradContext', 'get_gradients', 'backward', 'grad'])
__all__ += ['DistAutogradContext', 'get_gradients', 'backward', 'grad']
3 changes: 2 additions & 1 deletion torchopt/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from collections import OrderedDict
from typing import Any, Iterator, NamedTuple
from typing_extensions import Self # Python 3.11+

import torch
import torch.nn as nn
Expand All @@ -40,7 +41,7 @@ class MetaGradientModule(nn.Module): # pylint: disable=abstract-method
_meta_parameters: TensorContainer
_meta_modules: dict[str, nn.Module | None]

def __new__(cls, *args: Any, **kwargs: Any) -> MetaGradientModule:
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
"""Create a new module instance."""
instance = super().__new__(cls)
flat_args: list[Any]
Expand Down
2 changes: 1 addition & 1 deletion torchopt/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def tree_local_value(rref_tree: PyTree[RRef[T]]) -> PyTree[T]:
r"""Return the local value of a tree of :class:`RRef`\s."""
return tree_map(lambda x: x.local_value(), rref_tree)

__all__.extend(['tree_as_rref', 'tree_to_here'])
__all__ += ['tree_as_rref', 'tree_to_here']


del optree, rpc
2 changes: 1 addition & 1 deletion torchopt/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
if rpc.is_available(): # pragma: no cover
from torch.distributed.rpc import RRef # pylint: disable=ungrouped-imports,unused-import

__all__.extend(['RRef'])
__all__ += ['RRef']
else: # pragma: no cover
# pylint: disable-next=invalid-name
RRef = None # type: ignore[misc,assignment]
Expand Down
3 changes: 0 additions & 3 deletions torchopt/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from __future__ import annotations

from collections import namedtuple
from typing import Any, Generator, Iterable, Mapping, cast

import torch
Expand All @@ -33,8 +32,6 @@
__all__ = ['make_dot', 'resize_graph']


Node = namedtuple('Node', ('name', 'inputs', 'attr', 'op'))

# Saved attrs for grad_fn (incl. saved variables) begin with `._saved_*`
SAVED_PREFIX = '_saved_'

Expand Down

0 comments on commit ae8f96e

Please sign in to comment.