Skip to content

Commit e458ab3

Browse files
Docstrings updates (#548)
* ale docs update * ale some conventions updates * ale single-under * moving back to original return format with inconsistency italic/normal font * first pass through anchor_base * back to standard returns without : * anchor_text and anchor_explanation first pass * first pass through anchor_text * first pass through cfproto and cem * fist pass through cfrl_base * first pass through cfrl_tabular * first pass through shap_wrappers * first pass through models * First pass through backend top * first pass through backends pytorch * first pass through backends tensorflow * small updates * minor correction * example duplicated attribute - TabularSampler * replaced attributes docstrings * second pass through explainers. * added explain fields up to cfproto (inclusive) * add description of explanation return fields * add links to docstrings + minor corrections * tensor to array * fixed duplicated target names * first pass confidence docs * fixed broken links in shap * included links in defaults SHAP * minor fixes * fixed minor indentation and punctuation. private IG build_explanation * Fixed IG and test_shap_wrappers build_explanation * minor updates on defaults, interfaces, autoencoder and anchortabular * fixed approximation_methods * fixed data * fixed discretizer * fix app_methods, distance & distributed * fix language model * fist pass through utils * second pass through utils * fixed mypy errors
1 parent ce961ca commit e458ab3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1894
-1405
lines changed

alibi/api/defaults.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@
104104
'kwargs',
105105
]
106106
"""
107-
KernelShap parameters updated and return in metadata['params'].
107+
KernelShap parameters updated and returned in ``metadata['params']``.
108+
See :py:class:`alibi.explainers.shap_wrappers.KernelShap`.
108109
"""
109110

110111
DEFAULT_META_KERNEL_SHAP = {
@@ -172,7 +173,8 @@
172173
'kwargs'
173174
]
174175
"""
175-
TreeShap parameters updated and return in metadata['params'].
176+
TreeShap parameters updated and returned in ``metadata['params']``.
177+
See :py:class:`alibi.explainers.shap_wrappers.TreeShap`.
176178
"""
177179

