Skip to content

Commit 886bedd

Browse files
committed
Merge branch 'sokollip-data-stats-support-in-lime'
2 parents ec5df45 + 15b59c8 commit 886bedd

File tree

3 files changed

+161
-15
lines changed

3 files changed

+161
-15
lines changed

lime/discretize.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class BaseDiscretizer():
1818

1919
__metaclass__ = ABCMeta # abstract class
2020

21-
def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None):
21+
def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None,
22+
data_stats=None):
2223
"""Initializer
2324
Args:
2425
data: numpy 2d array
@@ -31,9 +32,12 @@ def __init__(self, data, categorical_features, feature_names, labels=None, rando
3132
column x.
3233
feature_names: list of names (strings) corresponding to the columns
3334
in the training data.
35+
data_stats: must have 'means', 'stds', 'mins' and 'maxs', use this
36+
if you don't want these values to be computed from data
3437
"""
3538
self.to_discretize = ([x for x in range(data.shape[1])
36-
if x not in categorical_features])
39+
if x not in categorical_features])
40+
self.data_stats = data_stats
3741
self.names = {}
3842
self.lambdas = {}
3943
self.means = {}
@@ -46,6 +50,13 @@ def __init__(self, data, categorical_features, feature_names, labels=None, rando
4650
bins = self.bins(data, labels)
4751
bins = [np.unique(x) for x in bins]
4852

53+
# Read the stats from data_stats if exists
54+
if data_stats:
55+
self.means = self.data_stats.get("means")
56+
self.stds = self.data_stats.get("stds")
57+
self.mins = self.data_stats.get("mins")
58+
self.maxs = self.data_stats.get("maxs")
59+
4960
for feature, qts in zip(self.to_discretize, bins):
5061
n_bins = qts.shape[0] # Actually number of borders (= #bins-1)
5162
boundaries = np.min(data[:, feature]), np.max(data[:, feature])
@@ -60,6 +71,10 @@ def __init__(self, data, categorical_features, feature_names, labels=None, rando
6071
self.lambdas[feature] = lambda x, qts=qts: np.searchsorted(qts, x)
6172
discretized = self.lambdas[feature](data[:, feature])
6273

74+
# If data stats are provided no need to compute the below set of details
75+
if data_stats:
76+
continue
77+
6378
self.means[feature] = []
6479
self.stds[feature] = []
6580
for x in range(n_bins + 1):
@@ -117,6 +132,31 @@ def get_inverse(q):
117132
return ret
118133

119134

135+
class StatsDiscretizer(BaseDiscretizer):
136+
"""
137+
Class to be used to supply the data stats info when discretize_continuous is true
138+
"""
139+
140+
def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None,
141+
data_stats=None):
142+
143+
BaseDiscretizer.__init__(self, data, categorical_features,
144+
feature_names, labels=labels,
145+
random_state=random_state,
146+
data_stats=data_stats)
147+
148+
def bins(self, data, labels):
149+
bins_from_stats = self.data_stats.get("bins")
150+
bins = []
151+
if bins_from_stats is not None:
152+
for feature in self.to_discretize:
153+
bins_from_stats_feature = bins_from_stats.get(feature)
154+
if bins_from_stats_feature is not None:
155+
qts = np.array(bins_from_stats_feature)
156+
bins.append(qts)
157+
return bins
158+
159+
120160
class QuartileDiscretizer(BaseDiscretizer):
121161
def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None):
122162

lime/lime_tabular.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from lime.discretize import DecileDiscretizer
1717
from lime.discretize import EntropyDiscretizer
1818
from lime.discretize import BaseDiscretizer
19+
from lime.discretize import StatsDiscretizer
1920
from . import explanation
2021
from . import lime_base
2122

