Skip to content

Commit 2addd85

Browse files
committed
suppress TypedStorage warning during nequip-package
1 parent 9d5e76a commit 2addd85

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

nequip/model/saved_models/package.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def _get_package_metadata(imp) -> Dict[str, Any]:
5151

5252

5353
@contextlib.contextmanager
54-
def _suppress_package_importer_warnings():
54+
def _suppress_package_importer_exporter_warnings():
5555
# Ideally this ceases to exist or becomes a no-op in future versions of PyTorch
5656
with warnings.catch_warnings():
5757
# suppress torch.package TypedStorage warning
5858
warnings.filterwarnings(
5959
"ignore",
6060
message="TypedStorage is deprecated.*",
6161
category=UserWarning,
62-
module="torch.package.package_importer",
62+
module=r"torch\.package\.(package_exporter|package_importer)",
6363
)
6464
yield
6565

@@ -104,7 +104,7 @@ def ModelFromPackage(package_path: str, compile_mode: str = _EAGER_MODEL_KEY):
104104

105105
# === load model ===
106106
logger.info(f"Loading model from package file: {package_path} ...")
107-
with _suppress_package_importer_warnings():
107+
with _suppress_package_importer_exporter_warnings():
108108
# during `nequip-package`, we need to use the same importer for all the models for successful repackaging
109109
# see https://pytorch.org/docs/stable/package.html#re-export-an-imported-object
110110
if workflow_state == "package":
@@ -141,7 +141,7 @@ def ModelFromPackage(package_path: str, compile_mode: str = _EAGER_MODEL_KEY):
141141

142142
def data_dict_from_package(package_path: str) -> AtomicDataDict.Type:
143143
"""Load example data from a .nequip.zip package file."""
144-
with _suppress_package_importer_warnings():
144+
with _suppress_package_importer_exporter_warnings():
145145
imp = torch.package.PackageImporter(package_path)
146146
data = imp.load_pickle(package="model", resource="example_data.pkl")
147147
return data
@@ -159,7 +159,7 @@ def ModelTypeNamesFromPackage(package_path: str):
159159

160160
_check_file_exists(file_path=package_path, file_type="package")
161161

162-
with _suppress_package_importer_warnings():
162+
with _suppress_package_importer_exporter_warnings():
163163
imp = torch.package.PackageImporter(package_path)
164164
pkg_metadata = _get_package_metadata(imp)
165165

nequip/scripts/package.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from nequip.model.saved_models.package import (
66
_get_shared_importer,
77
_get_package_metadata,
8-
_suppress_package_importer_warnings,
8+
_suppress_package_importer_exporter_warnings,
99
)
1010
from nequip.model.saved_models import load_saved_model
1111
from nequip.model.utils import (
@@ -100,7 +100,7 @@ def main(args=None):
100100
"packed model file to inspect must end with the `.nequip.zip` extension"
101101
)
102102

103-
with _suppress_package_importer_warnings():
103+
with _suppress_package_importer_exporter_warnings():
104104
imp = torch.package.PackageImporter(args.pkg_path)
105105
pkg_metadata = _get_package_metadata(imp)
106106

@@ -250,7 +250,7 @@ def main(args=None):
250250
models_to_package.update({compile_mode: model})
251251

252252
# == package ==
253-
with _suppress_package_importer_warnings():
253+
with _suppress_package_importer_exporter_warnings():
254254
with torch.package.PackageExporter(
255255
args.output_path, importer=importers, debug=True
256256
) as exp:

0 commit comments

Comments
 (0)