Skip to content

Commit

Permalink
Fix two bugs w/ kwargs and preprocess_batch_fn in preprocess_fn s…
Browse files Browse the repository at this point in the history
…erialisation (#752)
  • Loading branch information
ascillitoe authored Mar 3, 2023
1 parent 5c7df29 commit d1c137b
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.990
rev: v1.0.1
hooks:
- id: mypy
additional_dependencies: [
Expand Down
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
# Change Log

## v0.12.0dev
[Full Changelog](https://github.com/SeldonIO/alibi-detect/compare/v0.11.0...master)
[Full Changelog](https://github.com/SeldonIO/alibi-detect/compare/v0.11.1...master)

## v0.11.1
[Full Changelog](https://github.com/SeldonIO/alibi-detect/compare/v0.11.0...v0.11.1)

### Fixed

- Fixed two bugs with the saving/loading of drift detector `preprocess_fn`'s [#752](https://github.com/SeldonIO/alibi-detect/pull/752)):
- When `preprocess_fn` was a custom Python function wrapped in a partial, included kwarg's were not serialized. This has now been fixed.
- When saving drift detector `preprocess_fn`'s, for kwargs saved to `.dill` files, the filenames are now prepended with the kwarg name, so that files aren't overwritten if multiple kwargs are saved to `.dill`.

## v0.11.0
[Full Changelog](https://github.com/SeldonIO/alibi-detect/compare/v0.10.5...v0.11.0)
Expand Down
5 changes: 4 additions & 1 deletion alibi_detect/saving/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,10 @@ def _load_preprocess_config(cfg: dict) -> Optional[Callable]:
logger.warning('Unable to process preprocess_fn. No preprocessing function is defined.')
return None

return partial(preprocess_fn, **kwargs)
if kwargs == {}:
return preprocess_fn
else:
return partial(preprocess_fn, **kwargs)


def _load_model_config(cfg: dict) -> Callable:
Expand Down
8 changes: 4 additions & 4 deletions alibi_detect/saving/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from functools import partial
from pathlib import Path
from typing import Callable, Optional, Tuple, Union, Any, TYPE_CHECKING
from typing import Callable, Optional, Tuple, Union, Any, Dict, TYPE_CHECKING
import dill
import numpy as np
import toml
Expand Down Expand Up @@ -264,7 +264,7 @@ def _save_preprocess_config(preprocess_fn: Callable,
The config dictionary, containing references to the serialized artefacts. The format if this dict matches that
of the `preprocess` field in the drift detector specification.
"""
preprocess_cfg = {}
preprocess_cfg: Dict[str, Any] = {}
local_path = Path('preprocess_fn')

# Serialize function
Expand Down Expand Up @@ -292,7 +292,7 @@ def _save_preprocess_config(preprocess_fn: Callable,

# Arbitrary function
elif callable(v):
src, _ = _serialize_object(v, filepath, local_path)
src, _ = _serialize_object(v, filepath, local_path.joinpath(k))
kwargs.update({k: src})

# Put remaining kwargs directly into cfg
Expand All @@ -302,7 +302,7 @@ def _save_preprocess_config(preprocess_fn: Callable,
if 'preprocess_drift' in func:
preprocess_cfg.update(kwargs)
else:
kwargs.update({'kwargs': kwargs})
preprocess_cfg.update({'kwargs': kwargs})

return preprocess_cfg

Expand Down
10 changes: 9 additions & 1 deletion alibi_detect/saving/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def encoder_dropout_model(backend, current_cases):


@fixture
def preprocess_custom(encoder_model):
def preprocess_uae(encoder_model):
"""
Preprocess function with Untrained Autoencoder.
"""
Expand Down Expand Up @@ -263,6 +263,14 @@ def preprocess_simple(x: np.ndarray):
return x*2.0


@fixture
def preprocess_simple_with_kwargs():
"""
Simple function to test serialization of generic Python function with kwargs, within preprocess_fn.
"""
return partial(preprocess_simple, kwarg1=42, kwarg2=True)


@fixture
def preprocess_nlp(embedding, tokenizer, max_len, backend):
"""
Expand Down
87 changes: 57 additions & 30 deletions alibi_detect/saving/tests/test_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Internal functions such as save_kernel/load_kernel_config etc are also tested.
"""
from functools import partial
import os
from pathlib import Path
from typing import Callable

Expand All @@ -19,7 +20,8 @@
import torch.nn as nn

from .datasets import BinData, CategoricalData, ContinuousData, MixedData, TextData
from .models import (encoder_model, preprocess_custom, preprocess_hiddenoutput, preprocess_simple, # noqa: F401
from .models import (encoder_model, preprocess_uae, preprocess_hiddenoutput, preprocess_simple, # noqa: F401
preprocess_simple_with_kwargs,
preprocess_nlp, LATENT_DIM, classifier_model, kernel, deep_kernel, nlp_embedding_and_tokenizer,
embedding, tokenizer, max_len, enc_dim, encoder_dropout_model, optimizer)

Expand Down Expand Up @@ -105,7 +107,7 @@ def test_load_simple_config(cfg, tmp_path):
assert v == cfg_new[k]


@parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput])
@parametrize('preprocess_fn', [preprocess_uae, preprocess_hiddenoutput])
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_ksdrift(data, preprocess_fn, tmp_path):
"""
Expand Down Expand Up @@ -171,7 +173,7 @@ def test_save_ksdrift_nlp(data, preprocess_fn, enc_dim, tmp_path): # noqa: F811
@pytest.mark.skipif(version.parse(scipy.__version__) < version.parse('1.7.0'),
reason="Requires scipy version >= 1.7.0")
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_cvmdrift(data, preprocess_custom, tmp_path):
def test_save_cvmdrift(data, preprocess_uae, tmp_path):
"""
Test CVMDrift on continuous datasets, with UAE as preprocess_fn.
Expand All @@ -181,14 +183,14 @@ def test_save_cvmdrift(data, preprocess_custom, tmp_path):
X_ref, X_h0 = data
cd = CVMDrift(X_ref,
p_val=P_VAL,
preprocess_fn=preprocess_custom,
preprocess_fn=preprocess_uae,
preprocess_at_init=True,
)
save_detector(cd, tmp_path)
cd_load = load_detector(tmp_path)

# Assert
np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load.x_ref)
np.testing.assert_array_equal(preprocess_uae(X_ref), cd_load.x_ref)
assert cd_load.n_features == LATENT_DIM
assert cd_load.p_val == P_VAL
assert isinstance(cd_load.preprocess_fn, Callable)
Expand All @@ -203,7 +205,7 @@ def test_save_cvmdrift(data, preprocess_custom, tmp_path):
], indirect=True
)
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed): # noqa: F811
def test_save_mmddrift(data, kernel, preprocess_uae, backend, tmp_path, seed): # noqa: F811
"""
Test MMDDrift on continuous datasets, with UAE as preprocess_fn.
Expand All @@ -217,7 +219,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed)
kwargs = {
'p_val': P_VAL,
'backend': backend,
'preprocess_fn': preprocess_custom,
'preprocess_fn': preprocess_uae,
'n_permutations': N_PERMUTATIONS,
'preprocess_at_init': True,
'kernel': kernel,
Expand All @@ -237,7 +239,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed)
preds_load = cd_load.predict(X_h0)

# assertions
np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load._detector.x_ref)
np.testing.assert_array_equal(preprocess_uae(X_ref), cd_load._detector.x_ref)
assert not cd_load._detector.infer_sigma
assert cd_load._detector.n_permutations == N_PERMUTATIONS
assert cd_load._detector.p_val == P_VAL
Expand All @@ -248,7 +250,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed)
assert preds['data']['p_val'] == preds_load['data']['p_val']


# @parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput])
# @parametrize('preprocess_fn', [preprocess_uae, preprocess_hiddenoutput])
@parametrize('preprocess_at_init', [True, False])
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_lsdddrift(data, preprocess_at_init, backend, tmp_path, seed):
Expand Down Expand Up @@ -553,7 +555,7 @@ def test_save_contextmmddrift(data, kernel, backend, tmp_path, seed): # noqa: F
assert cd_load._detector.n_permutations == N_PERMUTATIONS
assert cd_load._detector.p_val == P_VAL
assert isinstance(cd_load._detector.preprocess_fn, Callable)
assert cd_load._detector.preprocess_fn.func.__name__ == 'preprocess_simple'
assert cd_load._detector.preprocess_fn.__name__ == 'preprocess_simple'
assert cd._detector.x_kernel.sigma == cd_load._detector.x_kernel.sigma
assert cd._detector.c_kernel.sigma == cd_load._detector.c_kernel.sigma
assert cd._detector.x_kernel.init_sigma_fn == cd_load._detector.x_kernel.init_sigma_fn
Expand Down Expand Up @@ -629,7 +631,7 @@ def test_save_regressoruncertaintydrift(data, regressor, backend, tmp_path, seed
], indirect=True
)
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed): # noqa: F811
def test_save_onlinemmddrift(data, kernel, preprocess_uae, backend, tmp_path, seed): # noqa: F811
"""
Test MMDDriftOnline on continuous datasets, with UAE as preprocess_fn.
Expand All @@ -645,7 +647,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path,
cd = MMDDriftOnline(X_ref,
ert=ERT,
backend=backend,
preprocess_fn=preprocess_custom,
preprocess_fn=preprocess_uae,
n_bootstraps=N_BOOTSTRAPS,
kernel=kernel,
window_size=WINDOW_SIZE
Expand All @@ -667,7 +669,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path,
stats_load.append(pred['data']['test_stat'])

# assertions
np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load._detector.x_ref)
np.testing.assert_array_equal(preprocess_uae(X_ref), cd_load._detector.x_ref)
assert cd_load._detector.n_bootstraps == N_BOOTSTRAPS
assert cd_load._detector.ert == ERT
assert isinstance(cd_load._detector.preprocess_fn, Callable)
Expand All @@ -678,7 +680,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path,


@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
def test_save_onlinelsdddrift(data, preprocess_uae, backend, tmp_path, seed):
"""
Test LSDDDriftOnline on continuous datasets, with UAE as preprocess_fn.
Expand All @@ -694,7 +696,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
cd = LSDDDriftOnline(X_ref,
ert=ERT,
backend=backend,
preprocess_fn=preprocess_custom,
preprocess_fn=preprocess_uae,
n_bootstraps=N_BOOTSTRAPS,
window_size=WINDOW_SIZE
)
Expand All @@ -715,7 +717,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
stats_load.append(pred['data']['test_stat'])

# assertions
np.testing.assert_array_almost_equal(preprocess_custom(X_ref), cd_load.get_config()['x_ref'], 5)
np.testing.assert_array_almost_equal(preprocess_uae(X_ref), cd_load.get_config()['x_ref'], 5)
assert cd_load._detector.n_bootstraps == N_BOOTSTRAPS
assert cd_load._detector.ert == ERT
assert isinstance(cd_load._detector.preprocess_fn, Callable)
Expand All @@ -726,7 +728,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):


@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed):
def test_save_onlinecvmdrift(data, preprocess_uae, tmp_path, seed):
"""
Test CVMDriftOnline on continuous datasets, with UAE as preprocess_fn.
Expand All @@ -738,7 +740,7 @@ def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed):
with fixed_seed(seed):
cd = CVMDriftOnline(X_ref,
ert=ERT,
preprocess_fn=preprocess_custom,
preprocess_fn=preprocess_uae,
n_bootstraps=N_BOOTSTRAPS,
window_sizes=[WINDOW_SIZE]
)
Expand All @@ -759,7 +761,7 @@ def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed):
stats_load.append(pred['data']['test_stat'])

# assertions
np.testing.assert_array_almost_equal(preprocess_custom(X_ref), cd_load.get_config()['x_ref'], 5)
np.testing.assert_array_almost_equal(preprocess_uae(X_ref), cd_load.get_config()['x_ref'], 5)
assert cd_load.n_bootstraps == N_BOOTSTRAPS
assert cd_load.ert == ERT
assert isinstance(cd_load.preprocess_fn, Callable)
Expand Down Expand Up @@ -1100,15 +1102,12 @@ def test_save_deepkernel(data, deep_kernel, backend, tmp_path): # noqa: F811
assert kernel_loaded.kernel_b.sigma == deep_kernel.kernel_b.sigma


@parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput])
@parametrize('preprocess_fn', [preprocess_uae, preprocess_hiddenoutput])
@parametrize_with_cases("data", cases=ContinuousData.data_synthetic_nd, prefix='data_')
def test_save_preprocess(data, preprocess_fn, tmp_path, backend):
def test_save_preprocess_drift(data, preprocess_fn, tmp_path, backend):
"""
Unit test for _save_preprocess_config and _load_preprocess_config, with continuous data.
preprocess_fn's are saved (serialized) and then loaded, with assertions to check equivalence.
Note: _save_model_config, _save_embedding_config, _save_tokenizer_config, _load_model_config,
_load_embedding_config, _load_tokenizer_config and _prep_model_and_embedding are all well covered by this test.
Test saving/loading of the inbuilt `preprocess_drift` preprocessing functions when containing a `model`, with the
`model` either being a simple tf/torch model, or a `HiddenOutput` class.
"""
registry_str = 'tensorflow' if backend == 'tensorflow' else 'pytorch'
# Save preprocess_fn to config
Expand All @@ -1132,14 +1131,40 @@ def test_save_preprocess(data, preprocess_fn, tmp_path, backend):
assert isinstance(preprocess_fn_load.keywords['model'], nn.Module)


@parametrize('preprocess_fn', [preprocess_simple, preprocess_simple_with_kwargs])
def test_save_preprocess_custom(preprocess_fn, tmp_path):
"""
Test saving/loading of custom preprocessing functions, without and with kwargs.
"""
# Save preprocess_fn to config
filepath = tmp_path
cfg_preprocess = _save_preprocess_config(preprocess_fn, input_shape=None, filepath=filepath)
cfg_preprocess = _path2str(cfg_preprocess)
cfg_preprocess = PreprocessConfig(**cfg_preprocess).dict() # pydantic validation

assert tmp_path.joinpath(cfg_preprocess['src']).is_file()
assert cfg_preprocess['src'] == os.path.join('preprocess_fn', 'function.dill')
if isinstance(preprocess_fn, partial): # kwargs expected
assert cfg_preprocess['kwargs'] == preprocess_fn.keywords
else: # no kwargs expected
assert cfg_preprocess['kwargs'] == {}

# Resolve and load preprocess config
cfg = {'preprocess_fn': cfg_preprocess}
preprocess_fn_load = resolve_config(cfg, tmp_path)['preprocess_fn'] # tests _load_preprocess_config implicitly
if isinstance(preprocess_fn, partial):
assert preprocess_fn_load.func == preprocess_fn.func
assert preprocess_fn_load.keywords == preprocess_fn.keywords
else:
assert preprocess_fn_load == preprocess_fn


@parametrize('preprocess_fn', [preprocess_nlp])
@parametrize_with_cases("data", cases=TextData.movie_sentiment_data, prefix='data_')
def test_save_preprocess_nlp(data, preprocess_fn, tmp_path, backend):
"""
Unit test for _save_preprocess_config and _load_preprocess_config, with text data.
Note: _save_model_config, _save_embedding_config, _save_tokenizer_config, _load_model_config,
_load_embedding_config, _load_tokenizer_config and _prep_model_and_embedding are all covered by this test.
Test saving/loading of the inbuilt `preprocess_drift` preprocessing functions when containing a `model`, text
`tokenizer` and text `embedding` model.
"""
registry_str = 'tensorflow' if backend == 'tensorflow' else 'pytorch'
# Save preprocess_fn to config
Expand All @@ -1152,6 +1177,8 @@ def test_save_preprocess_nlp(data, preprocess_fn, tmp_path, backend):
assert cfg_preprocess['src'] == '@cd.' + registry_str + '.preprocess.preprocess_drift'
assert cfg_preprocess['embedding']['src'] == 'preprocess_fn/embedding'
assert cfg_preprocess['tokenizer']['src'] == 'preprocess_fn/tokenizer'
assert tmp_path.joinpath(cfg_preprocess['preprocess_batch_fn']).is_file()
assert cfg_preprocess['preprocess_batch_fn'] == os.path.join('preprocess_fn', 'preprocess_batch_fn.dill')

if isinstance(preprocess_fn.keywords['model'], (TransformerEmbedding_tf, TransformerEmbedding_pt)):
assert cfg_preprocess['model'] is None
Expand Down

0 comments on commit d1c137b

Please sign in to comment.