@@ -112,7 +113,8 @@ def __init__(self,
112113
discretize_continuous=True,
113114
discretizer='quartile',
114115
sample_around_instance=False,
115-
random_state=None):
116+
random_state=None,
117+
training_data_stats=None):
116118
"""Init function.
117119
118120
Args:
@@ -153,11 +155,21 @@ def __init__(self,
153155
random_state: an integer or numpy.RandomState that will be used to
154156
generate random numbers. If None, the random state will be
155157
initialized using the internal numpy seed.
158+
training_data_stats: a dict object having the details of training data
159+
statistics. If None, training data information will be used, only matters
160+
if discretize_continuous is True. Must have the following keys:
161+
means", "mins", "maxs", "stds", "feature_values",
162+
"feature_frequencies"
156163
"""
157164
self.random_state = check_random_state(random_state)
158165
self.mode = mode
159166
self.categorical_names = categorical_names or {}
160167
self.sample_around_instance = sample_around_instance
168+
self.training_data_stats = training_data_stats
169+
170+
# Check and raise proper error in stats are supplied in non-descritized path
171+
if self.training_data_stats:
172+
self.validate_training_data_stats(self.training_data_stats)
161173

162174
if categorical_features is None:
163175
categorical_features = []
@@ -169,6 +181,12 @@ def __init__(self,
169181

170182
self.discretizer = None
171183
if discretize_continuous:
184+
# Set the discretizer if training data stats are provided
185+
if self.training_data_stats:
186+
discretizer = StatsDiscretizer(training_data, self.categorical_features,
187+
self.feature_names, labels=training_labels,
188+
data_stats=self.training_data_stats)
189+
172190
if discretizer == 'quartile':
173191
self.discretizer = QuartileDiscretizer(
174192
training_data, self.categorical_features,
@@ -188,7 +206,10 @@ def __init__(self,
188206
''' 'decile', 'entropy' or a''' +
189207
''' BaseDiscretizer instance''')
190208
self.categorical_features = list(range(training_data.shape[1]))
191-
discretized_training_data = self.discretizer.discretize(
209+
210+
# Get the discretized_training_data when the stats are not provided
211+
if(self.training_data_stats is None):
212+
discretized_training_data = self.discretizer.discretize(
192213
training_data)
193214

194215
if kernel_width is None:
@@ -203,21 +224,27 @@ def kernel(d, kernel_width):
203224

204225
self.feature_selection = feature_selection
205226
self.base = lime_base.LimeBase(kernel_fn, verbose, random_state=self.random_state)
206-
self.scaler = None
207227
self.class_names = class_names
228+
229+
# Though set has no role to play if training data stats are provided
230+
self.scaler = None
208231
self.scaler = sklearn.preprocessing.StandardScaler(with_mean=False)
209232
self.scaler.fit(training_data)
210233
self.feature_values = {}
211234
self.feature_frequencies = {}
212235

213236
for feature in self.categorical_features:
214-
if self.discretizer is not None:
215-
column = discretized_training_data[:, feature]
216-
else:
217-
column = training_data[:, feature]
237+
if training_data_stats is None:
238+
if self.discretizer is not None:
239+
column = discretized_training_data[:, feature]
240+
else:
241+
column = training_data[:, feature]
218242

219-
feature_count = collections.Counter(column)
220-
values, frequencies = map(list, zip(*(sorted(feature_count.items()))))
243+
feature_count = collections.Counter(column)
244+
values, frequencies = map(list, zip(*(sorted(feature_count.items()))))
245+
else:
246+
values = training_data_stats["feature_values"][feature]
247+
frequencies = training_data_stats["feature_frequencies"][feature]
221248

222249
self.feature_values[feature] = values
223250
self.feature_frequencies[feature] = (np.array(frequencies) /
@@ -229,6 +256,17 @@ def kernel(d, kernel_width):
229256
def convert_and_round(values):
230257
return ['%.2f' % v for v in values]
231258

259+
@staticmethod
260+
def validate_training_data_stats(training_data_stats):
261+
"""
262+
Method to validate the structure of training data stats
263+
"""
264+
stat_keys = list(training_data_stats.keys())
265+
valid_stat_keys = ["means", "mins", "maxs", "stds", "feature_values", "feature_frequencies"]
266+
missing_keys = list(set(valid_stat_keys) - set(stat_keys))
267+
if len(missing_keys) > 0:
268+
raise Exception("Missing keys in training_data_stats. Details:" % (missing_keys))
269+
232270
def explain_instance(self,
233271
data_row,
234272
predict_fn,
@@ -414,8 +452,8 @@ def __data_inverse(self,
414452
categorical_features = range(data_row.shape[0])
415453
if self.discretizer is None:
416454
data = self.random_state.normal(
417-
0, 1, num_samples * data_row.shape[0]).reshape(
418-
num_samples, data_row.shape[0])
455+
0, 1, num_samples * data_row.shape[0]).reshape(
456+
num_samples, data_row.shape[0])
419457
if self.sample_around_instance:
420458
data = data * self.scaler.scale_ + data_row
421459
else:

lime/tests/test_lime_tabular.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
import unittest
22

33
import numpy as np
4-
import sklearn # noqa
4+
import collections
5+
import sklearn # noqa
56
import sklearn.datasets
67
import sklearn.ensemble
7-
import sklearn.linear_model # noqa
8+
import sklearn.linear_model # noqa
89
from numpy.testing import assert_array_equal
910
from sklearn.datasets import load_iris, make_classification
1011
from sklearn.ensemble import RandomForestClassifier
1112
from sklearn.linear_model import Lasso
1213
from sklearn.linear_model import LinearRegression
1314
from lime.discretize import QuartileDiscretizer, DecileDiscretizer, EntropyDiscretizer
1415

16+
1517
try:
1618
from sklearn.model_selection import train_test_split
1719
except ImportError:
@@ -577,6 +579,72 @@ def testFeatureValues(self):
577579
assert_array_equal(explainer.feature_frequencies[1], np.array([.25, .25, .25, .25]))
578580
assert_array_equal(explainer.feature_frequencies[2], np.array([.5, .5]))
579581

582+
def test_lime_explainer_with_data_stats(self):
583+
np.random.seed(1)
584+
585+
rf = RandomForestClassifier(n_estimators=500)
586+
rf.fit(self.train, self.labels_train)
587+
i = np.random.randint(0, self.test.shape[0])
588+
589+
# Generate stats using a quartile descritizer
590+
descritizer = QuartileDiscretizer(self.train, [], self.feature_names, self.target_names,
591+
random_state=20)
592+
593+
d_means = descritizer.means
594+
d_stds = descritizer.stds
595+
d_mins = descritizer.mins
596+
d_maxs = descritizer.maxs
597+
d_bins = descritizer.bins(self.train, self.target_names)
598+
599+
# Compute feature values and frequencies of all columns
600+
cat_features = np.arange(self.train.shape[1])
601+
discretized_training_data = descritizer.discretize(self.train)
602+
603+
feature_values = {}
604+
feature_frequencies = {}
605+
for feature in cat_features:
606+
column = discretized_training_data[:, feature]
607+
feature_count = collections.Counter(column)
608+
values, frequencies = map(list, zip(*(feature_count.items())))
609+
feature_values[feature] = values
610+
feature_frequencies[feature] = frequencies
611+
612+
# Convert bins to list from array
613+
d_bins_revised = {}
614+
index = 0
615+
for bin in d_bins:
616+
d_bins_revised[index] = bin.tolist()
617+
index = index+1
618+
619+
# Descritized stats
620+
data_stats = {}
621+
data_stats["means"] = d_means
622+
data_stats["stds"] = d_stds
623+
data_stats["maxs"] = d_maxs
624+
data_stats["mins"] = d_mins
625+
data_stats["bins"] = d_bins_revised
626+
data_stats["feature_values"] = feature_values
627+
data_stats["feature_frequencies"] = feature_frequencies
628+
629+
data = np.zeros((2, len(self.feature_names)))
630+
explainer = LimeTabularExplainer(
631+
data, feature_names=self.feature_names, random_state=10,
632+
training_data_stats=data_stats, training_labels=self.target_names)
633+
634+
exp = explainer.explain_instance(self.test[i],
635+
rf.predict_proba,
636+
num_features=2,
637+
model_regressor=LinearRegression())
638+
639+
self.assertIsNotNone(exp)
640+
keys = [x[0] for x in exp.as_list()]
641+
self.assertEqual(1,
642+
sum([1 if 'petal width' in x else 0 for x in keys]),
643+
"Petal Width is a major feature")
644+
self.assertEqual(1,
645+
sum([1 if 'petal length' in x else 0 for x in keys]),
646+
"Petal Length is a major feature")
647+
580648

581649
if __name__ == '__main__':
582650
unittest.main()

0 commit comments

Comments
 (0)