5
5
Internal functions such as save_kernel/load_kernel_config etc are also tested.
6
6
"""
7
7
from functools import partial
8
+ import os
8
9
from pathlib import Path
9
10
from typing import Callable
10
11
19
20
import torch .nn as nn
20
21
21
22
from .datasets import BinData , CategoricalData , ContinuousData , MixedData , TextData
22
- from .models import (encoder_model , preprocess_custom , preprocess_hiddenoutput , preprocess_simple , # noqa: F401
23
+ from .models import (encoder_model , preprocess_uae , preprocess_hiddenoutput , preprocess_simple , # noqa: F401
24
+ preprocess_simple_with_kwargs ,
23
25
preprocess_nlp , LATENT_DIM , classifier_model , kernel , deep_kernel , nlp_embedding_and_tokenizer ,
24
26
embedding , tokenizer , max_len , enc_dim , encoder_dropout_model , optimizer )
25
27
@@ -105,7 +107,7 @@ def test_load_simple_config(cfg, tmp_path):
105
107
assert v == cfg_new [k ]
106
108
107
109
108
- @parametrize ('preprocess_fn' , [preprocess_custom , preprocess_hiddenoutput ])
110
+ @parametrize ('preprocess_fn' , [preprocess_uae , preprocess_hiddenoutput ])
109
111
@parametrize_with_cases ("data" , cases = ContinuousData , prefix = 'data_' )
110
112
def test_save_ksdrift (data , preprocess_fn , tmp_path ):
111
113
"""
@@ -171,7 +173,7 @@ def test_save_ksdrift_nlp(data, preprocess_fn, enc_dim, tmp_path): # noqa: F811
171
173
@pytest .mark .skipif (version .parse (scipy .__version__ ) < version .parse ('1.7.0' ),
172
174
reason = "Requires scipy version >= 1.7.0" )
173
175
@parametrize_with_cases ("data" , cases = ContinuousData , prefix = 'data_' )
174
- def test_save_cvmdrift (data , preprocess_custom , tmp_path ):
176
+ def test_save_cvmdrift (data , preprocess_uae , tmp_path ):
175
177
"""
176
178
Test CVMDrift on continuous datasets, with UAE as preprocess_fn.
177
179
@@ -181,14 +183,14 @@ def test_save_cvmdrift(data, preprocess_custom, tmp_path):
181
183
X_ref , X_h0 = data
182
184
cd = CVMDrift (X_ref ,
183
185
p_val = P_VAL ,
184
- preprocess_fn = preprocess_custom ,
186
+ preprocess_fn = preprocess_uae ,
185
187
preprocess_at_init = True ,
186
188
)
187
189
save_detector (cd , tmp_path )
188
190
cd_load = load_detector (tmp_path )
189
191
190
192
# Assert
191
- np .testing .assert_array_equal (preprocess_custom (X_ref ), cd_load .x_ref )
193
+ np .testing .assert_array_equal (preprocess_uae (X_ref ), cd_load .x_ref )
192
194
assert cd_load .n_features == LATENT_DIM
193
195
assert cd_load .p_val == P_VAL
194
196
assert isinstance (cd_load .preprocess_fn , Callable )
@@ -203,7 +205,7 @@ def test_save_cvmdrift(data, preprocess_custom, tmp_path):
203
205
], indirect = True
204
206
)
205
207
@parametrize_with_cases ("data" , cases = ContinuousData , prefix = 'data_' )
206
- def test_save_mmddrift (data , kernel , preprocess_custom , backend , tmp_path , seed ): # noqa: F811
208
+ def test_save_mmddrift (data , kernel , preprocess_uae , backend , tmp_path , seed ): # noqa: F811
207
209
"""
208
210
Test MMDDrift on continuous datasets, with UAE as preprocess_fn.
209
211
@@ -217,7 +219,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed)
217
219
kwargs = {
218
220
'p_val' : P_VAL ,
219
221
'backend' : backend ,
220
- 'preprocess_fn' : preprocess_custom ,
222
+ 'preprocess_fn' : preprocess_uae ,
221
223
'n_permutations' : N_PERMUTATIONS ,
222
224
'preprocess_at_init' : True ,
223
225
'kernel' : kernel ,
@@ -237,7 +239,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed)
237
239
preds_load = cd_load .predict (X_h0 )
238
240
239
241
# assertions
240
- np .testing .assert_array_equal (preprocess_custom (X_ref ), cd_load ._detector .x_ref )
242
+ np .testing .assert_array_equal (preprocess_uae (X_ref ), cd_load ._detector .x_ref )
241
243
assert not cd_load ._detector .infer_sigma
242
244
assert cd_load ._detector .n_permutations == N_PERMUTATIONS
243
245
assert cd_load ._detector .p_val == P_VAL
@@ -248,7 +250,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed)
248
250
assert preds ['data' ]['p_val' ] == preds_load ['data' ]['p_val' ]
249
251
250
252
251
- # @parametrize('preprocess_fn', [preprocess_custom , preprocess_hiddenoutput])
253
+ # @parametrize('preprocess_fn', [preprocess_uae , preprocess_hiddenoutput])
252
254
@parametrize ('preprocess_at_init' , [True , False ])
253
255
@parametrize_with_cases ("data" , cases = ContinuousData , prefix = 'data_' )
254
256
def test_save_lsdddrift (data , preprocess_at_init , backend , tmp_path , seed ):
@@ -553,7 +555,7 @@ def test_save_contextmmddrift(data, kernel, backend, tmp_path, seed): # noqa: F
553
555
assert cd_load ._detector .n_permutations == N_PERMUTATIONS
554
556
assert cd_load ._detector .p_val == P_VAL
555
557
assert isinstance (cd_load ._detector .preprocess_fn , Callable )
556
- assert cd_load ._detector .preprocess_fn .func . __name__ == 'preprocess_simple'
558
+ assert cd_load ._detector .preprocess_fn .__name__ == 'preprocess_simple'
557
559
assert cd ._detector .x_kernel .sigma == cd_load ._detector .x_kernel .sigma
558
560
assert cd ._detector .c_kernel .sigma == cd_load ._detector .c_kernel .sigma
559
561
assert cd ._detector .x_kernel .init_sigma_fn == cd_load ._detector .x_kernel .init_sigma_fn
@@ -629,7 +631,7 @@ def test_save_regressoruncertaintydrift(data, regressor, backend, tmp_path, seed
629
631
], indirect = True
630
632
)
631
633
@parametrize_with_cases ("data" , cases = ContinuousData , prefix = 'data_' )
632
- def test_save_onlinemmddrift (data , kernel , preprocess_custom , backend , tmp_path , seed ): # noqa: F811
634
+ def test_save_onlinemmddrift (data , kernel , preprocess_uae , backend , tmp_path , seed ): # noqa: F811
633
635
"""
634
636
Test MMDDriftOnline on continuous datasets, with UAE as preprocess_fn.
635
637
@@ -645,7 +647,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path,
645
647
cd = MMDDriftOnline (X_ref ,
646
648
ert = ERT ,
647
649
backend = backend ,
648
- preprocess_fn = preprocess_custom ,
650
+ preprocess_fn = preprocess_uae ,
649
651
n_bootstraps = N_BOOTSTRAPS ,
650
652
kernel = kernel ,
651
653
window_size = WINDOW_SIZE
@@ -667,7 +669,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path,
667
669
stats_load .append (pred ['data' ]['test_stat' ])
668
670
669
671
# assertions
670
- np .testing .assert_array_equal (preprocess_custom (X_ref ), cd_load ._detector .x_ref )
672
+ np .testing .assert_array_equal (preprocess_uae (X_ref ), cd_load ._detector .x_ref )
671
673
assert cd_load ._detector .n_bootstraps == N_BOOTSTRAPS
672
674
assert cd_load ._detector .ert == ERT
673
675
assert isinstance (cd_load ._detector .preprocess_fn , Callable )
@@ -678,7 +680,7 @@ def test_save_onlinemmddrift(data, kernel, preprocess_custom, backend, tmp_path,
678
680
679
681
680
682
@parametrize_with_cases ("data" , cases = ContinuousData , prefix = 'data_' )
681
- def test_save_onlinelsdddrift (data , preprocess_custom , backend , tmp_path , seed ):
683
+ def test_save_onlinelsdddrift (data , preprocess_uae , backend , tmp_path , seed ):
682
684
"""
683
685
Test LSDDDriftOnline on continuous datasets, with UAE as preprocess_fn.
684
686
@@ -694,7 +696,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
694
696
cd = LSDDDriftOnline (X_ref ,
695
697
ert = ERT ,
696
698
backend = backend ,
697
- preprocess_fn = preprocess_custom ,
699
+ preprocess_fn = preprocess_uae ,
698
700
n_bootstraps = N_BOOTSTRAPS ,
699
701
window_size = WINDOW_SIZE
700
702
)
@@ -715,7 +717,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
715
717
stats_load .append (pred ['data' ]['test_stat' ])
716
718
717
719
# assertions
718
- np .testing .assert_array_almost_equal (preprocess_custom (X_ref ), cd_load .get_config ()['x_ref' ], 5 )
720
+ np .testing .assert_array_almost_equal (preprocess_uae (X_ref ), cd_load .get_config ()['x_ref' ], 5 )
719
721
assert cd_load ._detector .n_bootstraps == N_BOOTSTRAPS
720
722
assert cd_load ._detector .ert == ERT
721
723
assert isinstance (cd_load ._detector .preprocess_fn , Callable )
@@ -726,7 +728,7 @@ def test_save_onlinelsdddrift(data, preprocess_custom, backend, tmp_path, seed):
726
728
727
729
728
730
@parametrize_with_cases ("data" , cases = ContinuousData , prefix = 'data_' )
729
- def test_save_onlinecvmdrift (data , preprocess_custom , tmp_path , seed ):
731
+ def test_save_onlinecvmdrift (data , preprocess_uae , tmp_path , seed ):
730
732
"""
731
733
Test CVMDriftOnline on continuous datasets, with UAE as preprocess_fn.
732
734
@@ -738,7 +740,7 @@ def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed):
738
740
with fixed_seed (seed ):
739
741
cd = CVMDriftOnline (X_ref ,
740
742
ert = ERT ,
741
- preprocess_fn = preprocess_custom ,
743
+ preprocess_fn = preprocess_uae ,
742
744
n_bootstraps = N_BOOTSTRAPS ,
743
745
window_sizes = [WINDOW_SIZE ]
744
746
)
@@ -759,7 +761,7 @@ def test_save_onlinecvmdrift(data, preprocess_custom, tmp_path, seed):
759
761
stats_load .append (pred ['data' ]['test_stat' ])
760
762
761
763
# assertions
762
- np .testing .assert_array_almost_equal (preprocess_custom (X_ref ), cd_load .get_config ()['x_ref' ], 5 )
764
+ np .testing .assert_array_almost_equal (preprocess_uae (X_ref ), cd_load .get_config ()['x_ref' ], 5 )
763
765
assert cd_load .n_bootstraps == N_BOOTSTRAPS
764
766
assert cd_load .ert == ERT
765
767
assert isinstance (cd_load .preprocess_fn , Callable )
@@ -1100,15 +1102,12 @@ def test_save_deepkernel(data, deep_kernel, backend, tmp_path): # noqa: F811
1100
1102
assert kernel_loaded .kernel_b .sigma == deep_kernel .kernel_b .sigma
1101
1103
1102
1104
1103
- @parametrize ('preprocess_fn' , [preprocess_custom , preprocess_hiddenoutput ])
1105
+ @parametrize ('preprocess_fn' , [preprocess_uae , preprocess_hiddenoutput ])
1104
1106
@parametrize_with_cases ("data" , cases = ContinuousData .data_synthetic_nd , prefix = 'data_' )
1105
- def test_save_preprocess (data , preprocess_fn , tmp_path , backend ):
1107
+ def test_save_preprocess_drift (data , preprocess_fn , tmp_path , backend ):
1106
1108
"""
1107
- Unit test for _save_preprocess_config and _load_preprocess_config, with continuous data.
1108
-
1109
- preprocess_fn's are saved (serialized) and then loaded, with assertions to check equivalence.
1110
- Note: _save_model_config, _save_embedding_config, _save_tokenizer_config, _load_model_config,
1111
- _load_embedding_config, _load_tokenizer_config and _prep_model_and_embedding are all well covered by this test.
1109
+ Test saving/loading of the inbuilt `preprocess_drift` preprocessing functions when containing a `model`, with the
1110
+ `model` either being a simple tf/torch model, or a `HiddenOutput` class.
1112
1111
"""
1113
1112
registry_str = 'tensorflow' if backend == 'tensorflow' else 'pytorch'
1114
1113
# Save preprocess_fn to config
@@ -1132,14 +1131,40 @@ def test_save_preprocess(data, preprocess_fn, tmp_path, backend):
1132
1131
assert isinstance (preprocess_fn_load .keywords ['model' ], nn .Module )
1133
1132
1134
1133
1134
+ @parametrize ('preprocess_fn' , [preprocess_simple , preprocess_simple_with_kwargs ])
1135
+ def test_save_preprocess_custom (preprocess_fn , tmp_path ):
1136
+ """
1137
+ Test saving/loading of custom preprocessing functions, without and with kwargs.
1138
+ """
1139
+ # Save preprocess_fn to config
1140
+ filepath = tmp_path
1141
+ cfg_preprocess = _save_preprocess_config (preprocess_fn , input_shape = None , filepath = filepath )
1142
+ cfg_preprocess = _path2str (cfg_preprocess )
1143
+ cfg_preprocess = PreprocessConfig (** cfg_preprocess ).dict () # pydantic validation
1144
+
1145
+ assert tmp_path .joinpath (cfg_preprocess ['src' ]).is_file ()
1146
+ assert cfg_preprocess ['src' ] == os .path .join ('preprocess_fn' , 'function.dill' )
1147
+ if isinstance (preprocess_fn , partial ): # kwargs expected
1148
+ assert cfg_preprocess ['kwargs' ] == preprocess_fn .keywords
1149
+ else : # no kwargs expected
1150
+ assert cfg_preprocess ['kwargs' ] == {}
1151
+
1152
+ # Resolve and load preprocess config
1153
+ cfg = {'preprocess_fn' : cfg_preprocess }
1154
+ preprocess_fn_load = resolve_config (cfg , tmp_path )['preprocess_fn' ] # tests _load_preprocess_config implicitly
1155
+ if isinstance (preprocess_fn , partial ):
1156
+ assert preprocess_fn_load .func == preprocess_fn .func
1157
+ assert preprocess_fn_load .keywords == preprocess_fn .keywords
1158
+ else :
1159
+ assert preprocess_fn_load == preprocess_fn
1160
+
1161
+
1135
1162
@parametrize ('preprocess_fn' , [preprocess_nlp ])
1136
1163
@parametrize_with_cases ("data" , cases = TextData .movie_sentiment_data , prefix = 'data_' )
1137
1164
def test_save_preprocess_nlp (data , preprocess_fn , tmp_path , backend ):
1138
1165
"""
1139
- Unit test for _save_preprocess_config and _load_preprocess_config, with text data.
1140
-
1141
- Note: _save_model_config, _save_embedding_config, _save_tokenizer_config, _load_model_config,
1142
- _load_embedding_config, _load_tokenizer_config and _prep_model_and_embedding are all covered by this test.
1166
+ Test saving/loading of the inbuilt `preprocess_drift` preprocessing functions when containing a `model`, text
1167
+ `tokenizer` and text `embedding` model.
1143
1168
"""
1144
1169
registry_str = 'tensorflow' if backend == 'tensorflow' else 'pytorch'
1145
1170
# Save preprocess_fn to config
@@ -1152,6 +1177,8 @@ def test_save_preprocess_nlp(data, preprocess_fn, tmp_path, backend):
1152
1177
assert cfg_preprocess ['src' ] == '@cd.' + registry_str + '.preprocess.preprocess_drift'
1153
1178
assert cfg_preprocess ['embedding' ]['src' ] == 'preprocess_fn/embedding'
1154
1179
assert cfg_preprocess ['tokenizer' ]['src' ] == 'preprocess_fn/tokenizer'
1180
+ assert tmp_path .joinpath (cfg_preprocess ['preprocess_batch_fn' ]).is_file ()
1181
+ assert cfg_preprocess ['preprocess_batch_fn' ] == os .path .join ('preprocess_fn' , 'preprocess_batch_fn.dill' )
1155
1182
1156
1183
if isinstance (preprocess_fn .keywords ['model' ], (TransformerEmbedding_tf , TransformerEmbedding_pt )):
1157
1184
assert cfg_preprocess ['model' ] is None
0 commit comments