diff --git a/ivy/functional/frontends/torch/nn/modules/module.py b/ivy/functional/frontends/torch/nn/modules/module.py index b18a2f212b021..1cc8e668db55f 100644 --- a/ivy/functional/frontends/torch/nn/modules/module.py +++ b/ivy/functional/frontends/torch/nn/modules/module.py @@ -2,6 +2,7 @@ import ivy from collections import OrderedDict from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Callable +import threading # local from ivy.functional.frontends.torch.nn.parameter import Parameter @@ -368,7 +369,11 @@ def __dir__(self): def __getstate__(self): state = self.__dict__.copy() state.pop("_compiled_call_impl", None) + state.pop("_thread_local", None) + state.pop("_metrics_lock", None) return state def __setstate__(self, state): + state["_thread_local"] = threading.local() + state["_metrics_lock"] = threading.Lock() self.__dict__.update(state)