Skip to content

Commit a5701d5

Browse files
Merge pull request #281 from matchms/remove_duplications
Refactor load_model to use factories for settings and model
2 parents c237507 + 25a840e commit a5701d5

File tree

1 file changed

+62
-115
lines changed

1 file changed

+62
-115
lines changed

ms2deepscore/models/load_model.py

Lines changed: 62 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union, Dict, Optional, Any
1+
from typing import Union, Dict, Optional, Any, Callable
22
import json
33
from pathlib import Path
44
import warnings
@@ -12,19 +12,17 @@
1212
SettingsEmbeddingEvaluator,
1313
SettingsMS2Deepscore,
1414
)
15-
from ms2deepscore.models.io_utils import _settings_to_json # re-use from your module where save() lives
15+
from ms2deepscore.models.io_utils import _settings_to_json
1616

1717

1818
# ---------- internal helpers ----------
1919

2020
def _torch_device() -> torch.device:
2121
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
2222

23+
2324
def _load_ckpt_safe(filename: Union[str, Path]) -> Dict[str, Any]:
24-
"""
25-
Load a checkpoint using PyTorch's restricted unpickler (weights_only=True).
26-
The file must contain only tensors and simple Python primitives.
27-
"""
25+
"""Load a checkpoint using PyTorch's restricted unpickler (weights_only=True)."""
2826
try:
2927
return torch.load(str(filename), map_location=_torch_device(), weights_only=True)
3028
except TypeError:
@@ -34,13 +32,13 @@ def _load_ckpt_safe(filename: Union[str, Path]) -> Dict[str, Any]:
3432
raise
3533
return ckpt
3634

35+
3736
def _extract_settings_dict(ckpt: Dict[str, Any]) -> Dict[str, Any]:
3837
"""
39-
Support current format and a couple of earlier safe variants.
40-
Priority:
41-
1) settings_json (preferred, current format)
42-
2) model_params_json (earlier safe format)
43-
3) model_params (earlier safe format with a plain dict)
38+
Support current format and earlier safe variants:
39+
- settings_json (preferred)
40+
- model_params_json (older)
41+
- model_params (older plain dict)
4442
"""
4543
if "settings_json" in ckpt:
4644
return json.loads(ckpt["settings_json"])
@@ -55,23 +53,15 @@ def _extract_settings_dict(ckpt: Dict[str, Any]) -> Dict[str, Any]:
5553
"No settings found. Expected 'settings_json' (preferred) or 'model_params_json' / 'model_params'."
5654
)
5755

56+
5857
def _extract_state_dict(ckpt: Dict[str, Any]) -> Dict[str, torch.Tensor]:
59-
"""
60-
Support current key 'state_dict' and older 'model_state_dict'.
61-
"""
58+
"""Support current key 'state_dict' and older 'model_state_dict'."""
6259
if "state_dict" in ckpt:
6360
return ckpt["state_dict"]
6461
if "model_state_dict" in ckpt:
6562
return ckpt["model_state_dict"]
6663
raise KeyError("No weights found. Expected 'state_dict' or 'model_state_dict'.")
6764

68-
def _maybe_warn_version(ckpt: Dict[str, Any]) -> None:
69-
v = ckpt.get("ms2deepscore_version") or ckpt.get("version")
70-
if v and v != __version__:
71-
warnings.warn(
72-
f"Model was saved with ms2deepscore {v}, but you're running {__version__}. "
73-
"Consider updating either the model or the library if you hit incompatibilities."
74-
)
7565