178180
DEFAULT_META_TREE_SHAP = {

alibi/api/interfaces.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class Explainer(abc.ABC):
6868
"""
6969
Base class for explainer algorithms
7070
"""
71-
meta = attr.ib(default=attr.Factory(default_meta), repr=alibi_pformat) # type: dict
71+
meta: dict = attr.ib(default=attr.Factory(default_meta), repr=alibi_pformat) #: Explainer meta-data.
7272

7373
def __attrs_post_init__(self):
7474
# add a name and version to the metadata dictionary
@@ -102,6 +102,14 @@ def load(cls, path: Union[str, os.PathLike], predictor: Any) -> "Explainer":
102102
return load_explainer(path, predictor)
103103

104104
def reset_predictor(self, predictor: Any) -> None:
105+
"""
106+
Resets the predictor.
107+
108+
Parameters
109+
----------
110+
predictor
111+
New predictor.
112+
"""
105113
raise NotImplementedError
106114

107115
def save(self, path: Union[str, os.PathLike]) -> None:
@@ -118,14 +126,14 @@ def save(self, path: Union[str, os.PathLike]) -> None:
118126
def _update_metadata(self, data_dict: dict, params: bool = False) -> None:
119127
"""
120128
Updates the metadata of the explainer using the data from the `data_dict`. If the params option
121-
is specified, then each key-value pair is added to the metadata `'params'` dictionary.
129+
is specified, then each key-value pair is added to the metadata ``'params'`` dictionary.
122130
123131
Parameters
124132
----------
125133
data_dict
126134
Contains the data to be stored in the metadata.
127135
params
128-
If True, the method updates the `'params'` attribute of the metatadata.
136+
If ``True``, the method updates the ``'params'`` attribute of the metadata.
129137
"""
130138

131139
if params:
@@ -151,34 +159,34 @@ class Explanation:
151159

152160
def __attrs_post_init__(self):
153161
"""
154-
Expose keys stored in self.meta and self.data as attributes of the class.
162+
Expose keys stored in `self.meta` and `self.data` as attributes of the class.
155163
"""
156164
for key, value in ChainMap(self.meta, self.data).items():
157165
setattr(self, key, value)
158166

159167
def to_json(self) -> str:
160168
"""
161-
Serialize the explanation data and metadata into a json format.
169+
Serialize the explanation data and metadata into a `json` format.
162170
163171
Returns
164172
-------
165-
String containing json representation of the explanation
173+
String containing `json` representation of the explanation.
166174
"""
167175
return json.dumps(attr.asdict(self), cls=NumpyEncoder)
168176

169177
@classmethod
170178
def from_json(cls, jsonrepr) -> "Explanation":
171179
"""
172-
Create an instance of an Explanation class using a json representation of the Explanation.
180+
Create an instance of an `Explanation` class using a `json` representation of the `Explanation`.
173181
174182
Parameters
175183
----------
176184
jsonrepr
177-
json representation of an explanation
185+
`json` representation of an explanation.
178186
179187
Returns
180188
-------
181-
An Explanation object
189+
An Explanation object.
182190
"""
183191
dictrepr = json.loads(jsonrepr)
184192
try:

alibi/confidence/model_linearity.py

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ def _linear_superposition(alphas, vecs, shape):
1616
Parameters
1717
----------
1818
alphas
19-
Coefficients of the superposition
19+
Coefficients of the superposition.
2020
vecs
21-
Tensors of the superposition
21+
Tensors of the superposition.
2222
shape
23-
Shape of each tensor
23+
Shape of each tensor.
2424
2525
Returns
2626
-------
27-
Linear tensor superposition
27+
Linear tensor superposition.
2828
"""
2929
input_str = string.ascii_lowercase[2: 2 + len(shape)]
3030
einstr = 'a,ba{}->b{}'.format(input_str, input_str)
@@ -39,19 +39,19 @@ def _calculate_global_linearity(predict_fn: Callable, input_shape: Tuple, X_samp
3939
Parameters
4040
----------
4141
predict_fn
42-
Model prediction function
42+
Model prediction function.
4343
input_shape
44-
Shape of the input
44+
Shape of the input.
4545
X_samples
46-
Array of feature vectors in the linear superposition
46+
Array of feature vectors in the linear superposition.
4747
model_type
48-
'classifier' or 'regressor'
48+
Supported values: ``'classifier'`` | ``'regressor'``.
4949
alphas
50-
Array of coefficients in the linear superposition
50+
Array of coefficients in the linear superposition.
5151
5252
Returns
5353
-------
54-
Linearity score
54+
Linearity score.
5555
5656
"""
5757
ss = X_samples.shape[:2] # X_samples shape=(nb_instances, nb_samples, nb_features)
@@ -105,27 +105,27 @@ def _calculate_global_linearity(predict_fn: Callable, input_shape: Tuple, X_samp
105105

106106
def _calculate_pairwise_linearity(predict_fn: Callable, x: np.ndarray, input_shape: Tuple, X_samples: np.ndarray,
107107
model_type: str, alphas: np.ndarray) -> np.ndarray:
108-
"""Calculates the norm of the difference between the output of a linear superposition of a test vector x and
109-
vectors in X_samples and the linear superposition of the outputs, averaged over all the vectors in X_samples.
108+
"""Calculates the norm of the difference between the output of a linear superposition of a test vector `x` and
109+
vectors in `X_samples` and the linear superposition of the outputs, averaged over all the vectors in `X_samples`.
110110
111111
Parameters
112112
----------
113113
predict_fn
114-
Model prediction function
114+
Model prediction function.
115115
x
116-
Test instance for which to calculate the linearity measure
116+
Test instance for which to calculate the linearity measure.
117117
input_shape
118-
Shape of the input
118+
Shape of the input.
119119
X_samples
120-
Array of feature vectors in the linear superposition
120+
Array of feature vectors in the linear superposition.
121121
model_type
122-
'classifier' or 'regressor'
122+
Supported values: ``'classifier'`` | ``'regressor'``.
123123
alphas
124-
Array of coefficients in the linear superposition
124+
Array of coefficients in the linear superposition.
125125
126126
Returns
127127
-------
128-
Linearity score
128+
Linearity score.
129129
130130
"""
131131
ss = X_samples.shape[:2] # X_samples shape=(nb_instances, nb_samples, nb_features)
@@ -190,15 +190,15 @@ def _sample_knn(x: np.ndarray, X_train: np.ndarray, nb_samples: int = 10) -> np.
190190
Parameters
191191
----------
192192
x
193-
Central instance for sampling
193+
Central instance for sampling.
194194
X_train
195195
Training set.
196196
nb_samples
197197
Number of samples to generate.
198198
199199
Returns
200200
-------
201-
Sampled vectors
201+
Sampled vectors.
202202
203203
"""
204204
x = x.reshape(x.shape[0], -1)
@@ -221,23 +221,23 @@ def _sample_knn(x: np.ndarray, X_train: np.ndarray, nb_samples: int = 10) -> np.
221221

222222
def _sample_grid(x: np.ndarray, feature_range: np.ndarray, epsilon: float = 0.04,
223223
nb_samples: int = 10, res: int = 100) -> np.ndarray:
224-
"""Samples data points uniformly from an interval centered at x and with size epsilon * Delta,
225-
with delta = f_max - f_min the features ranges.
224+
"""Samples data points uniformly from an interval centered at `x` and with size `epsilon * delta`,
225+
with `delta = f_max - f_min` the features ranges.
226226
227227
Parameters
228228
----------
229229
x
230230
Instance of interest.
231231
feature_range
232-
Array with min and max values for each feature
232+
Array with min and max values for each feature.
233233
epsilon
234234
Size of the sampling region around central instance as percentage of features range.
235235
nb_samples
236236
Number of samples to generate.
237237
238238
Returns
239239
-------
240-
Sampled vectors
240+
Sampled vectors.
241241
242242
"""
243243
nb_instances = x.shape[0]
@@ -271,7 +271,7 @@ def _linearity_measure(predict_fn: Callable,
271271
alphas: Optional[np.ndarray] = None,
272272
model_type: str = 'classifier',
273273
agg: str = 'global') -> np.ndarray:
274-
"""Calculate the linearity measure of the model around an instance of interest x.
274+
"""Calculate the linearity measure of the model around an instance of interest `x`.
275275
276276
Parameters
277277
----------
@@ -284,7 +284,7 @@ def _linearity_measure(predict_fn: Callable,
284284
feature_range
285285
Array with min and max values for each feature.
286286
method
287-
Method for sampling. Supported values 'knn' or 'grid'.
287+
Method for sampling. Supported values: ``'knn'`` | ``'grid'``.
288288
epsilon
289289
Size of the sampling region around the central instance as a percentage of feature range.
290290
nb_samples
@@ -294,13 +294,13 @@ def _linearity_measure(predict_fn: Callable,
294294
alphas
295295
Array of coefficients in the superposition.
296296
model_type
297-
Type of task. Supported values are 'regressor' or 'classifier'.
297+
Type of task. Supported values: ``'regressor'`` | ``'classifier'``.
298298
agg
299-
Aggregation method. Supported values are 'global' or 'pairwise'.
299+
Aggregation method. Supported values: ``'global'`` | ``'pairwise'``.
300300
301301
Returns
302302
-------
303-
Linearity score
303+
Linearity score.
304304
305305
"""
306306
input_shape = x.shape[1:]
@@ -339,11 +339,11 @@ def _infer_feature_range(X_train: np.ndarray) -> np.ndarray:
339339
Parameters
340340
----------
341341
X_train
342-
Training set
342+
Training set.
343343
344344
Returns
345345
-------
346-
Feature range
346+
Feature range.
347347
"""
348348
X_train = X_train.reshape(X_train.shape[0], -1)
349349
return np.vstack((X_train.min(axis=0), X_train.max(axis=0))).T
@@ -365,7 +365,7 @@ def __init__(self,
365365
Parameters
366366
----------
367367
method
368-
Method for sampling. Supported methods are 'knn' or 'grid'.
368+
Method for sampling. Supported methods: ``'knn'`` | ``'grid'``.
369369
epsilon
370370
Size of the sampling region around the central instance as a percentage of the features range.
371371
nb_samples
@@ -375,9 +375,9 @@ def __init__(self,
375375
alphas
376376
Coefficients in the superposition.
377377
agg
378-
Aggregation method. Supported values are 'global' or 'pairwise'.
378+
Aggregation method. Supported values: ``'global'`` | ``'pairwise'``.
379379
model_type
380-
Type of task. Supported values are 'regressor' or 'classifier'.
380+
Type of task. Supported values: ``'regressor'`` | ``'classifier'``.
381381
"""
382382
self.method = method
383383
self.epsilon = epsilon
@@ -395,11 +395,8 @@ def fit(self, X_train: np.ndarray) -> None:
395395
Parameters
396396
----------
397397
X_train
398-
Training set
398+
Training set.
399399
400-
Returns
401-
-------
402-
None
403400
"""
404401
self.X_train = X_train
405402
self.feature_range = _infer_feature_range(X_train)
@@ -412,13 +409,13 @@ def score(self, predict_fn: Callable, x: np.ndarray) -> np.ndarray:
412409
Parameters
413410
----------
414411
predict_fn
415-
Prediction function
412+
Prediction function.
416413
x
417-
Instance of interest
414+
Instance of interest.
418415
419416
Returns
420417
-------
421-
Linearity measure
418+
Linearity measure.
422419
423420
"""
424421
input_shape = x.shape[1:]
@@ -466,7 +463,7 @@ def linearity_measure(predict_fn: Callable,
466463
feature_range
467464
Array with min and max values for each feature.
468465
method
469-
Method for sampling. Supported values 'knn' or 'grid'.
466+
Method for sampling. Supported values: ``'knn'`` | ``'grid'``.
470467
X_train
471468
Training set.
472469
epsilon
@@ -478,13 +475,13 @@ def linearity_measure(predict_fn: Callable,
478475
alphas
479476
Coefficients in the superposition.
480477
agg
481-
Aggregation method. Supported values 'global' or 'pairwise'.
478+
Aggregation method. Supported values: ``'global'`` | ``'pairwise'``.
482479
model_type
483-
Type of task. Supported values 'regressor' or 'classifier'.
480+
Type of task. Supported values: ``'regressor'`` | ``'classifier'``.
484481
485482
Returns
486483
-------
487-
Linearity measure
484+
Linearity measure.
488485
489486
"""
490487
if method == 'knn':

0 commit comments

Comments
 (0)