@@ -79,13 +79,13 @@ def fn_(obj: Any) -> None:
79
79
obj .detach_ ().requires_grad_ (requires_grad )
80
80
81
81
if isinstance (target , ModuleState ):
82
- true_target = cast (TensorTree , (target .params , target .buffers ))
82
+ true_target = cast (' TensorTree' , (target .params , target .buffers ))
83
83
elif isinstance (target , nn .Module ):
84
- true_target = cast (TensorTree , tuple (target .parameters ()))
84
+ true_target = cast (' TensorTree' , tuple (target .parameters ()))
85
85
elif isinstance (target , MetaOptimizer ):
86
- true_target = cast (TensorTree , target .state_dict ())
86
+ true_target = cast (' TensorTree' , target .state_dict ())
87
87
else :
88
- true_target = cast (TensorTree , target ) # tree of tensors
88
+ true_target = cast (' TensorTree' , target ) # tree of tensors
89
89
90
90
pytree .tree_map_ (fn_ , true_target )
91
91
@@ -325,7 +325,7 @@ def recover_state_dict(
325
325
from torchopt .optim .meta .base import MetaOptimizer
326
326
327
327
if isinstance (target , nn .Module ):
328
- params , buffers , * _ = state = cast (ModuleState , state )
328
+ params , buffers , * _ = state = cast (' ModuleState' , state )
329
329
params_containers , buffers_containers = extract_module_containers (target , with_buffers = True )
330
330
331
331
if state .detach_buffers :
@@ -343,7 +343,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
343
343
):
344
344
tgt .update (src )
345
345
elif isinstance (target , MetaOptimizer ):
346
- state = cast (Sequence [OptState ], state )
346
+ state = cast (' Sequence[OptState]' , state )
347
347
target .load_state_dict (state )
348
348
else :
349
349
raise TypeError (f'Unexpected class of { target } ' )
@@ -422,9 +422,9 @@ def module_clone( # noqa: C901
422
422
423
423
if isinstance (target , (nn .Module , MetaOptimizer )):
424
424
if isinstance (target , nn .Module ):
425
- containers = cast (TensorTree , extract_module_containers (target , with_buffers = True ))
425
+ containers = cast (' TensorTree' , extract_module_containers (target , with_buffers = True ))
426
426
else :
427
- containers = cast (TensorTree , target .state_dict ())
427
+ containers = cast (' TensorTree' , target .state_dict ())
428
428
tensors = pytree .tree_leaves (containers )
429
429
memo = {id (t ): t for t in tensors }
430
430
cloned = copy .deepcopy (target , memo = memo )
@@ -476,7 +476,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
476
476
else :
477
477
replicate = clone_detach_
478
478
479
- return pytree .tree_map (replicate , cast (TensorTree , target ))
479
+ return pytree .tree_map (replicate , cast (' TensorTree' , target ))
480
480
481
481
482
482
@overload
0 commit comments