Skip to content

Commit d1c137b

Browse files
author
Ashley Scillitoe
authored
Fix two bugs w/ kwargs and preprocess_batch_fn in preprocess_fn serialisation (#752)
1 parent 5c7df29 commit d1c137b

File tree

6 files changed

+84
-38
lines changed

6 files changed

+84
-38
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ repos:
44
hooks:
55
- id: flake8
66
- repo: https://github.com/pre-commit/mirrors-mypy
7-
rev: v0.990
7+
rev: v1.0.1
88
hooks:
99
- id: mypy
1010
additional_dependencies: [

CHANGELOG.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
# Change Log
22

33
## v0.12.0dev
4-
[Full Changelog](https://github.com/SeldonIO/alibi-detect/compare/v0.11.0...master)
4+
[Full Changelog](https://github.com/SeldonIO/alibi-detect/compare/v0.11.1...master)
55

6+
## v0.11.1
7+
[Full Changelog](https://github.com/SeldonIO/alibi-detect/compare/v0.11.0...v0.11.1)
8+
9+
### Fixed
10+
11+
- Fixed two bugs with the saving/loading of drift detector `preprocess_fn`'s [#752](https://github.com/SeldonIO/alibi-detect/pull/752)):
12+
- When `preprocess_fn` was a custom Python function wrapped in a partial, included kwarg's were not serialized. This has now been fixed.
13+
- When saving drift detector `preprocess_fn`'s, for kwargs saved to `.dill` files, the filenames are now prepended with the kwarg name, so that files aren't overwritten if multiple kwargs are saved to `.dill`.
614

715
## v0.11.0
816
[Full Changelog](https://github.com/SeldonIO/alibi-detect/compare/v0.10.5...v0.11.0)

alibi_detect/saving/loading.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,10 @@ def _load_preprocess_config(cfg: dict) -> Optional[Callable]:
253253
logger.warning('Unable to process preprocess_fn. No preprocessing function is defined.')
254254
return None
255255

256-
return partial(preprocess_fn, **kwargs)
256+
if kwargs == {}:
257+
return preprocess_fn
258+
else:
259+
return partial(preprocess_fn, **kwargs)
257260

258261

259262
def _load_model_config(cfg: dict) -> Callable:

alibi_detect/saving/saving.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from functools import partial
66
from pathlib import Path
7-
from typing import Callable, Optional, Tuple, Union, Any, TYPE_CHECKING
7+
from typing import Callable, Optional, Tuple, Union, Any, Dict, TYPE_CHECKING
88
import dill
99
import numpy as np
1010
import toml
@@ -264,7 +264,7 @@ def _save_preprocess_config(preprocess_fn: Callable,
264264
The config dictionary, containing references to the serialized artefacts. The format if this dict matches that
265265
of the `preprocess` field in the drift detector specification.
266266
"""
267-
preprocess_cfg = {}
267+
preprocess_cfg: Dict[str, Any] = {}
268268
local_path = Path('preprocess_fn')
269269

270270
# Serialize function
@@ -292,7 +292,7 @@ def _save_preprocess_config(preprocess_fn: Callable,
292292

293293
# Arbitrary function
294294
elif callable(v):
295-
src, _ = _serialize_object(v, filepath, local_path)
295+
src, _ = _serialize_object(v, filepath, local_path.joinpath(k))
296296
kwargs.update({k: src})
297297

298298
# Put remaining kwargs directly into cfg
@@ -302,7 +302,7 @@ def _save_preprocess_config(preprocess_fn: Callable,
302302
if 'preprocess_drift' in func:
303303
preprocess_cfg.update(kwargs)
304304
else:
305-
kwargs.update({'kwargs': kwargs})
305+
preprocess_cfg.update({'kwargs': kwargs})
306306

307307
return preprocess_cfg
308308

alibi_detect/saving/tests/models.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def encoder_dropout_model(backend, current_cases):
8989

9090

9191
@fixture
92-
def preprocess_custom(encoder_model):
92+
def preprocess_uae(encoder_model):
9393
"""
9494
Preprocess function with Untrained Autoencoder.
9595
"""
@@ -263,6 +263,14 @@ def preprocess_simple(x: np.ndarray):
263263
return x*2.0
264264

265265

266+
@fixture
267+
def preprocess_simple_with_kwargs():
268+
"""
269+
Simple function to test serialization of generic Python function with kwargs, within preprocess_fn.
270+
"""
271+
return partial(preprocess_simple, kwarg1=42, kwarg2=True)
272+
273+
266274
@fixture
267275
def preprocess_nlp(embedding, tokenizer, max_len, backend):
268276
"""

alibi_detect/saving/tests/test_saving.py

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Internal functions such as save_kernel/load_kernel_config etc are also tested.
66
"""
77
from functools import partial
8+
import os
89
from pathlib import Path
910
from typing import Callable
1011

@@ -19,7 +20,8 @@
1920
import torch.nn as nn
2021

2122
from .datasets import BinData, CategoricalData, ContinuousData, MixedData, TextData
22-
from .models import (encoder_model, preprocess_custom, preprocess_hiddenoutput, preprocess_simple, # noqa: F401
23+
from .models import (encoder_model, preprocess_uae, preprocess_hiddenoutput, preprocess_simple, # noqa: F401
24+
preprocess_simple_with_kwargs,
2325
preprocess_nlp, LATENT_DIM, classifier_model, kernel, deep_kernel, nlp_embedding_and_tokenizer,
2426
embedding, tokenizer, max_len, enc_dim, encoder_dropout_model, optimizer)
2527

@@ -105,7 +107,7 @@ def test_load_simple_config(cfg, tmp_path):
105107
assert v == cfg_new[k]
106108

107109

108-
@parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput])
110+
@parametrize('preprocess_fn', [preprocess_uae, preprocess_hiddenoutput])
109111
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
110112
def test_save_ksdrift(data, preprocess_fn, tmp_path):
111113
"""
@@ -171,7 +173,7 @@ def test_save_ksdrift_nlp(data, preprocess_fn, enc_dim, tmp_path): # noqa: F811
171173
@pytest.mark.skipif(version.parse(scipy.__version__) < version.parse('1.7.0'),
172174
reason="Requires scipy version >= 1.7.0")
173175
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
174-
def test_save_cvmdrift(data, preprocess_custom, tmp_path):
176+
def test_save_cvmdrift(data, preprocess_uae, tmp_path):
175177
"""
176178
Test CVMDrift on continuous datasets, with UAE as preprocess_fn.
177179
@@ -181,14 +183,14 @@ def test_save_cvmdrift(data, preprocess_custom, tmp_path):
181183
X_ref, X_h0 = data
182184
cd = CVMDrift(X_ref,
183185
p_val=P_VAL,
184-
preprocess_fn=preprocess_custom,
186+
preprocess_fn=preprocess_uae,
185187
preprocess_at_init=True,
186188
)
187189
save_detector(cd, tmp_path)
188190
cd_load = load_detector(tmp_path)
189191

190192
# Assert
191-
np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load.x_ref)
193+
np.testing.assert_array_equal(preprocess_uae(X_ref), cd_load.x_ref)
192194
assert cd_load.n_features == LATENT_DIM
193195
assert cd_load.p_val == P_VAL
194196
assert isinstance(cd_load.preprocess_fn, Callable)
@@ -203,7 +205,7 @@ def test_save_cvmdrift(data, preprocess_custom, tmp_path):
203205
], indirect=True
204206
)
205207
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
206-
def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed): # noqa: F811
208+
def test_save_mmddrift(data, kernel, preprocess_uae, backend, tmp_path, seed): # noqa: F811
207209
"""
208210
Test MMDDrift on continuous datasets, with UAE as preprocess_fn.
209211
@@ -217,7 +219,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed)
217219
kwargs = {
218220
'p_val': P_VAL,
219221
'backend': backend,
220-
'preprocess_fn': preprocess_custom,
222+
'preprocess_fn': preprocess_uae,
221223
'n_permutations': N_PERMUTATIONS,
222224
'preprocess_at_init': True,
223225
'kernel': kernel,
@@ -237,7 +239,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed)
237239
preds_load = cd_load.predict(X_h0)
238240

239241
# assertions
240-
np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load._detector.x_ref)
242+
np.testing.assert_array_equal(preprocess_uae(X_ref), cd_load._detector.x_ref)
241243
assert not cd_load._detector.infer_sigma
242244
assert cd_load._detector.n_permutations == N_PERMUTATIONS
243245
assert cd_load._detector.p_val == P_VAL
@@ -248,7 +250,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed)
248250
assert preds['data']['p_val'] == preds_load['data']['p_val']
249251

250252

251-
# @parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput])
253+
# @parametrize('preprocess_fn', [preprocess_uae, preprocess_hiddenoutput])
252254
@parametrize('preprocess_at_init', [True, False])
253255
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
254256
def test_save_lsdddrift(data, preprocess_at_init, backend, tmp_path, seed):
@@ -553,7 +555,7 @@ def test_save_contextmmddrift(data, kernel, backend, tmp_path, seed): # noqa: F
553555
assert cd_load._detector.n_permutations == N_PERMUTATIONS
554556
assert cd_load._detector.p_val == P_VAL
555557
assert isinstance(cd_load._detector.preprocess_fn, Callable)
556-
assert cd_load._detector.preprocess_fn.func.__name__ == 'preprocess_simple'
558+
assert cd_load._detector.preprocess_fn.__name__ == 'preprocess_simple'
557559
assert cd._detector.x_kernel.sigma == cd_load._detector.x_kernel.sigma
558560
assert cd._detector.c_kernel.sigma == cd_load._detector.c_kernel.sigma
559561
assert cd._detector.x_kernel.init_sigma_fn == cd_load._detector.x_kernel.init_sigma_fn
@@ -629,7 +631,7 @@ def test_save_regressoruncertaintydrift(data, regressor, backend, tmp_path, seed
629631
], indirect=True
630632
)
631633
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
632-
def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed): # noqa: F811
634+
def test_save_onlinemmddrift(data, kernel, preprocess_uae, backend, tmp_path, seed): # noqa: F811
633635
"""
634636
Test MMDDriftOnline on continuous datasets, with UAE as preprocess_fn.
635637
@@ -645,7 +647,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path,
645647
cd = MMDDriftOnline(X_ref,
646648
ert=ERT,
647649
backend=backend,
648-
preprocess_fn=preprocess_custom,
650+
preprocess_fn=preprocess_uae,
649651
n_bootstraps=N_BOOTSTRAPS,
650652
kernel=kernel,
651653
window_size=WINDOW_SIZE
@@ -667,7 +669,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path,
667669
stats_load.append(pred['data']['test_stat'])
668670

669671
# assertions
670-
np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load._detector.x_ref)
672+
np.testing.assert_array_equal(preprocess_uae(X_ref), cd_load._detector.x_ref)
671673
assert cd_load._detector.n_bootstraps == N_BOOTSTRAPS
672674
assert cd_load._detector.ert == ERT
673675
assert isinstance(cd_load._detector.preprocess_fn, Callable)
@@ -678,7 +680,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path,
678680

679681

680682
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
681-
def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
683+
def test_save_onlinelsdddrift(data, preprocess_uae, backend, tmp_path, seed):
682684
"""
683685
Test LSDDDriftOnline on continuous datasets, with UAE as preprocess_fn.
684686
@@ -694,7 +696,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
694696
cd = LSDDDriftOnline(X_ref,
695697
ert=ERT,
696698
backend=backend,
697-
preprocess_fn=preprocess_custom,
699+
preprocess_fn=preprocess_uae,
698700
n_bootstraps=N_BOOTSTRAPS,
699701
window_size=WINDOW_SIZE
700702
)
@@ -715,7 +717,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
715717
stats_load.append(pred['data']['test_stat'])
716718

717719
# assertions
718-
np.testing.assert_array_almost_equal(preprocess_custom(X_ref), cd_load.get_config()['x_ref'], 5)
720+
np.testing.assert_array_almost_equal(preprocess_uae(X_ref), cd_load.get_config()['x_ref'], 5)
719721
assert cd_load._detector.n_bootstraps == N_BOOTSTRAPS
720722
assert cd_load._detector.ert == ERT
721723
assert isinstance(cd_load._detector.preprocess_fn, Callable)
@@ -726,7 +728,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
726728

727729

728730
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
729-
def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed):
731+
def test_save_onlinecvmdrift(data, preprocess_uae, tmp_path, seed):
730732
"""
731733
Test CVMDriftOnline on continuous datasets, with UAE as preprocess_fn.
732734
@@ -738,7 +740,7 @@ def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed):
738740
with fixed_seed(seed):
739741
cd = CVMDriftOnline(X_ref,
740742
ert=ERT,
741-
preprocess_fn=preprocess_custom,
743+
preprocess_fn=preprocess_uae,
742744
n_bootstraps=N_BOOTSTRAPS,
743745
window_sizes=[WINDOW_SIZE]
744746
)
@@ -759,7 +761,7 @@ def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed):
759761
stats_load.append(pred['data']['test_stat'])
760762

761763
# assertions
762-
np.testing.assert_array_almost_equal(preprocess_custom(X_ref), cd_load.get_config()['x_ref'], 5)
764+
np.testing.assert_array_almost_equal(preprocess_uae(X_ref), cd_load.get_config()['x_ref'], 5)
763765
assert cd_load.n_bootstraps == N_BOOTSTRAPS
764766
assert cd_load.ert == ERT
765767
assert isinstance(cd_load.preprocess_fn, Callable)
@@ -1100,15 +1102,12 @@ def test_save_deepkernel(data, deep_kernel, backend, tmp_path): # noqa: F811
11001102
assert kernel_loaded.kernel_b.sigma == deep_kernel.kernel_b.sigma
11011103

11021104

1103-
@parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput])
1105+
@parametrize('preprocess_fn', [preprocess_uae, preprocess_hiddenoutput])
11041106
@parametrize_with_cases("data", cases=ContinuousData.data_synthetic_nd, prefix='data_')
1105-
def test_save_preprocess(data, preprocess_fn, tmp_path, backend):
1107+
def test_save_preprocess_drift(data, preprocess_fn, tmp_path, backend):
11061108
"""
1107-
Unit test for _save_preprocess_config and _load_preprocess_config, with continuous data.
1108-
1109-
preprocess_fn's are saved (serialized) and then loaded, with assertions to check equivalence.
1110-
Note: _save_model_config, _save_embedding_config, _save_tokenizer_config, _load_model_config,
1111-
_load_embedding_config, _load_tokenizer_config and _prep_model_and_embedding are all well covered by this test.
1109+
Test saving/loading of the inbuilt `preprocess_drift` preprocessing functions when containing a `model`, with the
1110+
`model` either being a simple tf/torch model, or a `HiddenOutput` class.
11121111
"""
11131112
registry_str = 'tensorflow' if backend == 'tensorflow' else 'pytorch'
11141113
# Save preprocess_fn to config
@@ -1132,14 +1131,40 @@ def test_save_preprocess(data, preprocess_fn, tmp_path, backend):
11321131
assert isinstance(preprocess_fn_load.keywords['model'], nn.Module)
11331132

11341133

1134+
@parametrize('preprocess_fn', [preprocess_simple, preprocess_simple_with_kwargs])
1135+
def test_save_preprocess_custom(preprocess_fn, tmp_path):
1136+
"""
1137+
Test saving/loading of custom preprocessing functions, without and with kwargs.
1138+
"""
1139+
# Save preprocess_fn to config
1140+
filepath = tmp_path
1141+
cfg_preprocess = _save_preprocess_config(preprocess_fn, input_shape=None, filepath=filepath)
1142+
cfg_preprocess = _path2str(cfg_preprocess)
1143+
cfg_preprocess = PreprocessConfig(**cfg_preprocess).dict() # pydantic validation
1144+
1145+
assert tmp_path.joinpath(cfg_preprocess['src']).is_file()
1146+
assert cfg_preprocess['src'] == os.path.join('preprocess_fn', 'function.dill')
1147+
if isinstance(preprocess_fn, partial): # kwargs expected
1148+
assert cfg_preprocess['kwargs'] == preprocess_fn.keywords
1149+
else: # no kwargs expected
1150+
assert cfg_preprocess['kwargs'] == {}
1151+
1152+
# Resolve and load preprocess config
1153+
cfg = {'preprocess_fn': cfg_preprocess}
1154+
preprocess_fn_load = resolve_config(cfg, tmp_path)['preprocess_fn'] # tests _load_preprocess_config implicitly
1155+
if isinstance(preprocess_fn, partial):
1156+
assert preprocess_fn_load.func == preprocess_fn.func
1157+
assert preprocess_fn_load.keywords == preprocess_fn.keywords
1158+
else:
1159+
assert preprocess_fn_load == preprocess_fn
1160+
1161+
11351162
@parametrize('preprocess_fn', [preprocess_nlp])
11361163
@parametrize_with_cases("data", cases=TextData.movie_sentiment_data, prefix='data_')
11371164
def test_save_preprocess_nlp(data, preprocess_fn, tmp_path, backend):
11381165
"""
1139-
Unit test for _save_preprocess_config and _load_preprocess_config, with text data.
1140-
1141-
Note: _save_model_config, _save_embedding_config, _save_tokenizer_config, _load_model_config,
1142-
_load_embedding_config, _load_tokenizer_config and _prep_model_and_embedding are all covered by this test.
1166+
Test saving/loading of the inbuilt `preprocess_drift` preprocessing functions when containing a `model`, text
1167+
`tokenizer` and text `embedding` model.
11431168
"""
11441169
registry_str = 'tensorflow' if backend == 'tensorflow' else 'pytorch'
11451170
# Save preprocess_fn to config
@@ -1152,6 +1177,8 @@ def test_save_preprocess_nlp(data, preprocess_fn, tmp_path, backend):
11521177
assert cfg_preprocess['src'] == '@cd.' + registry_str + '.preprocess.preprocess_drift'
11531178
assert cfg_preprocess['embedding']['src'] == 'preprocess_fn/embedding'
11541179
assert cfg_preprocess['tokenizer']['src'] == 'preprocess_fn/tokenizer'
1180+
assert tmp_path.joinpath(cfg_preprocess['preprocess_batch_fn']).is_file()
1181+
assert cfg_preprocess['preprocess_batch_fn'] == os.path.join('preprocess_fn', 'preprocess_batch_fn.dill')
11551182

11561183
if isinstance(preprocess_fn.keywords['model'], (TransformerEmbedding_tf, TransformerEmbedding_pt)):
11571184
assert cfg_preprocess['model'] is None

0 commit comments

Comments
 (0)