Skip to content

Commit d35a91f

Browse files
fix: [pre-commit.ci] auto fixes [...]
1 parent 9d9e309 commit d35a91f

10 files changed

+5562
-2925
lines changed

torchopt/distributed/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,12 +318,12 @@ def remote_async_call(
318318
futures.append(fut)
319319

320320
future = cast(
321-
Future[List[T]],
321+
'Future[List[T]]',
322322
torch.futures.collect_all(futures).then(lambda fut: [f.wait() for f in fut.wait()]),
323323
)
324324
if reducer is not None:
325325
return cast(
326-
Future[U],
326+
'Future[U]',
327327
future.then(lambda fut: reducer(fut.wait())),
328328
)
329329
return future

torchopt/nn/stateless.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def reparametrize(
8484
module: nn.Module,
8585
named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]],
8686
allow_missing: bool = False,
87-
) -> Generator[nn.Module, None, None]:
87+
) -> Generator[nn.Module]:
8888
"""Reparameterize the module parameters and/or buffers."""
8989
if not isinstance(named_tensors, dict):
9090
named_tensors = dict(named_tensors)

torchopt/utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,13 @@ def fn_(obj: Any) -> None:
7979
obj.detach_().requires_grad_(requires_grad)
8080

8181
if isinstance(target, ModuleState):
82-
true_target = cast(TensorTree, (target.params, target.buffers))
82+
true_target = cast('TensorTree', (target.params, target.buffers))
8383
elif isinstance(target, nn.Module):
84-
true_target = cast(TensorTree, tuple(target.parameters()))
84+
true_target = cast('TensorTree', tuple(target.parameters()))
8585
elif isinstance(target, MetaOptimizer):
86-
true_target = cast(TensorTree, target.state_dict())
86+
true_target = cast('TensorTree', target.state_dict())
8787
else:
88-
true_target = cast(TensorTree, target) # tree of tensors
88+
true_target = cast('TensorTree', target) # tree of tensors
8989

9090
pytree.tree_map_(fn_, true_target)
9191

@@ -325,7 +325,7 @@ def recover_state_dict(
325325
from torchopt.optim.meta.base import MetaOptimizer
326326

327327
if isinstance(target, nn.Module):
328-
params, buffers, *_ = state = cast(ModuleState, state)
328+
params, buffers, *_ = state = cast('ModuleState', state)
329329
params_containers, buffers_containers = extract_module_containers(target, with_buffers=True)
330330

331331
if state.detach_buffers:
@@ -343,7 +343,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
343343
):
344344
tgt.update(src)
345345
elif isinstance(target, MetaOptimizer):
346-
state = cast(Sequence[OptState], state)
346+
state = cast('Sequence[OptState]', state)
347347
target.load_state_dict(state)
348348
else:
349349
raise TypeError(f'Unexpected class of {target}')
@@ -422,9 +422,9 @@ def module_clone( # noqa: C901
422422

423423
if isinstance(target, (nn.Module, MetaOptimizer)):
424424
if isinstance(target, nn.Module):
425-
containers = cast(TensorTree, extract_module_containers(target, with_buffers=True))
425+
containers = cast('TensorTree', extract_module_containers(target, with_buffers=True))
426426
else:
427-
containers = cast(TensorTree, target.state_dict())
427+
containers = cast('TensorTree', target.state_dict())
428428
tensors = pytree.tree_leaves(containers)
429429
memo = {id(t): t for t in tensors}
430430
cloned = copy.deepcopy(target, memo=memo)
@@ -476,7 +476,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
476476
else:
477477
replicate = clone_detach_
478478

479-
return pytree.tree_map(replicate, cast(TensorTree, target))
479+
return pytree.tree_map(replicate, cast('TensorTree', target))
480480

481481

482482
@overload

torchopt/visual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def make_dot( # noqa: C901
129129
elif isinstance(param, Generator):
130130
param_map.update({v: k for k, v in param})
131131
else:
132-
param_map.update({v: k for k, v in cast(Mapping, param).items()})
132+
param_map.update({v: k for k, v in cast('Mapping', param).items()})
133133

134134
node_attr = {
135135
'style': 'filled',

0 commit comments

Comments
 (0)