16
16
from lime .discretize import DecileDiscretizer
17
17
from lime .discretize import EntropyDiscretizer
18
18
from lime .discretize import BaseDiscretizer
19
+ from lime .discretize import StatsDiscretizer
19
20
from . import explanation
20
21
from . import lime_base
21
22
@@ -112,7 +113,8 @@ def __init__(self,
112
113
discretize_continuous = True ,
113
114
discretizer = 'quartile' ,
114
115
sample_around_instance = False ,
115
- random_state = None ):
116
+ random_state = None ,
117
+ training_data_stats = None ):
116
118
"""Init function.
117
119
118
120
Args:
@@ -153,11 +155,21 @@ def __init__(self,
153
155
random_state: an integer or numpy.RandomState that will be used to
154
156
generate random numbers. If None, the random state will be
155
157
initialized using the internal numpy seed.
158
+ training_data_stats: a dict object having the details of training data
159
+ statistics. If None, training data information will be used, only matters
160
+ if discretize_continuous is True. Must have the following keys:
161
+ means", "mins", "maxs", "stds", "feature_values",
162
+ "feature_frequencies"
156
163
"""
157
164
self .random_state = check_random_state (random_state )
158
165
self .mode = mode
159
166
self .categorical_names = categorical_names or {}
160
167
self .sample_around_instance = sample_around_instance
168
+ self .training_data_stats = training_data_stats
169
+
170
+ # Check and raise proper error in stats are supplied in non-descritized path
171
+ if self .training_data_stats :
172
+ self .validate_training_data_stats (self .training_data_stats )
161
173
162
174
if categorical_features is None :
163
175
categorical_features = []
@@ -169,6 +181,12 @@ def __init__(self,
169
181
170
182
self .discretizer = None
171
183
if discretize_continuous :
184
+ # Set the discretizer if training data stats are provided
185
+ if self .training_data_stats :
186
+ discretizer = StatsDiscretizer (training_data , self .categorical_features ,
187
+ self .feature_names , labels = training_labels ,
188
+ data_stats = self .training_data_stats )
189
+
172
190
if discretizer == 'quartile' :
173
191
self .discretizer = QuartileDiscretizer (
174
192
training_data , self .categorical_features ,
@@ -188,7 +206,10 @@ def __init__(self,
188
206
''' 'decile', 'entropy' or a''' +
189
207
''' BaseDiscretizer instance''' )
190
208
self .categorical_features = list (range (training_data .shape [1 ]))
191
- discretized_training_data = self .discretizer .discretize (
209
+
210
+ # Get the discretized_training_data when the stats are not provided
211
+ if (self .training_data_stats is None ):
212
+ discretized_training_data = self .discretizer .discretize (
192
213
training_data )
193
214
194
215
if kernel_width is None :
@@ -203,21 +224,27 @@ def kernel(d, kernel_width):
203
224
204
225
self .feature_selection = feature_selection
205
226
self .base = lime_base .LimeBase (kernel_fn , verbose , random_state = self .random_state )
206
- self .scaler = None
207
227
self .class_names = class_names
228
+
229
+ # Though set has no role to play if training data stats are provided
230
+ self .scaler = None
208
231
self .scaler = sklearn .preprocessing .StandardScaler (with_mean = False )
209
232
self .scaler .fit (training_data )
210
233
self .feature_values = {}
211
234
self .feature_frequencies = {}
212
235
213
236
for feature in self .categorical_features :
214
- if self .discretizer is not None :
215
- column = discretized_training_data [:, feature ]
216
- else :
217
- column = training_data [:, feature ]
237
+ if training_data_stats is None :
238
+ if self .discretizer is not None :
239
+ column = discretized_training_data [:, feature ]
240
+ else :
241
+ column = training_data [:, feature ]
218
242
219
- feature_count = collections .Counter (column )
220
- values , frequencies = map (list , zip (* (sorted (feature_count .items ()))))
243
+ feature_count = collections .Counter (column )
244
+ values , frequencies = map (list , zip (* (sorted (feature_count .items ()))))
245
+ else :
246
+ values = training_data_stats ["feature_values" ][feature ]
247
+ frequencies = training_data_stats ["feature_frequencies" ][feature ]
221
248
222
249
self .feature_values [feature ] = values
223
250
self .feature_frequencies [feature ] = (np .array (frequencies ) /
@@ -229,6 +256,17 @@ def kernel(d, kernel_width):
229
256
def convert_and_round (values ):
230
257
return ['%.2f' % v for v in values ]
231
258
259
+ @staticmethod
260
+ def validate_training_data_stats (training_data_stats ):
261
+ """
262
+ Method to validate the structure of training data stats
263
+ """
264
+ stat_keys = list (training_data_stats .keys ())
265
+ valid_stat_keys = ["means" , "mins" , "maxs" , "stds" , "feature_values" , "feature_frequencies" ]
266
+ missing_keys = list (set (valid_stat_keys ) - set (stat_keys ))
267
+ if len (missing_keys ) > 0 :
268
+ raise Exception ("Missing keys in training_data_stats. Details:" % (missing_keys ))
269
+
232
270
def explain_instance (self ,
233
271
data_row ,
234
272
predict_fn ,
@@ -414,8 +452,8 @@ def __data_inverse(self,
414
452
categorical_features = range (data_row .shape [0 ])
415
453
if self .discretizer is None :
416
454
data = self .random_state .normal (
417
- 0 , 1 , num_samples * data_row .shape [0 ]).reshape (
418
- num_samples , data_row .shape [0 ])
455
+ 0 , 1 , num_samples * data_row .shape [0 ]).reshape (
456
+ num_samples , data_row .shape [0 ])
419
457
if self .sample_around_instance :
420
458
data = data * self .scaler .scale_ + data_row
421
459
else :
0 commit comments