1- from typing import Union , Dict , Optional , Any
1+ from typing import Union , Dict , Optional , Any , Callable
22import json
33from pathlib import Path
44import warnings
1212 SettingsEmbeddingEvaluator ,
1313 SettingsMS2Deepscore ,
1414)
15- from ms2deepscore .models .io_utils import _settings_to_json # re-use from your module where save() lives
15+ from ms2deepscore .models .io_utils import _settings_to_json
1616
1717
1818# ---------- internal helpers ----------
1919
2020def _torch_device () -> torch .device :
2121 return torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
2222
23+
2324def _load_ckpt_safe (filename : Union [str , Path ]) -> Dict [str , Any ]:
24- """
25- Load a checkpoint using PyTorch's restricted unpickler (weights_only=True).
26- The file must contain only tensors and simple Python primitives.
27- """
25+ """Load a checkpoint using PyTorch's restricted unpickler (weights_only=True)."""
2826 try :
2927 return torch .load (str (filename ), map_location = _torch_device (), weights_only = True )
3028 except TypeError :
@@ -34,13 +32,13 @@ def _load_ckpt_safe(filename: Union[str, Path]) -> Dict[str, Any]:
3432 raise
3533 return ckpt
3634
35+
3736def _extract_settings_dict (ckpt : Dict [str , Any ]) -> Dict [str , Any ]:
3837 """
39- Support current format and a couple of earlier safe variants.
40- Priority:
41- 1) settings_json (preferred, current format)
42- 2) model_params_json (earlier safe format)
43- 3) model_params (earlier safe format with a plain dict)
38+ Support current format and earlier safe variants:
39+ - settings_json (preferred)
40+ - model_params_json (older)
41+ - model_params (older plain dict)
4442 """
4543 if "settings_json" in ckpt :
4644 return json .loads (ckpt ["settings_json" ])
@@ -55,23 +53,15 @@ def _extract_settings_dict(ckpt: Dict[str, Any]) -> Dict[str, Any]:
5553 "No settings found. Expected 'settings_json' (preferred) or 'model_params_json' / 'model_params'."
5654 )
5755
56+
5857def _extract_state_dict (ckpt : Dict [str , Any ]) -> Dict [str , torch .Tensor ]:
59- """
60- Support current key 'state_dict' and older 'model_state_dict'.
61- """
58+ """Support current key 'state_dict' and older 'model_state_dict'."""
6259 if "state_dict" in ckpt :
6360 return ckpt ["state_dict" ]
6461 if "model_state_dict" in ckpt :
6562 return ckpt ["model_state_dict" ]
6663 raise KeyError ("No weights found. Expected 'state_dict' or 'model_state_dict'." )
6764
68- def _maybe_warn_version (ckpt : Dict [str , Any ]) -> None :
69- v = ckpt .get ("ms2deepscore_version" ) or ckpt .get ("version" )
70- if v and v != __version__ :
71- warnings .warn (
72- f"Model was saved with ms2deepscore { v } , but you're running { __version__ } . "
73- "Consider updating either the model or the library if you hit incompatibilities."
74- )
7565
7666def _convert_legacy_if_requested (
7767 obj : Any ,
@@ -89,10 +79,8 @@ def _convert_legacy_if_requested(
8979 out_path = Path (convert_path )
9080 out_path .parent .mkdir (parents = True , exist_ok = True )
9181
92- # Two conversion paths:
9382 # (1) nn.Module with .state_dict and .model_settings
9483 if isinstance (obj , torch .nn .Module ) and hasattr (obj , "state_dict" ) and hasattr (obj , "model_settings" ):
95- # Build a minimal safe checkpoint (match the new save())
9684 safe_ckpt = {
9785 "format" : "ms2deepscore.safe.v1" ,
9886 "ms2deepscore_version" : getattr (obj , "version" , __version__ ),
@@ -105,12 +93,10 @@ def _convert_legacy_if_requested(
10593
10694 # (2) dict-like legacy checkpoint with params + state_dict
10795 if isinstance (obj , dict ) and ("model_state_dict" in obj or "state_dict" in obj ):
108- # Try to normalize to the new shape
10996 try :
11097 params = _extract_settings_dict (obj )
11198 except Exception :
11299 params = obj .get ("model_params" , {})
113- # JSON encode params to ensure safety
114100 settings_json = json .dumps (params , ensure_ascii = False , sort_keys = True )
115101
116102 state_dict = _extract_state_dict (obj )
@@ -124,55 +110,37 @@ def _convert_legacy_if_requested(
124110 torch .save (safe_ckpt , str (out_path ))
125111 return out_path
126112
127- # Otherwise we don't know how to convert
128113 return None
129114
130115
131- # ---------- public API ----------
116+ # ---------- generic loader ----------
132117
133- def load_model (
118+ def _load_model_generic (
134119 filename : Union [str , Path ],
135120 * ,
136- allow_legacy : bool = False ,
137- convert_legacy_to : Optional [Union [str , Path ]] = None ,
138- ) -> SiameseSpectralModel :
121+ allow_legacy : bool ,
122+ convert_legacy_to : Optional [Union [str , Path ]],
123+ settings_factory : Callable [[Dict [str , Any ]], Any ],
124+ model_factory : Callable [[Any ], torch .nn .Module ],
125+ ) -> torch .nn .Module :
139126 """
140- Load a SiameseSpectralModel.
141-
142- Normal path:
127+ Shared loader:
143128 1) Safe-load checkpoint (weights_only=True).
144- 2) Parse settings_json -> SettingsMS2Deepscore.
145- 3) Instantiate SiameseSpectralModel(settings=...) then load state_dict.
146-
147- Legacy path (only if allow_legacy=True):
148- - Attempt torch.load(weights_only=False) and either:
149- * return the nn.Module directly (if saved whole), or
150- * adapt a dict-like legacy checkpoint to the new format.
151- - If convert_legacy_to is given, also write a converted, safe checkpoint.
152-
153- Parameters
154- ----------
155- filename : str | Path
156- allow_legacy : bool, default False
157- Permit unsafe pickle loading for truly old files. Only use for trusted sources.
158- convert_legacy_to : Optional[str | Path]
159- If given and a legacy artifact is loaded, write an equivalent safe checkpoint here.
129+ 2) Parse settings -> create Settings via settings_factory(params).
130+ 3) Instantiate model via model_factory(settings), then load state_dict.
131+ 4) If safe path fails and allow_legacy=True, attempt unsafe legacy path.
160132 """
161- device = _torch_device ()
162-
163133 # --- preferred safe path
164134 try :
165135 ckpt = _load_ckpt_safe (filename )
166- _maybe_warn_version (ckpt )
167136 params = _extract_settings_dict (ckpt )
168137 state_dict = _extract_state_dict (ckpt )
169138
170- settings = SettingsMS2Deepscore ( ** params , validate_settings = False )
171- model = SiameseSpectralModel ( settings = settings )
139+ settings = settings_factory ( params )
140+ model = model_factory ( settings )
172141 model .load_state_dict (state_dict )
173142 model .eval ()
174143 return model
175-
176144 except Exception as safe_err :
177145 if not allow_legacy :
178146 raise RuntimeError (
@@ -184,24 +152,24 @@ def load_model(
184152 "Using UNSAFE legacy loading (weights_only=False). Only do this for trusted files." ,
185153 RuntimeWarning ,
186154 )
187- legacy_obj = torch .load (str (filename ), map_location = device , weights_only = False )
155+ legacy_obj = torch .load (str (filename ), map_location = _torch_device () , weights_only = False )
188156
189157 # If the whole module was saved, just use it.
190158 if isinstance (legacy_obj , torch .nn .Module ):
191159 legacy_obj .eval ()
192160 _convert_legacy_if_requested (legacy_obj , convert_path = convert_legacy_to )
193161 return legacy_obj
194162
195- # If it looks like a legacy dict checkpoint, normalize and build the model
163+ # Dict- like legacy checkpoint: normalize and build the model
196164 if isinstance (legacy_obj , dict ):
197165 try :
198166 params = _extract_settings_dict (legacy_obj )
199167 state_dict = _extract_state_dict (legacy_obj )
200168 except Exception as err :
201169 raise TypeError ("Unrecognized legacy checkpoint structure." ) from err
202170
203- settings = SettingsMS2Deepscore ( ** params , validate_settings = False )
204- model = SiameseSpectralModel ( settings = settings )
171+ settings = settings_factory ( params )
172+ model = model_factory ( settings )
205173 model .load_state_dict (state_dict )
206174 model .eval ()
207175
@@ -211,75 +179,54 @@ def load_model(
211179 raise TypeError ("Legacy artifact is neither a torch.nn.Module nor a compatible dict checkpoint." )
212180
213181
214- def load_embedding_evaluator (
182+ # ---------- public API ----------
183+
184+ def load_model (
215185 filename : Union [str , Path ],
216186 * ,
217187 allow_legacy : bool = False ,
218188 convert_legacy_to : Optional [Union [str , Path ]] = None ,
219- ) -> EmbeddingEvaluationModel :
189+ ) -> SiameseSpectralModel :
220190 """
221- Load an EmbeddingEvaluationModel with the same safe-first, legacy-optional policy .
191+ Load a SiameseSpectralModel ( safe-first, legacy-optional) .
222192 """
223- device = _torch_device ()
224-
225- # --- preferred safe path
226- try :
227- ckpt = _load_ckpt_safe (filename )
228- _maybe_warn_version (ckpt )
229- params = _extract_settings_dict (ckpt )
230- state_dict = _extract_state_dict (ckpt )
231-
232- settings = SettingsEmbeddingEvaluator (** params )
233- model = EmbeddingEvaluationModel (settings = settings )
234- model .load_state_dict (state_dict )
235- model .eval ()
236- return model
237-
238- except Exception as safe_err :
239- if not allow_legacy :
240- raise RuntimeError (
241- "Failed to load safely. If this is a trusted legacy file, call with allow_legacy=True."
242- ) from safe_err
193+ return _load_model_generic (
194+ filename ,
195+ allow_legacy = allow_legacy ,
196+ convert_legacy_to = convert_legacy_to ,
197+ settings_factory = lambda params : SettingsMS2Deepscore (
198+ ** params , validate_settings = False
199+ ),
200+ model_factory = lambda settings : SiameseSpectralModel (settings = settings ),
201+ ) # type: ignore[return-value]
243202
244- # --- legacy fallback (unsafe; only for trusted files)
245- warnings .warn (
246- "Using UNSAFE legacy loading (weights_only=False). Only do this for trusted files." ,
247- RuntimeWarning ,
248- )
249- legacy_obj = torch .load (str (filename ), map_location = device , weights_only = False )
250-
251- if isinstance (legacy_obj , torch .nn .Module ):
252- legacy_obj .eval ()
253- _convert_legacy_if_requested (legacy_obj , convert_path = convert_legacy_to )
254- return legacy_obj
255203
256- if isinstance ( legacy_obj , dict ):
257- try :
258- params = _extract_settings_dict ( legacy_obj )
259- state_dict = _extract_state_dict ( legacy_obj )
260- except Exception as err :
261- raise TypeError ( "Unrecognized legacy checkpoint structure." ) from err
262-
263- settings = SettingsEmbeddingEvaluator ( ** params )
264- model = EmbeddingEvaluationModel ( settings = settings )
265- model . load_state_dict ( state_dict )
266- model . eval ()
267-
268- _convert_legacy_if_requested ( legacy_obj , convert_path = convert_legacy_to )
269- return model
270-
271- raise TypeError ( "Legacy artifact is neither a torch.nn.Module nor a compatible dict checkpoint." )
204+ def load_embedding_evaluator (
205+ filename : Union [ str , Path ],
206+ * ,
207+ allow_legacy : bool = False ,
208+ convert_legacy_to : Optional [ Union [ str , Path ]] = None ,
209+ ) -> EmbeddingEvaluationModel :
210+ """
211+ Load an EmbeddingEvaluationModel (safe-first, legacy-optional).
212+ """
213+ return _load_model_generic (
214+ filename ,
215+ allow_legacy = allow_legacy ,
216+ convert_legacy_to = convert_legacy_to ,
217+ settings_factory = lambda params : SettingsEmbeddingEvaluator ( ** params ),
218+ model_factory = lambda settings : EmbeddingEvaluationModel ( settings = settings ),
219+ ) # type: ignore[return-value]
272220
273221
274- def load_linear_model (filepath ):
275- """Load a LinearModel from json.
276- """
222+ def load_linear_model (filepath : Union [str , Path ]) -> LinearModel :
223+ """Load a LinearModel from JSON."""
277224 with open (filepath , "r" , encoding = "utf-8" ) as f :
278225 model_params = json .load (f )
279226
280227 loaded_model = LinearModel (model_params ["degree" ])
281- loaded_model .model .coef_ = np .array (model_params [' coef' ])
282- loaded_model .model .intercept_ = np .array (model_params [' intercept' ])
228+ loaded_model .model .coef_ = np .array (model_params [" coef" ])
229+ loaded_model .model .intercept_ = np .array (model_params [" intercept" ])
283230 loaded_model .poly ._min_degree = model_params ["min_degree" ]
284231 loaded_model .poly ._max_degree = model_params ["max_degree" ]
285232 loaded_model .poly ._n_out_full = model_params ["_n_out_full" ]
0 commit comments