7666
def _convert_legacy_if_requested(
7767
obj: Any,
@@ -89,10 +79,8 @@ def _convert_legacy_if_requested(
8979
out_path = Path(convert_path)
9080
out_path.parent.mkdir(parents=True, exist_ok=True)
9181

92-
# Two conversion paths:
9382
# (1) nn.Module with .state_dict and .model_settings
9483
if isinstance(obj, torch.nn.Module) and hasattr(obj, "state_dict") and hasattr(obj, "model_settings"):
95-
# Build a minimal safe checkpoint (match the new save())
9684
safe_ckpt = {
9785
"format": "ms2deepscore.safe.v1",
9886
"ms2deepscore_version": getattr(obj, "version", __version__),
@@ -105,12 +93,10 @@ def _convert_legacy_if_requested(
10593

10694
# (2) dict-like legacy checkpoint with params + state_dict
10795
if isinstance(obj, dict) and ("model_state_dict" in obj or "state_dict" in obj):
108-
# Try to normalize to the new shape
10996
try:
11097
params = _extract_settings_dict(obj)
11198
except Exception:
11299
params = obj.get("model_params", {})
113-
# JSON encode params to ensure safety
114100
settings_json = json.dumps(params, ensure_ascii=False, sort_keys=True)
115101

116102
state_dict = _extract_state_dict(obj)
@@ -124,55 +110,37 @@ def _convert_legacy_if_requested(
124110
torch.save(safe_ckpt, str(out_path))
125111
return out_path
126112

127-
# Otherwise we don't know how to convert
128113
return None
129114

130115

131-
# ---------- public API ----------
116+
# ---------- generic loader ----------
132117

133-
def load_model(
118+
def _load_model_generic(
134119
filename: Union[str, Path],
135120
*,
136-
allow_legacy: bool = False,
137-
convert_legacy_to: Optional[Union[str, Path]] = None,
138-
) -> SiameseSpectralModel:
121+
allow_legacy: bool,
122+
convert_legacy_to: Optional[Union[str, Path]],
123+
settings_factory: Callable[[Dict[str, Any]], Any],
124+
model_factory: Callable[[Any], torch.nn.Module],
125+
) -> torch.nn.Module:
139126
"""
140-
Load a SiameseSpectralModel.
141-
142-
Normal path:
127+
Shared loader:
143128
1) Safe-load checkpoint (weights_only=True).
144-
2) Parse settings_json -> SettingsMS2Deepscore.
145-
3) Instantiate SiameseSpectralModel(settings=...) then load state_dict.
146-
147-
Legacy path (only if allow_legacy=True):
148-
- Attempt torch.load(weights_only=False) and either:
149-
* return the nn.Module directly (if saved whole), or
150-
* adapt a dict-like legacy checkpoint to the new format.
151-
- If convert_legacy_to is given, also write a converted, safe checkpoint.
152-
153-
Parameters
154-
----------
155-
filename : str | Path
156-
allow_legacy : bool, default False
157-
Permit unsafe pickle loading for truly old files. Only use for trusted sources.
158-
convert_legacy_to : Optional[str | Path]
159-
If given and a legacy artifact is loaded, write an equivalent safe checkpoint here.
129+
2) Parse settings -> create Settings via settings_factory(params).
130+
3) Instantiate model via model_factory(settings), then load state_dict.
131+
4) If safe path fails and allow_legacy=True, attempt unsafe legacy path.
160132
"""
161-
device = _torch_device()
162-
163133
# --- preferred safe path
164134
try:
165135
ckpt = _load_ckpt_safe(filename)
166-
_maybe_warn_version(ckpt)
167136
params = _extract_settings_dict(ckpt)
168137
state_dict = _extract_state_dict(ckpt)
169138

170-
settings = SettingsMS2Deepscore(**params, validate_settings=False)
171-
model = SiameseSpectralModel(settings=settings)
139+
settings = settings_factory(params)
140+
model = model_factory(settings)
172141
model.load_state_dict(state_dict)
173142
model.eval()
174143
return model
175-
176144
except Exception as safe_err:
177145
if not allow_legacy:
178146
raise RuntimeError(
@@ -184,24 +152,24 @@ def load_model(
184152
"Using UNSAFE legacy loading (weights_only=False). Only do this for trusted files.",
185153
RuntimeWarning,
186154
)
187-
legacy_obj = torch.load(str(filename), map_location=device, weights_only=False)
155+
legacy_obj = torch.load(str(filename), map_location=_torch_device(), weights_only=False)
188156

189157
# If the whole module was saved, just use it.
190158
if isinstance(legacy_obj, torch.nn.Module):
191159
legacy_obj.eval()
192160
_convert_legacy_if_requested(legacy_obj, convert_path=convert_legacy_to)
193161
return legacy_obj
194162

195-
# If it looks like a legacy dict checkpoint, normalize and build the model
163+
# Dict-like legacy checkpoint: normalize and build the model
196164
if isinstance(legacy_obj, dict):
197165
try:
198166
params = _extract_settings_dict(legacy_obj)
199167
state_dict = _extract_state_dict(legacy_obj)
200168
except Exception as err:
201169
raise TypeError("Unrecognized legacy checkpoint structure.") from err
202170

203-
settings = SettingsMS2Deepscore(**params, validate_settings=False)
204-
model = SiameseSpectralModel(settings=settings)
171+
settings = settings_factory(params)
172+
model = model_factory(settings)
205173
model.load_state_dict(state_dict)
206174
model.eval()
207175

@@ -211,75 +179,54 @@ def load_model(
211179
raise TypeError("Legacy artifact is neither a torch.nn.Module nor a compatible dict checkpoint.")
212180

213181

214-
def load_embedding_evaluator(
182+
# ---------- public API ----------
183+
184+
def load_model(
215185
filename: Union[str, Path],
216186
*,
217187
allow_legacy: bool = False,
218188
convert_legacy_to: Optional[Union[str, Path]] = None,
219-
) -> EmbeddingEvaluationModel:
189+
) -> SiameseSpectralModel:
220190
"""
221-
Load an EmbeddingEvaluationModel with the same safe-first, legacy-optional policy.
191+
Load a SiameseSpectralModel (safe-first, legacy-optional).
222192
"""
223-
device = _torch_device()
224-
225-
# --- preferred safe path
226-
try:
227-
ckpt = _load_ckpt_safe(filename)
228-
_maybe_warn_version(ckpt)
229-
params = _extract_settings_dict(ckpt)
230-
state_dict = _extract_state_dict(ckpt)
231-
232-
settings = SettingsEmbeddingEvaluator(**params)
233-
model = EmbeddingEvaluationModel(settings=settings)
234-
model.load_state_dict(state_dict)
235-
model.eval()
236-
return model
237-
238-
except Exception as safe_err:
239-
if not allow_legacy:
240-
raise RuntimeError(
241-
"Failed to load safely. If this is a trusted legacy file, call with allow_legacy=True."
242-
) from safe_err
193+
return _load_model_generic(
194+
filename,
195+
allow_legacy=allow_legacy,
196+
convert_legacy_to=convert_legacy_to,
197+
settings_factory=lambda params: SettingsMS2Deepscore(
198+
**params, validate_settings=False
199+
),
200+
model_factory=lambda settings: SiameseSpectralModel(settings=settings),
201+
) # type: ignore[return-value]
243202

244-
# --- legacy fallback (unsafe; only for trusted files)
245-
warnings.warn(
246-
"Using UNSAFE legacy loading (weights_only=False). Only do this for trusted files.",
247-
RuntimeWarning,
248-
)
249-
legacy_obj = torch.load(str(filename), map_location=device, weights_only=False)
250-
251-
if isinstance(legacy_obj, torch.nn.Module):
252-
legacy_obj.eval()
253-
_convert_legacy_if_requested(legacy_obj, convert_path=convert_legacy_to)
254-
return legacy_obj
255203

256-
if isinstance(legacy_obj, dict):
257-
try:
258-
params = _extract_settings_dict(legacy_obj)
259-
state_dict = _extract_state_dict(legacy_obj)
260-
except Exception as err:
261-
raise TypeError("Unrecognized legacy checkpoint structure.") from err
262-
263-
settings = SettingsEmbeddingEvaluator(**params)
264-
model = EmbeddingEvaluationModel(settings=settings)
265-
model.load_state_dict(state_dict)
266-
model.eval()
267-
268-
_convert_legacy_if_requested(legacy_obj, convert_path=convert_legacy_to)
269-
return model
270-
271-
raise TypeError("Legacy artifact is neither a torch.nn.Module nor a compatible dict checkpoint.")
204+
def load_embedding_evaluator(
205+
filename: Union[str, Path],
206+
*,
207+
allow_legacy: bool = False,
208+
convert_legacy_to: Optional[Union[str, Path]] = None,
209+
) -> EmbeddingEvaluationModel:
210+
"""
211+
Load an EmbeddingEvaluationModel (safe-first, legacy-optional).
212+
"""
213+
return _load_model_generic(
214+
filename,
215+
allow_legacy=allow_legacy,
216+
convert_legacy_to=convert_legacy_to,
217+
settings_factory=lambda params: SettingsEmbeddingEvaluator(**params),
218+
model_factory=lambda settings: EmbeddingEvaluationModel(settings=settings),
219+
) # type: ignore[return-value]
272220

273221

274-
def load_linear_model(filepath):
275-
"""Load a LinearModel from json.
276-
"""
222+
def load_linear_model(filepath: Union[str, Path]) -> LinearModel:
223+
"""Load a LinearModel from JSON."""
277224
with open(filepath, "r", encoding="utf-8") as f:
278225
model_params = json.load(f)
279226

280227
loaded_model = LinearModel(model_params["degree"])
281-
loaded_model.model.coef_ = np.array(model_params['coef'])
282-
loaded_model.model.intercept_ = np.array(model_params['intercept'])
228+
loaded_model.model.coef_ = np.array(model_params["coef"])
229+
loaded_model.model.intercept_ = np.array(model_params["intercept"])
283230
loaded_model.poly._min_degree = model_params["min_degree"]
284231
loaded_model.poly._max_degree = model_params["max_degree"]
285232
loaded_model.poly._n_out_full = model_params["_n_out_full"]

0 commit comments

Comments
 (0)