Skip to content

Commit

Permalink
Fix to allow config.toml to be loaded with [meta] not present (#591)
Browse files Browse the repository at this point in the history
  • Loading branch information
ascillitoe authored Aug 17, 2022
1 parent af57b12 commit cf03eb7
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 13 deletions.
6 changes: 4 additions & 2 deletions alibi_detect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,10 @@ def from_config(cls, config: dict):
config
A config dictionary matching the schema's in :class:`~alibi_detect.saving.schemas`.
"""
# Check for exisiting version_warning. meta is pop'd as don't want to pass as arg/kwarg
version_warning = config.pop('meta', {}).pop('version_warning', False)
# Check for existing version_warning. meta is pop'd as don't want to pass as arg/kwarg
meta = config.pop('meta', None)
meta = {} if meta is None else meta # Needed because pydantic sets meta=None if it is missing from the config
version_warning = meta.pop('version_warning', False)
# Init detector
detector = cls(**config)
# Add version_warning
Expand Down
2 changes: 1 addition & 1 deletion alibi_detect/saving/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class DetectorConfig(CustomBaseModel):
"Name of the detector e.g. `MMDDrift`."
backend: Literal['tensorflow', 'pytorch', 'sklearn'] = 'tensorflow'
"The detector backend."
meta: Optional[MetaData]
meta: Optional[MetaData] = None
"Config metadata. Should not be edited."
# Note: Although not all detectors have a backend, we define in base class as `backend` also determines
# whether tf or torch models used for preprocess_fn.
Expand Down
37 changes: 37 additions & 0 deletions alibi_detect/saving/tests/test_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import Callable

import toml
import dill
import numpy as np
import pytest
Expand Down Expand Up @@ -61,6 +62,16 @@
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
REGISTERED_OBJECTS = registry.get_all()

# Define a detector config dict
MMD_CFG = {
'name': 'MMDDrift',
'x_ref': np.array([[-0.30074928], [1.50240758], [0.43135768], [2.11295779], [0.79684913]]),
'p_val': 0.05,
'n_permutations': 150,
'data_type': 'tabular'
}
CFGS = [MMD_CFG]

# TODO - future: Some of the fixtures can/should be moved elsewhere (i.e. if they can be recycled for use elsewhere)


Expand Down Expand Up @@ -259,6 +270,32 @@ def preprocess_hiddenoutput(classifier, backend):
return preprocess_fn


@parametrize('cfg', CFGS)
def test_load_simple_config(cfg, tmp_path):
"""
Test that a bare-bones `config.toml` without a [meta] field can be loaded by `load_detector`.
"""
save_dir = tmp_path
x_ref_path = str(save_dir.joinpath('x_ref.npy'))
cfg_path = save_dir.joinpath('config.toml')
# Save x_ref in config.toml
x_ref = cfg['x_ref']
np.save(x_ref_path, x_ref)
cfg['x_ref'] = 'x_ref.npy'
# Save config.toml then load it
with open(cfg_path, 'w') as f:
toml.dump(cfg, f)
cd = load_detector(cfg_path)
assert cd.__class__.__name__ == cfg['name']
# Get config and compare to original (orginal cfg not fully spec'd so only compare items that are present)
cfg_new = cd.get_config()
for k, v in cfg.items():
if k == 'x_ref':
assert v == 'x_ref.npy'
else:
assert v == cfg_new[k]


@parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput])
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
def test_save_ksdrift(data, preprocess_fn, tmp_path):
Expand Down
27 changes: 17 additions & 10 deletions alibi_detect/saving/tests/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from alibi_detect.saving import validate_config
from alibi_detect.saving.saving import X_REF_FILENAME
from alibi_detect.version import __config_spec__, __version__
from copy import deepcopy

# Define a detector config dict
mmd_cfg = {
Expand All @@ -16,19 +17,14 @@
'x_ref': np.array([[-0.30074928], [1.50240758], [0.43135768], [2.11295779], [0.79684913]]),
'p_val': 0.05
}
cfgs = [mmd_cfg]
n_tests = len(cfgs)

# Define a detector config dict without meta (as simple as it gets!)
mmd_cfg_nometa = deepcopy(mmd_cfg)
mmd_cfg_nometa.pop('meta')

@pytest.fixture
def select_cfg(request):
return cfgs[request.param]


@pytest.mark.parametrize('select_cfg', list(range(n_tests)), indirect=True)
def test_validate_config(select_cfg):
cfg = select_cfg

@pytest.mark.parametrize('cfg', [mmd_cfg])
def test_validate_config(cfg):
# Original cfg
# Check original cfg doesn't raise errors
cfg_full = validate_config(cfg, resolved=True)
Expand Down Expand Up @@ -81,3 +77,14 @@ def test_validate_config(select_cfg):
with pytest.raises(ValidationError):
cfg_err = validate_config(cfg_err, resolved=True)
assert not cfg_err.get('meta').get('version_warning')


@pytest.mark.parametrize('cfg', [mmd_cfg_nometa])
def test_validate_config_wo_meta(cfg):
# Check a config w/o a meta dict can be validated
_ = validate_config(cfg, resolved=True)

# Check the unresolved case
cfg_unres = cfg.copy()
cfg_unres['x_ref'] = X_REF_FILENAME
_ = validate_config(cfg_unres)
1 change: 1 addition & 0 deletions alibi_detect/saving/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def validate_config(cfg: dict, resolved: bool = False) -> dict:

# Get meta data
meta = cfg.get('meta')
meta = {} if meta is None else meta # Needed because pydantic sets meta=None if it is missing from the config
version_warning = meta.get('version_warning', False)
version = meta.get('version', None)
config_spec = meta.get('config_spec', None)
Expand Down

0 comments on commit cf03eb7

Please sign in to comment.